Spaces:
Running
Running
| # ml_module/tools/data_tools.py | |
| import re | |
| from typing import Dict, Optional | |
| import pandas as pd | |
| from agno.tools import Toolkit, tool | |
| from ml_module.services.storage_service import MLStorageService | |
| from ml_module.services.project_service import ProjectService | |
| from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths | |
| from ml_module.core.response_formatter import ( | |
| FormattedResponse, | |
| Severity, | |
| make_text_response, | |
| metric_block, | |
| simple_table, | |
| simple_table_with_types, | |
| visualization_block, | |
| text_block, | |
| ) | |
| class DataAnalysisToolkit(Toolkit): | |
| """A collection of safe tools for performing data analysis.""" | |
| def __init__( | |
| self, | |
| storage_service: MLStorageService, | |
| user_id: str, | |
| project_id: str, | |
| project_service: Optional[ProjectService] = None, | |
| ): | |
| super().__init__(name="data_analysis_tools") | |
| self.storage = storage_service | |
| self.user_id = user_id | |
| self.project_id = project_id | |
| self.project_service = project_service | |
| def _get_base_path(self) -> str: | |
| return f"{self.user_id}/{self.project_id}" | |
| def _extract_version_from_path(self, path: str) -> Optional[int]: | |
| match = re.search(r"_v(\d+)", path) | |
| if match: | |
| try: | |
| return int(match.group(1)) | |
| except ValueError: | |
| return None | |
| return None | |
| def _resolve_raw_version(self, dataset_path: str, default: int = 1) -> int: | |
| version = self._extract_version_from_path(dataset_path) | |
| if version is None and self.project_service: | |
| try: | |
| version = self.project_service.get_latest_version(self.user_id, self.project_id, "raw") | |
| except Exception: | |
| version = default | |
| return version or default | |
| def get_data_summary(self, dataset_path: str) -> FormattedResponse: | |
| """ | |
| Calculates and saves a high-level summary of the dataset. This includes | |
| shape (rows and columns), a list of column names, and data types for each column. | |
| This should be the VERY FIRST tool you use to understand the dataset. | |
| Args: | |
| dataset_path (str): The full path to the dataset file within project storage. | |
| Returns: | |
| FormattedResponse: Structured dataset summary with artifact reference. | |
| """ | |
| try: | |
| df = self.storage.load_dataframe(dataset_path) | |
| summary = { | |
| "shape": {"rows": df.shape[0], "columns": df.shape[1]}, | |
| "column_names": list(df.columns), | |
| "column_data_types": {col: str(dtype) for col, dtype in df.dtypes.items()}, | |
| } | |
| output_path = f"{self._get_base_path()}/analysis/data_profile.json" | |
| info = self.storage.save_json(summary, output_path) | |
| if self.project_service: | |
| version = self._resolve_raw_version(dataset_path) | |
| info.metadata.update({"columns": summary["column_names"]}) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.DATA_PROFILE, | |
| version, | |
| info, | |
| version_scope="raw", | |
| extra_metadata={"shape": summary["shape"]}, | |
| ) | |
| dtype_rows = [ | |
| {"column": col, "dtype": dtype} | |
| for col, dtype in summary["column_data_types"].items() | |
| ] | |
| blocks = [ | |
| metric_block("Rows", summary["shape"]["rows"]), | |
| metric_block("Columns", summary["shape"]["columns"]), | |
| simple_table_with_types(dtype_rows, caption="Column data types", block_id="column_dtypes"), | |
| text_block(f"Saved summary to `{output_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary="Data summary generated", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| error_response = make_text_response( | |
| f"Could not get data summary: {e}", | |
| severity=Severity.ERROR, | |
| ) | |
| error_response.done = True | |
| return error_response | |
| def get_missing_values_summary(self, dataset_path: str) -> FormattedResponse: | |
| """ | |
| Analyzes the dataset for missing (null or NaN) values in each column and saves a | |
| report. This is a crucial step for assessing data quality. | |
| Args: | |
| dataset_path (str): The full path to the dataset file within project storage. | |
| Returns: | |
| FormattedResponse: Structured missing-value overview with artifact reference. | |
| """ | |
| try: | |
| df = self.storage.load_dataframe(dataset_path) | |
| missing_values = df.isnull().sum() | |
| missing_summary = { | |
| "total_missing_values": int(missing_values.sum()), | |
| "missing_percentage": f"{(missing_values.sum() / (df.shape[0] * df.shape[1])):.2%}", | |
| "missing_values_per_column": { | |
| col: int(count) for col, count in missing_values.items() if count > 0 | |
| } | |
| } | |
| output_path = f"{self._get_base_path()}/analysis/missing_values_report.json" | |
| info = self.storage.save_json(missing_summary, output_path) | |
| if self.project_service: | |
| version = self._resolve_raw_version(dataset_path) | |
| info.metadata.update({"columns_with_missing": list(missing_summary["missing_values_per_column"].keys())}) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.MISSING_VALUES, | |
| version, | |
| info, | |
| version_scope="raw", | |
| extra_metadata={ | |
| "total_missing": missing_summary["total_missing_values"], | |
| "missing_percentage": missing_summary["missing_percentage"], | |
| }, | |
| ) | |
| columns_with_missing = list(missing_summary["missing_values_per_column"].keys()) | |
| table_rows = [ | |
| {"column": col, "missing": count} | |
| for col, count in missing_summary["missing_values_per_column"].items() | |
| ] | |
| blocks = [ | |
| metric_block( | |
| "Total Missing", | |
| missing_summary["total_missing_values"], | |
| unit="cells", | |
| ), | |
| text_block( | |
| f"Overall missing percentage: {missing_summary['missing_percentage']}", | |
| severity=Severity.INFO, | |
| ), | |
| simple_table(table_rows, caption="Missing values per column", block_id="missing_values"), | |
| text_block(f"Saved missing-values report to `{output_path}`"), | |
| ] | |
| summary_text = ( | |
| "No missing values detected" | |
| if not columns_with_missing | |
| else f"Missing values recorded for {len(columns_with_missing)} columns" | |
| ) | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=summary_text, | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| error_response = make_text_response( | |
| f"Could not analyze missing values: {e}", | |
| severity=Severity.ERROR, | |
| ) | |
| error_response.done = True | |
| return error_response | |
| def save_sample_head( | |
| self, | |
| dataset_path: str, | |
| limit: Optional[int] = None, | |
| version: Optional[int] = None, | |
| ) -> FormattedResponse: | |
| """ | |
| Saves the first N rows of the dataset as a JSON file for UI preview. | |
| Includes both the data sample and schema information. | |
| Args: | |
| dataset_path (str): The full path to the dataset file within project storage. | |
| limit (Optional[int]): Number of rows to sample (defaults to DEFAULT_SAMPLE_ROWS). | |
| Returns: | |
| FormattedResponse: Structured sample preview with artifact reference. | |
| """ | |
| try: | |
| df = self.storage.load_dataframe(dataset_path) | |
| rows_to_sample = limit or DEFAULT_SAMPLE_ROWS | |
| resolved_version = version or self._resolve_raw_version(dataset_path) | |
| # Get sample data (first N rows) | |
| sample_df = df.head(rows_to_sample) | |
| # Create comprehensive sample data structure | |
| sample_data = { | |
| "dataset_info": { | |
| "total_rows": len(df), | |
| "total_columns": len(df.columns), | |
| "sample_rows": len(sample_df), | |
| "source_path": dataset_path | |
| }, | |
| "schema": { | |
| "columns": list(df.columns), | |
| "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}, | |
| "null_counts": {col: int(count) for col, count in df.isnull().sum().items()} | |
| }, | |
| "sample_data": { | |
| "columns": list(sample_df.columns), | |
| "rows": sample_df.to_dict(orient="records") | |
| } | |
| } | |
| # Use versioned path from constants | |
| output_path = StoragePaths.SAMPLE_RAW_HEAD.format( | |
| user_id=self.user_id, | |
| project_id=self.project_id, | |
| version=resolved_version | |
| ) | |
| info = self.storage.save_json(sample_data, output_path) | |
| if self.project_service: | |
| info.metadata.update({ | |
| "sample_rows": sample_data["dataset_info"].get("sample_rows"), | |
| "total_rows": sample_data["dataset_info"].get("total_rows"), | |
| }) | |
| self.project_service.register_artifact( | |
| self.user_id, | |
| self.project_id, | |
| ArtifactTypes.SAMPLE_RAW_HEAD, | |
| resolved_version, | |
| info, | |
| version_scope="raw", | |
| extra_metadata={ | |
| "columns": sample_data["schema"].get("columns", []), | |
| }, | |
| ) | |
| preview_rows = sample_df.head(min(rows_to_sample, 10)).to_dict(orient="records") | |
| blocks = [ | |
| metric_block("Rows Sampled", len(sample_df)), | |
| metric_block("Total Rows", len(df)), | |
| simple_table_with_types(preview_rows, caption="First rows preview", block_id="sample_preview"), | |
| text_block(f"Sample saved to `{output_path}`"), | |
| ] | |
| return FormattedResponse( | |
| blocks=blocks, | |
| summary=f"Saved sample head (first {len(sample_df)} rows)", | |
| correlation_id=info.path, | |
| done=True, | |
| ) | |
| except Exception as e: | |
| error_response = make_text_response( | |
| f"Could not save dataset sample: {e}", | |
| severity=Severity.ERROR, | |
| ) | |
| error_response.done = True | |
| return error_response |