# ml_module/tools/data_tools.py import re from typing import Dict, Optional import pandas as pd from agno.tools import Toolkit, tool from ml_module.services.storage_service import MLStorageService from ml_module.services.project_service import ProjectService from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths from ml_module.core.response_formatter import ( FormattedResponse, Severity, make_text_response, metric_block, simple_table, simple_table_with_types, visualization_block, text_block, ) class DataAnalysisToolkit(Toolkit): """A collection of safe tools for performing data analysis.""" def __init__( self, storage_service: MLStorageService, user_id: str, project_id: str, project_service: Optional[ProjectService] = None, ): super().__init__(name="data_analysis_tools") self.storage = storage_service self.user_id = user_id self.project_id = project_id self.project_service = project_service def _get_base_path(self) -> str: return f"{self.user_id}/{self.project_id}" def _extract_version_from_path(self, path: str) -> Optional[int]: match = re.search(r"_v(\d+)", path) if match: try: return int(match.group(1)) except ValueError: return None return None def _resolve_raw_version(self, dataset_path: str, default: int = 1) -> int: version = self._extract_version_from_path(dataset_path) if version is None and self.project_service: try: version = self.project_service.get_latest_version(self.user_id, self.project_id, "raw") except Exception: version = default return version or default @tool def get_data_summary(self, dataset_path: str) -> FormattedResponse: """ Calculates and saves a high-level summary of the dataset. This includes shape (rows and columns), a list of column names, and data types for each column. This should be the VERY FIRST tool you use to understand the dataset. Args: dataset_path (str): The full path to the dataset file within project storage. Returns: FormattedResponse: Structured dataset summary with artifact reference. """ try: df = self.storage.load_dataframe(dataset_path) summary = { "shape": {"rows": df.shape[0], "columns": df.shape[1]}, "column_names": list(df.columns), "column_data_types": {col: str(dtype) for col, dtype in df.dtypes.items()}, } output_path = f"{self._get_base_path()}/analysis/data_profile.json" info = self.storage.save_json(summary, output_path) if self.project_service: version = self._resolve_raw_version(dataset_path) info.metadata.update({"columns": summary["column_names"]}) self.project_service.register_artifact( self.user_id, self.project_id, ArtifactTypes.DATA_PROFILE, version, info, version_scope="raw", extra_metadata={"shape": summary["shape"]}, ) dtype_rows = [ {"column": col, "dtype": dtype} for col, dtype in summary["column_data_types"].items() ] blocks = [ metric_block("Rows", summary["shape"]["rows"]), metric_block("Columns", summary["shape"]["columns"]), simple_table_with_types(dtype_rows, caption="Column data types", block_id="column_dtypes"), text_block(f"Saved summary to `{output_path}`"), ] return FormattedResponse( blocks=blocks, summary="Data summary generated", correlation_id=info.path, done=True, ) except Exception as e: error_response = make_text_response( f"Could not get data summary: {e}", severity=Severity.ERROR, ) error_response.done = True return error_response @tool def get_missing_values_summary(self, dataset_path: str) -> FormattedResponse: """ Analyzes the dataset for missing (null or NaN) values in each column and saves a report. This is a crucial step for assessing data quality. Args: dataset_path (str): The full path to the dataset file within project storage. Returns: FormattedResponse: Structured missing-value overview with artifact reference. """ try: df = self.storage.load_dataframe(dataset_path) missing_values = df.isnull().sum() missing_summary = { "total_missing_values": int(missing_values.sum()), "missing_percentage": f"{(missing_values.sum() / (df.shape[0] * df.shape[1])):.2%}", "missing_values_per_column": { col: int(count) for col, count in missing_values.items() if count > 0 } } output_path = f"{self._get_base_path()}/analysis/missing_values_report.json" info = self.storage.save_json(missing_summary, output_path) if self.project_service: version = self._resolve_raw_version(dataset_path) info.metadata.update({"columns_with_missing": list(missing_summary["missing_values_per_column"].keys())}) self.project_service.register_artifact( self.user_id, self.project_id, ArtifactTypes.MISSING_VALUES, version, info, version_scope="raw", extra_metadata={ "total_missing": missing_summary["total_missing_values"], "missing_percentage": missing_summary["missing_percentage"], }, ) columns_with_missing = list(missing_summary["missing_values_per_column"].keys()) table_rows = [ {"column": col, "missing": count} for col, count in missing_summary["missing_values_per_column"].items() ] blocks = [ metric_block( "Total Missing", missing_summary["total_missing_values"], unit="cells", ), text_block( f"Overall missing percentage: {missing_summary['missing_percentage']}", severity=Severity.INFO, ), simple_table(table_rows, caption="Missing values per column", block_id="missing_values"), text_block(f"Saved missing-values report to `{output_path}`"), ] summary_text = ( "No missing values detected" if not columns_with_missing else f"Missing values recorded for {len(columns_with_missing)} columns" ) return FormattedResponse( blocks=blocks, summary=summary_text, correlation_id=info.path, done=True, ) except Exception as e: error_response = make_text_response( f"Could not analyze missing values: {e}", severity=Severity.ERROR, ) error_response.done = True return error_response @tool def save_sample_head( self, dataset_path: str, limit: Optional[int] = None, version: Optional[int] = None, ) -> FormattedResponse: """ Saves the first N rows of the dataset as a JSON file for UI preview. Includes both the data sample and schema information. Args: dataset_path (str): The full path to the dataset file within project storage. limit (Optional[int]): Number of rows to sample (defaults to DEFAULT_SAMPLE_ROWS). Returns: FormattedResponse: Structured sample preview with artifact reference. """ try: df = self.storage.load_dataframe(dataset_path) rows_to_sample = limit or DEFAULT_SAMPLE_ROWS resolved_version = version or self._resolve_raw_version(dataset_path) # Get sample data (first N rows) sample_df = df.head(rows_to_sample) # Create comprehensive sample data structure sample_data = { "dataset_info": { "total_rows": len(df), "total_columns": len(df.columns), "sample_rows": len(sample_df), "source_path": dataset_path }, "schema": { "columns": list(df.columns), "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}, "null_counts": {col: int(count) for col, count in df.isnull().sum().items()} }, "sample_data": { "columns": list(sample_df.columns), "rows": sample_df.to_dict(orient="records") } } # Use versioned path from constants output_path = StoragePaths.SAMPLE_RAW_HEAD.format( user_id=self.user_id, project_id=self.project_id, version=resolved_version ) info = self.storage.save_json(sample_data, output_path) if self.project_service: info.metadata.update({ "sample_rows": sample_data["dataset_info"].get("sample_rows"), "total_rows": sample_data["dataset_info"].get("total_rows"), }) self.project_service.register_artifact( self.user_id, self.project_id, ArtifactTypes.SAMPLE_RAW_HEAD, resolved_version, info, version_scope="raw", extra_metadata={ "columns": sample_data["schema"].get("columns", []), }, ) preview_rows = sample_df.head(min(rows_to_sample, 10)).to_dict(orient="records") blocks = [ metric_block("Rows Sampled", len(sample_df)), metric_block("Total Rows", len(df)), simple_table_with_types(preview_rows, caption="First rows preview", block_id="sample_preview"), text_block(f"Sample saved to `{output_path}`"), ] return FormattedResponse( blocks=blocks, summary=f"Saved sample head (first {len(sample_df)} rows)", correlation_id=info.path, done=True, ) except Exception as e: error_response = make_text_response( f"Could not save dataset sample: {e}", severity=Severity.ERROR, ) error_response.done = True return error_response