sirus / backend /ml_module /tools /data_preprocessing_tools.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
783a952
# ml_module/tools/data_preprocessing_tools.py
import json
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from agno.tools import Toolkit, tool
from ml_module.services.storage_service import MLStorageService
from ml_module.services.project_service import ProjectService
from ml_module.core.exceptions import FileOperationException
from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths
from ml_module.core.response_formatter import (
FormattedResponse,
Severity,
make_text_response,
metric_block,
simple_table,
simple_table_with_types,
visualization_block,
text_block,
)
class DataPreprocessingToolkit(Toolkit):
"""A toolkit for safe, pre-built data cleaning and preprocessing operations."""
def __init__(self, storage_service: MLStorageService, user_id: str, project_id: str, project_service: ProjectService = None):
super().__init__(name="data_preprocessing_tools")
self.storage = storage_service
self.project_service = project_service
self.user_id = user_id
self.project_id = project_id
def _get_base_path(self, subfolder: str = "") -> str:
if subfolder:
return f"{self.user_id}/{self.project_id}/{subfolder}"
return f"{self.user_id}/{self.project_id}"
def _extract_version_from_path(self, artifact_path: str) -> Optional[int]:
match = re.search(r"_v(\d+)", artifact_path)
if match:
try:
return int(match.group(1))
except ValueError:
return None
return None
@tool
def handle_missing_values(
self,
input_path: str,
output_filename: str,
strategy: str,
columns: Optional[List[str]] = None
) -> FormattedResponse:
"""
Handles missing values in a dataset using a specified strategy.
Args:
input_path (str): The path to the source dataset (e.g., 'raw/dataset.csv').
output_filename (str): The name for the processed file (e.g., 'cleaned_data.csv').
strategy (str): The method to use. Must be one of: 'mean', 'median', 'mode', 'drop_row'.
columns (Optional[List[str]]): A list of specific column names to apply the strategy to.
If None, applies to all possible columns.
Returns:
FormattedResponse: Structured confirmation with cleaning metrics and artifact reference.
"""
valid_strategies = ['mean', 'median', 'mode', 'drop_row']
if strategy not in valid_strategies:
response = make_text_response(
f"Invalid strategy '{strategy}'. Must be one of {valid_strategies}.",
severity=Severity.ERROR,
)
response.summary = "Invalid preprocessing strategy"
response.done = True
return response
try:
source_path = f"{self._get_base_path()}/{input_path}"
df = self.storage.load_dataframe(source_path)
target_cols = columns
if not target_cols:
target_cols = df.columns
if strategy == 'drop_row':
df.dropna(subset=target_cols, inplace=True)
else:
for col in target_cols:
if df[col].isnull().any():
if strategy == 'mean':
fill_value = df[col].mean()
elif strategy == 'median':
fill_value = df[col].median()
elif strategy == 'mode':
fill_value = df[col].mode()[0]
df[col] = df[col].fillna(fill_value)
output_path = f"{self._get_base_path('processed')}/{output_filename}"
info = self.storage.save_dataframe(df, output_path)
if self.project_service:
version = self._extract_version_from_path(output_filename)
if version is not None:
columns_list = list(target_cols)
extra_metadata = {
"strategy": strategy,
"columns": columns_list,
"rows": len(df),
}
info.metadata.update(extra_metadata)
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.CLEANED_DATA,
version,
info,
version_scope="processed",
extra_metadata=extra_metadata,
)
columns_summary = [
{"column": col, "strategy": strategy}
for col in (target_cols or [])
]
blocks = [
text_block(
f"Applied `{strategy}` strategy to {len(target_cols)} columns",
severity=Severity.SUCCESS,
),
metric_block("Rows After Cleaning", len(df)),
simple_table(columns_summary, caption="Columns processed", block_id="columns_processed"),
text_block(f"Cleaned dataset saved to `{output_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary=f"Handled missing values using {strategy}",
correlation_id=info.path,
done=True,
)
except Exception as e:
raise FileOperationException("handle missing values", source_path, e)
@tool
def save_processed_sample_head(
self,
processed_csv_path: str,
version: int,
limit: int = DEFAULT_SAMPLE_ROWS
) -> FormattedResponse:
"""
Save a sample head of processed data for UI preview.
Args:
processed_csv_path (str): Path to the processed CSV file (e.g., 'processed/cleaned_data_v1.csv').
version (int): Version number for the sample file.
limit (int): Number of rows to include in sample (default: 20).
Returns:
FormattedResponse: Structured preview details and artifact reference for the sample JSON.
"""
try:
# Load the processed data
full_path = f"{self._get_base_path()}/{processed_csv_path}"
df = self.storage.load_dataframe(full_path)
# Create sample head
sample_df = df.head(limit)
# Convert to JSON-serializable format
sample_data = {
"data": sample_df.to_dict('records'),
"columns": list(df.columns),
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
"shape": df.shape,
"sample_rows": len(sample_df),
"total_rows": len(df),
"created_at": datetime.now().isoformat(),
"version": version
}
# Save to samples folder
sample_path = StoragePaths.CLEANED_SAMPLE.format(
user_id=self.user_id,
project_id=self.project_id,
version=version
)
info = self.storage.save_json(sample_data, sample_path)
if self.project_service:
extra_metadata = {
"sample_rows": sample_data.get("sample_rows"),
"total_rows": sample_data.get("total_rows"),
"columns": sample_data.get("columns", []),
}
info.metadata.update(extra_metadata)
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.CLEANED_SAMPLE,
version,
info,
version_scope="processed",
extra_metadata=extra_metadata,
)
preview_rows = sample_data["data"][: min(10, len(sample_data["data"]))]
blocks = [
metric_block("Sample Rows", sample_data.get("sample_rows", 0)),
metric_block("Total Rows", sample_data.get("total_rows", 0)),
simple_table_with_types(preview_rows, caption="Sample preview", block_id="processed_sample_preview"),
text_block(f"Sample JSON saved to `{sample_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary=f"Created processed sample v{version}",
correlation_id=info.path,
done=True,
)
except Exception as e:
raise FileOperationException("save processed sample head", processed_csv_path, e)
@tool
def generate_change_log(
self,
version: int,
operations: List[str],
before_stats: Dict[str, Any],
after_stats: Dict[str, Any],
columns_affected: Optional[List[str]] = None
) -> FormattedResponse:
"""
Generate a human-readable change log for preprocessing operations.
Args:
version (int): Version number for the change log.
operations (List[str]): List of operations performed (e.g., ["handled missing values with median", "removed outliers"]).
before_stats (Dict[str, Any]): Statistics before processing (shape, nulls, etc.).
after_stats (Dict[str, Any]): Statistics after processing (shape, nulls, etc.).
columns_affected (Optional[List[str]]): List of columns that were modified.
Returns:
FormattedResponse: Structured change-log summary with artifact reference.
"""
try:
# Create change log structure
change_log = {
"version": version,
"timestamp": datetime.now().isoformat(),
"operations_performed": operations,
"statistics": {
"before": before_stats,
"after": after_stats,
"changes": {
"rows_removed": before_stats.get('row_count', 0) - after_stats.get('row_count', 0),
"columns_modified": len(columns_affected) if columns_affected else 0,
"null_values_handled": before_stats.get('null_count', 0) - after_stats.get('null_count', 0)
}
},
"columns_affected": columns_affected or [],
"human_readable": {
"summary": f"Applied {len(operations)} operations to the dataset",
"details": operations,
"impact": f"Dataset shape changed from {before_stats.get('shape', 'unknown')} to {after_stats.get('shape', 'unknown')}"
}
}
# Save change log
change_log_path = StoragePaths.CHANGE_LOG.format(
user_id=self.user_id,
project_id=self.project_id,
version=version
)
info = self.storage.save_json(change_log, change_log_path)
if self.project_service:
info.metadata.update({
"operations": operations,
"columns_affected": columns_affected or [],
})
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.CHANGE_LOG,
version,
info,
version_scope="processed",
extra_metadata={
"operations": operations,
"summary": change_log.get("human_readable", {}).get("summary"),
},
)
stats = change_log["statistics"]
summary_rows = [
{"metric": "Rows", "before": before_stats.get("row_count"), "after": after_stats.get("row_count")},
{"metric": "Null values", "before": before_stats.get("null_count"), "after": after_stats.get("null_count")},
]
blocks = [
text_block(f"Recorded {len(operations)} preprocessing operations", severity=Severity.INFO),
simple_table(summary_rows, caption="Dataset stats delta", block_id="stats_delta"),
text_block(f"Change log saved to `{change_log_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary=f"Captured preprocessing change log v{version}",
correlation_id=info.path,
done=True,
)
except Exception as e:
raise FileOperationException("generate change log", f"version_{version}", e)
@tool
def compare_preprocessing_versions(
self,
version_a: int,
version_b: int,
comparison_type: str = "full"
) -> FormattedResponse:
"""
Compare two preprocessing versions and generate a detailed diff summary.
Args:
version_a (int): First version number for comparison (typically older).
version_b (int): Second version number for comparison (typically newer).
comparison_type (str): Type of comparison - "full", "summary", or "stats_only".
Returns:
FormattedResponse: Structured comparison summary with artifact reference.
"""
try:
# Load the two versions' data
sample_path_a = StoragePaths.CLEANED_SAMPLE.format(
user_id=self.user_id,
project_id=self.project_id,
version=version_a
)
sample_path_b = StoragePaths.CLEANED_SAMPLE.format(
user_id=self.user_id,
project_id=self.project_id,
version=version_b
)
sample_a = self.storage.load_json(sample_path_a)
sample_b = self.storage.load_json(sample_path_b)
# Load change logs if available
changelog_path_a = StoragePaths.CHANGE_LOG.format(
user_id=self.user_id,
project_id=self.project_id,
version=version_a
)
changelog_path_b = StoragePaths.CHANGE_LOG.format(
user_id=self.user_id,
project_id=self.project_id,
version=version_b
)
try:
changelog_a = self.storage.load_json(changelog_path_a)
changelog_b = self.storage.load_json(changelog_path_b)
except:
changelog_a = {"operations_performed": ["Unknown operations"]}
changelog_b = {"operations_performed": ["Unknown operations"]}
# Generate comparison data
comparison = {
"versions_compared": {"from": version_a, "to": version_b},
"timestamp": datetime.now().isoformat(),
"data_changes": {
"shape_change": {
"from": sample_a.get("shape", [0, 0]),
"to": sample_b.get("shape", [0, 0])
},
"row_count_change": {
"from": sample_a.get("total_rows", 0),
"to": sample_b.get("total_rows", 0),
"difference": sample_b.get("total_rows", 0) - sample_a.get("total_rows", 0)
},
"columns_change": {
"from": sample_a.get("columns", []),
"to": sample_b.get("columns", []),
"added": list(set(sample_b.get("columns", [])) - set(sample_a.get("columns", []))),
"removed": list(set(sample_a.get("columns", [])) - set(sample_b.get("columns", [])))
},
"dtypes_changes": self._compare_dtypes(
sample_a.get("dtypes", {}),
sample_b.get("dtypes", {})
)
},
"operations": {
"version_a_operations": changelog_a.get("operations_performed", []),
"version_b_operations": changelog_b.get("operations_performed", []),
"new_operations": list(set(changelog_b.get("operations_performed", [])) -
set(changelog_a.get("operations_performed", [])))
},
"human_readable": {
"summary": f"Comparison between v{version_a} and v{version_b}",
"key_differences": self._generate_key_differences(sample_a, sample_b, changelog_a, changelog_b)
}
}
# Include sample data comparison if full comparison requested
if comparison_type == "full":
comparison["sample_data"] = {
"version_a_sample": sample_a.get("data", [])[:5], # First 5 rows
"version_b_sample": sample_b.get("data", [])[:5] # First 5 rows
}
# Save comparison result
comparison_path = f"{self.user_id}/{self.project_id}/processed/version_comparison_v{version_a}_v{version_b}.json"
info = self.storage.save_json(comparison, comparison_path)
diff = comparison["data_changes"]
metric_rows = [
{
"metric": "Row count",
"from": diff["row_count_change"]["from"],
"to": diff["row_count_change"]["to"],
"delta": diff["row_count_change"]["difference"],
},
{
"metric": "Columns",
"from": len(diff["columns_change"]["from"]),
"to": len(diff["columns_change"]["to"]),
"delta": len(diff["columns_change"]["added"]) - len(diff["columns_change"]["removed"]),
},
]
column_changes = diff["columns_change"]
blocks = [
text_block(
f"Compared preprocessing versions v{version_a} → v{version_b}",
severity=Severity.INFO,
),
simple_table(metric_rows, caption="Key dataset deltas", block_id="dataset_deltas"),
text_block(
f"Columns added: {', '.join(column_changes['added']) or 'None'}\nColumns removed: {', '.join(column_changes['removed']) or 'None'}",
),
text_block(f"Comparison saved to `{comparison_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary=f"Generated comparison v{version_a} vs v{version_b}",
correlation_id=info.path,
done=True,
)
except Exception as e:
raise FileOperationException("compare preprocessing versions", f"v{version_a}_v{version_b}", e)
def _compare_dtypes(self, dtypes_a: Dict[str, str], dtypes_b: Dict[str, str]) -> Dict[str, Any]:
"""Helper method to compare data types between versions."""
changes = {}
all_columns = set(dtypes_a.keys()) | set(dtypes_b.keys())
for col in all_columns:
dtype_a = dtypes_a.get(col)
dtype_b = dtypes_b.get(col)
if dtype_a != dtype_b:
changes[col] = {"from": dtype_a, "to": dtype_b}
return changes
def _generate_key_differences(self, sample_a: Dict, sample_b: Dict, changelog_a: Dict, changelog_b: Dict) -> List[str]:
"""Generate human-readable key differences between versions."""
differences = []
# Row count changes
rows_a = sample_a.get("total_rows", 0)
rows_b = sample_b.get("total_rows", 0)
if rows_a != rows_b:
if rows_b > rows_a:
differences.append(f"Added {rows_b - rows_a} rows")
else:
differences.append(f"Removed {rows_a - rows_b} rows")
# Column changes
cols_a = set(sample_a.get("columns", []))
cols_b = set(sample_b.get("columns", []))
added_cols = cols_b - cols_a
removed_cols = cols_a - cols_b
if added_cols:
differences.append(f"Added columns: {', '.join(added_cols)}")
if removed_cols:
differences.append(f"Removed columns: {', '.join(removed_cols)}")
# New operations
ops_a = set(changelog_a.get("operations_performed", []))
ops_b = set(changelog_b.get("operations_performed", []))
new_ops = ops_b - ops_a
if new_ops:
differences.append(f"New operations: {'; '.join(new_ops)}")
if not differences:
differences.append("No significant differences detected")
return differences
@tool
def get_preprocessing_history(self) -> FormattedResponse:
"""
Get the complete preprocessing history for this project.
Returns:
FormattedResponse: Structured history overview with artifact reference.
"""
try:
# Find all processed versions by checking for sample files
history = {
"project_id": self.project_id,
"user_id": self.user_id,
"timestamp": datetime.now().isoformat(),
"versions": [],
"summary": {
"total_versions": 0,
"latest_version": 0
}
}
# Check for versions (up to 20 versions)
for version in range(1, 21):
try:
sample_path = StoragePaths.CLEANED_SAMPLE.format(
user_id=self.user_id,
project_id=self.project_id,
version=version
)
sample_data = self.storage.load_json(sample_path)
# Try to load change log
try:
changelog_path = StoragePaths.CHANGE_LOG.format(
user_id=self.user_id,
project_id=self.project_id,
version=version
)
changelog = self.storage.load_json(changelog_path)
except:
changelog = {"operations_performed": ["Operations not recorded"]}
version_info = {
"version": version,
"created_at": sample_data.get("created_at", "Unknown"),
"shape": sample_data.get("shape", [0, 0]),
"total_rows": sample_data.get("total_rows", 0),
"operations": changelog.get("operations_performed", []),
"human_readable_summary": changelog.get("human_readable", {}).get("summary", "No summary available")
}
history["versions"].append(version_info)
history["summary"]["latest_version"] = version
except:
# Version doesn't exist, stop checking
break
history["summary"]["total_versions"] = len(history["versions"])
# Save history
history_path = f"{self.user_id}/{self.project_id}/processed/preprocessing_history.json"
info = self.storage.save_json(history, history_path)
rows = [
{
"version": item.get("version"),
"rows": item.get("total_rows"),
"summary": item.get("human_readable_summary"),
}
for item in history["versions"]
][:10]
blocks = [
text_block(
f"Indexed {len(history['versions'])} preprocessing versions",
severity=Severity.INFO,
),
simple_table(rows, caption="Recent preprocessing runs", block_id="preprocessing_history"),
text_block(f"History saved to `{history_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary="Compiled preprocessing history",
correlation_id=info.path,
done=True,
)
except Exception as e:
raise FileOperationException("get preprocessing history", "all_versions", e)