# ml_module/tools/data_preprocessing_tools.py import json import re from datetime import datetime from typing import Any, Dict, List, Optional import pandas as pd from agno.tools import Toolkit, tool from ml_module.services.storage_service import MLStorageService from ml_module.services.project_service import ProjectService from ml_module.core.exceptions import FileOperationException from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths from ml_module.core.response_formatter import ( FormattedResponse, Severity, make_text_response, metric_block, simple_table, simple_table_with_types, visualization_block, text_block, ) class DataPreprocessingToolkit(Toolkit): """A toolkit for safe, pre-built data cleaning and preprocessing operations.""" def __init__(self, storage_service: MLStorageService, user_id: str, project_id: str, project_service: ProjectService = None): super().__init__(name="data_preprocessing_tools") self.storage = storage_service self.project_service = project_service self.user_id = user_id self.project_id = project_id def _get_base_path(self, subfolder: str = "") -> str: if subfolder: return f"{self.user_id}/{self.project_id}/{subfolder}" return f"{self.user_id}/{self.project_id}" def _extract_version_from_path(self, artifact_path: str) -> Optional[int]: match = re.search(r"_v(\d+)", artifact_path) if match: try: return int(match.group(1)) except ValueError: return None return None @tool def handle_missing_values( self, input_path: str, output_filename: str, strategy: str, columns: Optional[List[str]] = None ) -> FormattedResponse: """ Handles missing values in a dataset using a specified strategy. Args: input_path (str): The path to the source dataset (e.g., 'raw/dataset.csv'). output_filename (str): The name for the processed file (e.g., 'cleaned_data.csv'). strategy (str): The method to use. Must be one of: 'mean', 'median', 'mode', 'drop_row'. columns (Optional[List[str]]): A list of specific column names to apply the strategy to. If None, applies to all possible columns. Returns: FormattedResponse: Structured confirmation with cleaning metrics and artifact reference. """ valid_strategies = ['mean', 'median', 'mode', 'drop_row'] if strategy not in valid_strategies: response = make_text_response( f"Invalid strategy '{strategy}'. Must be one of {valid_strategies}.", severity=Severity.ERROR, ) response.summary = "Invalid preprocessing strategy" response.done = True return response try: source_path = f"{self._get_base_path()}/{input_path}" df = self.storage.load_dataframe(source_path) target_cols = columns if not target_cols: target_cols = df.columns if strategy == 'drop_row': df.dropna(subset=target_cols, inplace=True) else: for col in target_cols: if df[col].isnull().any(): if strategy == 'mean': fill_value = df[col].mean() elif strategy == 'median': fill_value = df[col].median() elif strategy == 'mode': fill_value = df[col].mode()[0] df[col] = df[col].fillna(fill_value) output_path = f"{self._get_base_path('processed')}/{output_filename}" info = self.storage.save_dataframe(df, output_path) if self.project_service: version = self._extract_version_from_path(output_filename) if version is not None: columns_list = list(target_cols) extra_metadata = { "strategy": strategy, "columns": columns_list, "rows": len(df), } info.metadata.update(extra_metadata) self.project_service.register_artifact( self.user_id, self.project_id, ArtifactTypes.CLEANED_DATA, version, info, version_scope="processed", extra_metadata=extra_metadata, ) columns_summary = [ {"column": col, "strategy": strategy} for col in (target_cols or []) ] blocks = [ text_block( f"Applied `{strategy}` strategy to {len(target_cols)} columns", severity=Severity.SUCCESS, ), metric_block("Rows After Cleaning", len(df)), simple_table(columns_summary, caption="Columns processed", block_id="columns_processed"), text_block(f"Cleaned dataset saved to `{output_path}`"), ] return FormattedResponse( blocks=blocks, summary=f"Handled missing values using {strategy}", correlation_id=info.path, done=True, ) except Exception as e: raise FileOperationException("handle missing values", source_path, e) @tool def save_processed_sample_head( self, processed_csv_path: str, version: int, limit: int = DEFAULT_SAMPLE_ROWS ) -> FormattedResponse: """ Save a sample head of processed data for UI preview. Args: processed_csv_path (str): Path to the processed CSV file (e.g., 'processed/cleaned_data_v1.csv'). version (int): Version number for the sample file. limit (int): Number of rows to include in sample (default: 20). Returns: FormattedResponse: Structured preview details and artifact reference for the sample JSON. """ try: # Load the processed data full_path = f"{self._get_base_path()}/{processed_csv_path}" df = self.storage.load_dataframe(full_path) # Create sample head sample_df = df.head(limit) # Convert to JSON-serializable format sample_data = { "data": sample_df.to_dict('records'), "columns": list(df.columns), "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}, "shape": df.shape, "sample_rows": len(sample_df), "total_rows": len(df), "created_at": datetime.now().isoformat(), "version": version } # Save to samples folder sample_path = StoragePaths.CLEANED_SAMPLE.format( user_id=self.user_id, project_id=self.project_id, version=version ) info = self.storage.save_json(sample_data, sample_path) if self.project_service: extra_metadata = { "sample_rows": sample_data.get("sample_rows"), "total_rows": sample_data.get("total_rows"), "columns": sample_data.get("columns", []), } info.metadata.update(extra_metadata) self.project_service.register_artifact( self.user_id, self.project_id, ArtifactTypes.CLEANED_SAMPLE, version, info, version_scope="processed", extra_metadata=extra_metadata, ) preview_rows = sample_data["data"][: min(10, len(sample_data["data"]))] blocks = [ metric_block("Sample Rows", sample_data.get("sample_rows", 0)), metric_block("Total Rows", sample_data.get("total_rows", 0)), simple_table_with_types(preview_rows, caption="Sample preview", block_id="processed_sample_preview"), text_block(f"Sample JSON saved to `{sample_path}`"), ] return FormattedResponse( blocks=blocks, summary=f"Created processed sample v{version}", correlation_id=info.path, done=True, ) except Exception as e: raise FileOperationException("save processed sample head", processed_csv_path, e) @tool def generate_change_log( self, version: int, operations: List[str], before_stats: Dict[str, Any], after_stats: Dict[str, Any], columns_affected: Optional[List[str]] = None ) -> FormattedResponse: """ Generate a human-readable change log for preprocessing operations. Args: version (int): Version number for the change log. operations (List[str]): List of operations performed (e.g., ["handled missing values with median", "removed outliers"]). before_stats (Dict[str, Any]): Statistics before processing (shape, nulls, etc.). after_stats (Dict[str, Any]): Statistics after processing (shape, nulls, etc.). columns_affected (Optional[List[str]]): List of columns that were modified. Returns: FormattedResponse: Structured change-log summary with artifact reference. """ try: # Create change log structure change_log = { "version": version, "timestamp": datetime.now().isoformat(), "operations_performed": operations, "statistics": { "before": before_stats, "after": after_stats, "changes": { "rows_removed": before_stats.get('row_count', 0) - after_stats.get('row_count', 0), "columns_modified": len(columns_affected) if columns_affected else 0, "null_values_handled": before_stats.get('null_count', 0) - after_stats.get('null_count', 0) } }, "columns_affected": columns_affected or [], "human_readable": { "summary": f"Applied {len(operations)} operations to the dataset", "details": operations, "impact": f"Dataset shape changed from {before_stats.get('shape', 'unknown')} to {after_stats.get('shape', 'unknown')}" } } # Save change log change_log_path = StoragePaths.CHANGE_LOG.format( user_id=self.user_id, project_id=self.project_id, version=version ) info = self.storage.save_json(change_log, change_log_path) if self.project_service: info.metadata.update({ "operations": operations, "columns_affected": columns_affected or [], }) self.project_service.register_artifact( self.user_id, self.project_id, ArtifactTypes.CHANGE_LOG, version, info, version_scope="processed", extra_metadata={ "operations": operations, "summary": change_log.get("human_readable", {}).get("summary"), }, ) stats = change_log["statistics"] summary_rows = [ {"metric": "Rows", "before": before_stats.get("row_count"), "after": after_stats.get("row_count")}, {"metric": "Null values", "before": before_stats.get("null_count"), "after": after_stats.get("null_count")}, ] blocks = [ text_block(f"Recorded {len(operations)} preprocessing operations", severity=Severity.INFO), simple_table(summary_rows, caption="Dataset stats delta", block_id="stats_delta"), text_block(f"Change log saved to `{change_log_path}`"), ] return FormattedResponse( blocks=blocks, summary=f"Captured preprocessing change log v{version}", correlation_id=info.path, done=True, ) except Exception as e: raise FileOperationException("generate change log", f"version_{version}", e) @tool def compare_preprocessing_versions( self, version_a: int, version_b: int, comparison_type: str = "full" ) -> FormattedResponse: """ Compare two preprocessing versions and generate a detailed diff summary. Args: version_a (int): First version number for comparison (typically older). version_b (int): Second version number for comparison (typically newer). comparison_type (str): Type of comparison - "full", "summary", or "stats_only". Returns: FormattedResponse: Structured comparison summary with artifact reference. """ try: # Load the two versions' data sample_path_a = StoragePaths.CLEANED_SAMPLE.format( user_id=self.user_id, project_id=self.project_id, version=version_a ) sample_path_b = StoragePaths.CLEANED_SAMPLE.format( user_id=self.user_id, project_id=self.project_id, version=version_b ) sample_a = self.storage.load_json(sample_path_a) sample_b = self.storage.load_json(sample_path_b) # Load change logs if available changelog_path_a = StoragePaths.CHANGE_LOG.format( user_id=self.user_id, project_id=self.project_id, version=version_a ) changelog_path_b = StoragePaths.CHANGE_LOG.format( user_id=self.user_id, project_id=self.project_id, version=version_b ) try: changelog_a = self.storage.load_json(changelog_path_a) changelog_b = self.storage.load_json(changelog_path_b) except: changelog_a = {"operations_performed": ["Unknown operations"]} changelog_b = {"operations_performed": ["Unknown operations"]} # Generate comparison data comparison = { "versions_compared": {"from": version_a, "to": version_b}, "timestamp": datetime.now().isoformat(), "data_changes": { "shape_change": { "from": sample_a.get("shape", [0, 0]), "to": sample_b.get("shape", [0, 0]) }, "row_count_change": { "from": sample_a.get("total_rows", 0), "to": sample_b.get("total_rows", 0), "difference": sample_b.get("total_rows", 0) - sample_a.get("total_rows", 0) }, "columns_change": { "from": sample_a.get("columns", []), "to": sample_b.get("columns", []), "added": list(set(sample_b.get("columns", [])) - set(sample_a.get("columns", []))), "removed": list(set(sample_a.get("columns", [])) - set(sample_b.get("columns", []))) }, "dtypes_changes": self._compare_dtypes( sample_a.get("dtypes", {}), sample_b.get("dtypes", {}) ) }, "operations": { "version_a_operations": changelog_a.get("operations_performed", []), "version_b_operations": changelog_b.get("operations_performed", []), "new_operations": list(set(changelog_b.get("operations_performed", [])) - set(changelog_a.get("operations_performed", []))) }, "human_readable": { "summary": f"Comparison between v{version_a} and v{version_b}", "key_differences": self._generate_key_differences(sample_a, sample_b, changelog_a, changelog_b) } } # Include sample data comparison if full comparison requested if comparison_type == "full": comparison["sample_data"] = { "version_a_sample": sample_a.get("data", [])[:5], # First 5 rows "version_b_sample": sample_b.get("data", [])[:5] # First 5 rows } # Save comparison result comparison_path = f"{self.user_id}/{self.project_id}/processed/version_comparison_v{version_a}_v{version_b}.json" info = self.storage.save_json(comparison, comparison_path) diff = comparison["data_changes"] metric_rows = [ { "metric": "Row count", "from": diff["row_count_change"]["from"], "to": diff["row_count_change"]["to"], "delta": diff["row_count_change"]["difference"], }, { "metric": "Columns", "from": len(diff["columns_change"]["from"]), "to": len(diff["columns_change"]["to"]), "delta": len(diff["columns_change"]["added"]) - len(diff["columns_change"]["removed"]), }, ] column_changes = diff["columns_change"] blocks = [ text_block( f"Compared preprocessing versions v{version_a} → v{version_b}", severity=Severity.INFO, ), simple_table(metric_rows, caption="Key dataset deltas", block_id="dataset_deltas"), text_block( f"Columns added: {', '.join(column_changes['added']) or 'None'}\nColumns removed: {', '.join(column_changes['removed']) or 'None'}", ), text_block(f"Comparison saved to `{comparison_path}`"), ] return FormattedResponse( blocks=blocks, summary=f"Generated comparison v{version_a} vs v{version_b}", correlation_id=info.path, done=True, ) except Exception as e: raise FileOperationException("compare preprocessing versions", f"v{version_a}_v{version_b}", e) def _compare_dtypes(self, dtypes_a: Dict[str, str], dtypes_b: Dict[str, str]) -> Dict[str, Any]: """Helper method to compare data types between versions.""" changes = {} all_columns = set(dtypes_a.keys()) | set(dtypes_b.keys()) for col in all_columns: dtype_a = dtypes_a.get(col) dtype_b = dtypes_b.get(col) if dtype_a != dtype_b: changes[col] = {"from": dtype_a, "to": dtype_b} return changes def _generate_key_differences(self, sample_a: Dict, sample_b: Dict, changelog_a: Dict, changelog_b: Dict) -> List[str]: """Generate human-readable key differences between versions.""" differences = [] # Row count changes rows_a = sample_a.get("total_rows", 0) rows_b = sample_b.get("total_rows", 0) if rows_a != rows_b: if rows_b > rows_a: differences.append(f"Added {rows_b - rows_a} rows") else: differences.append(f"Removed {rows_a - rows_b} rows") # Column changes cols_a = set(sample_a.get("columns", [])) cols_b = set(sample_b.get("columns", [])) added_cols = cols_b - cols_a removed_cols = cols_a - cols_b if added_cols: differences.append(f"Added columns: {', '.join(added_cols)}") if removed_cols: differences.append(f"Removed columns: {', '.join(removed_cols)}") # New operations ops_a = set(changelog_a.get("operations_performed", [])) ops_b = set(changelog_b.get("operations_performed", [])) new_ops = ops_b - ops_a if new_ops: differences.append(f"New operations: {'; '.join(new_ops)}") if not differences: differences.append("No significant differences detected") return differences @tool def get_preprocessing_history(self) -> FormattedResponse: """ Get the complete preprocessing history for this project. Returns: FormattedResponse: Structured history overview with artifact reference. """ try: # Find all processed versions by checking for sample files history = { "project_id": self.project_id, "user_id": self.user_id, "timestamp": datetime.now().isoformat(), "versions": [], "summary": { "total_versions": 0, "latest_version": 0 } } # Check for versions (up to 20 versions) for version in range(1, 21): try: sample_path = StoragePaths.CLEANED_SAMPLE.format( user_id=self.user_id, project_id=self.project_id, version=version ) sample_data = self.storage.load_json(sample_path) # Try to load change log try: changelog_path = StoragePaths.CHANGE_LOG.format( user_id=self.user_id, project_id=self.project_id, version=version ) changelog = self.storage.load_json(changelog_path) except: changelog = {"operations_performed": ["Operations not recorded"]} version_info = { "version": version, "created_at": sample_data.get("created_at", "Unknown"), "shape": sample_data.get("shape", [0, 0]), "total_rows": sample_data.get("total_rows", 0), "operations": changelog.get("operations_performed", []), "human_readable_summary": changelog.get("human_readable", {}).get("summary", "No summary available") } history["versions"].append(version_info) history["summary"]["latest_version"] = version except: # Version doesn't exist, stop checking break history["summary"]["total_versions"] = len(history["versions"]) # Save history history_path = f"{self.user_id}/{self.project_id}/processed/preprocessing_history.json" info = self.storage.save_json(history, history_path) rows = [ { "version": item.get("version"), "rows": item.get("total_rows"), "summary": item.get("human_readable_summary"), } for item in history["versions"] ][:10] blocks = [ text_block( f"Indexed {len(history['versions'])} preprocessing versions", severity=Severity.INFO, ), simple_table(rows, caption="Recent preprocessing runs", block_id="preprocessing_history"), text_block(f"History saved to `{history_path}`"), ] return FormattedResponse( blocks=blocks, summary="Compiled preprocessing history", correlation_id=info.path, done=True, ) except Exception as e: raise FileOperationException("get preprocessing history", "all_versions", e)