Spaces:
Running
Running
| # ml_module/tools/data_preprocessing_tools.py | |
| import json | |
| import re | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Optional | |
| import pandas as pd | |
| from agno.tools import Toolkit, tool | |
| from ml_module.services.storage_service import MLStorageService | |
| from ml_module.services.project_service import ProjectService | |
| from ml_module.core.exceptions import FileOperationException | |
| from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths | |
| from ml_module.core.response_formatter import ( | |
| FormattedResponse, | |
| Severity, | |
| make_text_response, | |
| metric_block, | |
| simple_table, | |
| simple_table_with_types, | |
| visualization_block, | |
| text_block, | |
| ) | |
| class DataPreprocessingToolkit(Toolkit): | |
| """A toolkit for safe, pre-built data cleaning and preprocessing operations.""" | |
| def __init__(self, storage_service: MLStorageService, user_id: str, project_id: str, project_service: ProjectService = None): | |
| super().__init__(name="data_preprocessing_tools") | |
| self.storage = storage_service | |
| self.project_service = project_service | |
| self.user_id = user_id | |
| self.project_id = project_id | |
| def _get_base_path(self, subfolder: str = "") -> str: | |
| if subfolder: | |
| return f"{self.user_id}/{self.project_id}/{subfolder}" | |
| return f"{self.user_id}/{self.project_id}" | |
| def _extract_version_from_path(self, artifact_path: str) -> Optional[int]: | |
| match = re.search(r"_v(\d+)", artifact_path) | |
| if match: | |
| try: | |
| return int(match.group(1)) | |
| except ValueError: | |
| return None | |
| return None | |
| def handle_missing_values( | |
| self, | |
| input_path: str, | |
| output_filename: str, | |
| strategy: str, | |
| columns: Optional[List[str]] = None | |
| ) -> FormattedResponse: | |
| """ | |
| Handles missing values in a dataset using a specified strategy. | |
| Args: | |
| input_path (str): The path to the source dataset (e.g., 'raw/dataset.csv'). | |
| output_filename (str): The name for the processed file (e.g., 'cleaned_data.csv'). | |
| strategy (str): The method to use. Must be one of: 'mean', 'median', 'mode', 'drop_row'. | |
| columns (Optional[List[str]]): A list of specific column names to apply the strategy to. | |
| If None, applies to all possible columns. | |
| Returns: | |
| FormattedResponse: Structured confirmation with cleaning metrics and artifact reference. | |
| """ | |
| valid_strategies = ['mean', 'median', 'mode', 'drop_row'] | |
| if strategy not in valid_strategies: | |
| response = make_text_response( | |
| f"Invalid strategy '{strategy}'. Must be one of {valid_strategies}.", | |
| severity=Severity.ERROR, | |
| ) | |
| response.summary = "Invalid preprocessing strategy" | |
| response.done = True | |
| return response | |
| try: | |
| source_path = f"{self._get_base_path()}/{input_path}" | |
| df = self.storage.load_dataframe(source_path) | |
| target_cols = columns | |
| if not target_cols: | |
| target_cols = df.columns | |
| if strategy == 'drop_row': | |
| df.dropna(subset=target_cols, inplace=True) | |
| else: | |
| for col in target_cols: | |
| if df[col].isnull().any(): | |
| if strategy == 'mean': | |
| fill_value = df[col].mean() | |
| elif strategy == 'median': | |
| fill_value = df[col].median() | |
| elif strategy == 'mode': | |
| fill_value = df[col].mode()[0] | |
| df[col] = df[col].fillna(fill_value) | |
| output_path = f"{self._get_base_path('processed')}/{output_filename}" | |
| info = self.storage.save_dataframe(df, output_path) | |
| if self.project_service: | |
| version = self._extract_version_from_path(output_filename) | |
| if version is not None: | |
| columns_list = list(target_cols) | |
| extra_metadata = { | |
| "strategy": strategy, | |
| "columns": columns_list, | |
| "rows": len(df), | |
| } | |
| info.metadata.update(extra_metadata) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.CLEANED_DATA, | |
| version, | |
| info, | |
| version_scope="processed", | |
| extra_metadata=extra_metadata, | |
| ) | |
| columns_summary = [ | |
| {"column": col, "strategy": strategy} | |
| for col in (target_cols or []) | |
| ] | |
| blocks = [ | |
| text_block( | |
| f"Applied `{strategy}` strategy to {len(target_cols)} columns", | |
| severity=Severity.SUCCESS, | |
| ), | |
| metric_block("Rows After Cleaning", len(df)), | |
| simple_table(columns_summary, caption="Columns processed", block_id="columns_processed"), | |
| text_block(f"Cleaned dataset saved to `{output_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=f"Handled missing values using {strategy}", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| raise FileOperationException("handle missing values", source_path, e) | |
| def save_processed_sample_head( | |
| self, | |
| processed_csv_path: str, | |
| version: int, | |
| limit: int = DEFAULT_SAMPLE_ROWS | |
| ) -> FormattedResponse: | |
| """ | |
| Save a sample head of processed data for UI preview. | |
| Args: | |
| processed_csv_path (str): Path to the processed CSV file (e.g., 'processed/cleaned_data_v1.csv'). | |
| version (int): Version number for the sample file. | |
| limit (int): Number of rows to include in sample (default: 20). | |
| Returns: | |
| FormattedResponse: Structured preview details and artifact reference for the sample JSON. | |
| """ | |
| try: | |
| # Load the processed data | |
| full_path = f"{self._get_base_path()}/{processed_csv_path}" | |
| df = self.storage.load_dataframe(full_path) | |
| # Create sample head | |
| sample_df = df.head(limit) | |
| # Convert to JSON-serializable format | |
| sample_data = { | |
| "data": sample_df.to_dict('records'), | |
| "columns": list(df.columns), | |
| "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}, | |
| "shape": df.shape, | |
| "sample_rows": len(sample_df), | |
| "total_rows": len(df), | |
| "created_at": datetime.now().isoformat(), | |
| "version": version | |
| } | |
| # Save to samples folder | |
| sample_path = StoragePaths.CLEANED_SAMPLE.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version | |
| ) | |
| info = self.storage.save_json(sample_data, sample_path) | |
| if self.project_service: | |
| extra_metadata = { | |
| "sample_rows": sample_data.get("sample_rows"), | |
| "total_rows": sample_data.get("total_rows"), | |
| "columns": sample_data.get("columns", []), | |
| } | |
| info.metadata.update(extra_metadata) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.CLEANED_SAMPLE, | |
| version, | |
| info, | |
| version_scope="processed", | |
| extra_metadata=extra_metadata, | |
| ) | |
| preview_rows = sample_data["data"][: min(10, len(sample_data["data"]))] | |
| blocks = [ | |
| metric_block("Sample Rows", sample_data.get("sample_rows", 0)), | |
| metric_block("Total Rows", sample_data.get("total_rows", 0)), | |
| simple_table_with_types(preview_rows, caption="Sample preview", block_id="processed_sample_preview"), | |
| text_block(f"Sample JSON saved to `{sample_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=f"Created processed sample v{version}", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| raise FileOperationException("save processed sample head", processed_csv_path, e) | |
| def generate_change_log( | |
| self, | |
| version: int, | |
| operations: List[str], | |
| before_stats: Dict[str, Any], | |
| after_stats: Dict[str, Any], | |
| columns_affected: Optional[List[str]] = None | |
| ) -> FormattedResponse: | |
| """ | |
| Generate a human-readable change log for preprocessing operations. | |
| Args: | |
| version (int): Version number for the change log. | |
| operations (List[str]): List of operations performed (e.g., ["handled missing values with median", "removed outliers"]). | |
| before_stats (Dict[str, Any]): Statistics before processing (shape, nulls, etc.). | |
| after_stats (Dict[str, Any]): Statistics after processing (shape, nulls, etc.). | |
| columns_affected (Optional[List[str]]): List of columns that were modified. | |
| Returns: | |
| FormattedResponse: Structured change-log summary with artifact reference. | |
| """ | |
| try: | |
| # Create change log structure | |
| change_log = { | |
| "version": version, | |
| "timestamp": datetime.now().isoformat(), | |
| "operations_performed": operations, | |
| "statistics": { | |
| "before": before_stats, | |
| "after": after_stats, | |
| "changes": { | |
| "rows_removed": before_stats.get('row_count', 0) - after_stats.get('row_count', 0), | |
| "columns_modified": len(columns_affected) if columns_affected else 0, | |
| "null_values_handled": before_stats.get('null_count', 0) - after_stats.get('null_count', 0) | |
| } | |
| }, | |
| "columns_affected": columns_affected or [], | |
| "human_readable": { | |
| "summary": f"Applied {len(operations)} operations to the dataset", | |
| "details": operations, | |
| "impact": f"Dataset shape changed from {before_stats.get('shape', 'unknown')} to {after_stats.get('shape', 'unknown')}" | |
| } | |
| } | |
| # Save change log | |
| change_log_path = StoragePaths.CHANGE_LOG.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version | |
| ) | |
| info = self.storage.save_json(change_log, change_log_path) | |
| if self.project_service: | |
| info.metadata.update({ | |
| "operations": operations, | |
| "columns_affected": columns_affected or [], | |
| }) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.CHANGE_LOG, | |
| version, | |
| info, | |
| version_scope="processed", | |
| extra_metadata={ | |
| "operations": operations, | |
| "summary": change_log.get("human_readable", {}).get("summary"), | |
| }, | |
| ) | |
| stats = change_log["statistics"] | |
| summary_rows = [ | |
| {"metric": "Rows", "before": before_stats.get("row_count"), "after": after_stats.get("row_count")}, | |
| {"metric": "Null values", "before": before_stats.get("null_count"), "after": after_stats.get("null_count")}, | |
| ] | |
| blocks = [ | |
| text_block(f"Recorded {len(operations)} preprocessing operations", severity=Severity.INFO), | |
| simple_table(summary_rows, caption="Dataset stats delta", block_id="stats_delta"), | |
| text_block(f"Change log saved to `{change_log_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=f"Captured preprocessing change log v{version}", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| raise FileOperationException("generate change log", f"version_{version}", e) | |
| def compare_preprocessing_versions( | |
| self, | |
| version_a: int, | |
| version_b: int, | |
| comparison_type: str = "full" | |
| ) -> FormattedResponse: | |
| """ | |
| Compare two preprocessing versions and generate a detailed diff summary. | |
| Args: | |
| version_a (int): First version number for comparison (typically older). | |
| version_b (int): Second version number for comparison (typically newer). | |
| comparison_type (str): Type of comparison - "full", "summary", or "stats_only". | |
| Returns: | |
| FormattedResponse: Structured comparison summary with artifact reference. | |
| """ | |
| try: | |
| # Load the two versions' data | |
| sample_path_a = StoragePaths.CLEANED_SAMPLE.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version_a | |
| ) | |
| sample_path_b = StoragePaths.CLEANED_SAMPLE.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version_b | |
| ) | |
| sample_a = self.storage.load_json(sample_path_a) | |
| sample_b = self.storage.load_json(sample_path_b) | |
| # Load change logs if available | |
| changelog_path_a = StoragePaths.CHANGE_LOG.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version_a | |
| ) | |
| changelog_path_b = StoragePaths.CHANGE_LOG.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version_b | |
| ) | |
| try: | |
| changelog_a = self.storage.load_json(changelog_path_a) | |
| changelog_b = self.storage.load_json(changelog_path_b) | |
| except: | |
| changelog_a = {"operations_performed": ["Unknown operations"]} | |
| changelog_b = {"operations_performed": ["Unknown operations"]} | |
| # Generate comparison data | |
| comparison = { | |
| "versions_compared": {"from": version_a, "to": version_b}, | |
| "timestamp": datetime.now().isoformat(), | |
| "data_changes": { | |
| "shape_change": { | |
| "from": sample_a.get("shape", [0, 0]), | |
| "to": sample_b.get("shape", [0, 0]) | |
| }, | |
| "row_count_change": { | |
| "from": sample_a.get("total_rows", 0), | |
| "to": sample_b.get("total_rows", 0), | |
| "difference": sample_b.get("total_rows", 0) - sample_a.get("total_rows", 0) | |
| }, | |
| "columns_change": { | |
| "from": sample_a.get("columns", []), | |
| "to": sample_b.get("columns", []), | |
| "added": list(set(sample_b.get("columns", [])) - set(sample_a.get("columns", []))), | |
| "removed": list(set(sample_a.get("columns", [])) - set(sample_b.get("columns", []))) | |
| }, | |
| "dtypes_changes": self._compare_dtypes( | |
| sample_a.get("dtypes", {}), | |
| sample_b.get("dtypes", {}) | |
| ) | |
| }, | |
| "operations": { | |
| "version_a_operations": changelog_a.get("operations_performed", []), | |
| "version_b_operations": changelog_b.get("operations_performed", []), | |
| "new_operations": list(set(changelog_b.get("operations_performed", [])) - | |
| set(changelog_a.get("operations_performed", []))) | |
| }, | |
| "human_readable": { | |
| "summary": f"Comparison between v{version_a} and v{version_b}", | |
| "key_differences": self._generate_key_differences(sample_a, sample_b, changelog_a, changelog_b) | |
| } | |
| } | |
| # Include sample data comparison if full comparison requested | |
| if comparison_type == "full": | |
| comparison["sample_data"] = { | |
| "version_a_sample": sample_a.get("data", [])[:5], # First 5 rows | |
| "version_b_sample": sample_b.get("data", [])[:5] # First 5 rows | |
| } | |
| # Save comparison result | |
| comparison_path = f"{self.user_id}/{self.project_id}/processed/version_comparison_v{version_a}_v{version_b}.json" | |
| info = self.storage.save_json(comparison, comparison_path) | |
| diff = comparison["data_changes"] | |
| metric_rows = [ | |
| { | |
| "metric": "Row count", | |
| "from": diff["row_count_change"]["from"], | |
| "to": diff["row_count_change"]["to"], | |
| "delta": diff["row_count_change"]["difference"], | |
| }, | |
| { | |
| "metric": "Columns", | |
| "from": len(diff["columns_change"]["from"]), | |
| "to": len(diff["columns_change"]["to"]), | |
| "delta": len(diff["columns_change"]["added"]) - len(diff["columns_change"]["removed"]), | |
| }, | |
| ] | |
| column_changes = diff["columns_change"] | |
| blocks = [ | |
| text_block( | |
| f"Compared preprocessing versions v{version_a} → v{version_b}", | |
| severity=Severity.INFO, | |
| ), | |
| simple_table(metric_rows, caption="Key dataset deltas", block_id="dataset_deltas"), | |
| text_block( | |
| f"Columns added: {', '.join(column_changes['added']) or 'None'}\nColumns removed: {', '.join(column_changes['removed']) or 'None'}", | |
| ), | |
| text_block(f"Comparison saved to `{comparison_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=f"Generated comparison v{version_a} vs v{version_b}", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| raise FileOperationException("compare preprocessing versions", f"v{version_a}_v{version_b}", e) | |
| def _compare_dtypes(self, dtypes_a: Dict[str, str], dtypes_b: Dict[str, str]) -> Dict[str, Any]: | |
| """Helper method to compare data types between versions.""" | |
| changes = {} | |
| all_columns = set(dtypes_a.keys()) | set(dtypes_b.keys()) | |
| for col in all_columns: | |
| dtype_a = dtypes_a.get(col) | |
| dtype_b = dtypes_b.get(col) | |
| if dtype_a != dtype_b: | |
| changes[col] = {"from": dtype_a, "to": dtype_b} | |
| return changes | |
| def _generate_key_differences(self, sample_a: Dict, sample_b: Dict, changelog_a: Dict, changelog_b: Dict) -> List[str]: | |
| """Generate human-readable key differences between versions.""" | |
| differences = [] | |
| # Row count changes | |
| rows_a = sample_a.get("total_rows", 0) | |
| rows_b = sample_b.get("total_rows", 0) | |
| if rows_a != rows_b: | |
| if rows_b > rows_a: | |
| differences.append(f"Added {rows_b - rows_a} rows") | |
| else: | |
| differences.append(f"Removed {rows_a - rows_b} rows") | |
| # Column changes | |
| cols_a = set(sample_a.get("columns", [])) | |
| cols_b = set(sample_b.get("columns", [])) | |
| added_cols = cols_b - cols_a | |
| removed_cols = cols_a - cols_b | |
| if added_cols: | |
| differences.append(f"Added columns: {', '.join(added_cols)}") | |
| if removed_cols: | |
| differences.append(f"Removed columns: {', '.join(removed_cols)}") | |
| # New operations | |
| ops_a = set(changelog_a.get("operations_performed", [])) | |
| ops_b = set(changelog_b.get("operations_performed", [])) | |
| new_ops = ops_b - ops_a | |
| if new_ops: | |
| differences.append(f"New operations: {'; '.join(new_ops)}") | |
| if not differences: | |
| differences.append("No significant differences detected") | |
| return differences | |
| def get_preprocessing_history(self) -> FormattedResponse: | |
| """ | |
| Get the complete preprocessing history for this project. | |
| Returns: | |
| FormattedResponse: Structured history overview with artifact reference. | |
| """ | |
| try: | |
| # Find all processed versions by checking for sample files | |
| history = { | |
| "project_id": self.project_id, | |
| "user_id": self.user_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "versions": [], | |
| "summary": { | |
| "total_versions": 0, | |
| "latest_version": 0 | |
| } | |
| } | |
| # Check for versions (up to 20 versions) | |
| for version in range(1, 21): | |
| try: | |
| sample_path = StoragePaths.CLEANED_SAMPLE.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version | |
| ) | |
| sample_data = self.storage.load_json(sample_path) | |
| # Try to load change log | |
| try: | |
| changelog_path = StoragePaths.CHANGE_LOG.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=version | |
| ) | |
| changelog = self.storage.load_json(changelog_path) | |
| except: | |
| changelog = {"operations_performed": ["Operations not recorded"]} | |
| version_info = { | |
| "version": version, | |
| "created_at": sample_data.get("created_at", "Unknown"), | |
| "shape": sample_data.get("shape", [0, 0]), | |
| "total_rows": sample_data.get("total_rows", 0), | |
| "operations": changelog.get("operations_performed", []), | |
| "human_readable_summary": changelog.get("human_readable", {}).get("summary", "No summary available") | |
| } | |
| history["versions"].append(version_info) | |
| history["summary"]["latest_version"] = version | |
| except: | |
| # Version doesn't exist, stop checking | |
| break | |
| history["summary"]["total_versions"] = len(history["versions"]) | |
| # Save history | |
| history_path = f"{self.user_id}/{self.project_id}/processed/preprocessing_history.json" | |
| info = self.storage.save_json(history, history_path) | |
| rows = [ | |
| { | |
| "version": item.get("version"), | |
| "rows": item.get("total_rows"), | |
| "summary": item.get("human_readable_summary"), | |
| } | |
| for item in history["versions"] | |
| ][:10] | |
| blocks = [ | |
| text_block( | |
| f"Indexed {len(history['versions'])} preprocessing versions", | |
| severity=Severity.INFO, | |
| ), | |
| simple_table(rows, caption="Recent preprocessing runs", block_id="preprocessing_history"), | |
| text_block(f"History saved to `{history_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary="Compiled preprocessing history", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| raise FileOperationException("get preprocessing history", "all_versions", e) |