sirus / backend /ml_module /tools /data_tools.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
783a952
# ml_module/tools/data_tools.py
import re
from typing import Dict, Optional
import pandas as pd
from agno.tools import Toolkit, tool
from ml_module.services.storage_service import MLStorageService
from ml_module.services.project_service import ProjectService
from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths
from ml_module.core.response_formatter import (
FormattedResponse,
Severity,
make_text_response,
metric_block,
simple_table,
simple_table_with_types,
visualization_block,
text_block,
)
class DataAnalysisToolkit(Toolkit):
"""A collection of safe tools for performing data analysis."""
def __init__(
self,
storage_service: MLStorageService,
user_id: str,
project_id: str,
project_service: Optional[ProjectService] = None,
):
super().__init__(name="data_analysis_tools")
self.storage = storage_service
self.user_id = user_id
self.project_id = project_id
self.project_service = project_service
def _get_base_path(self) -> str:
return f"{self.user_id}/{self.project_id}"
def _extract_version_from_path(self, path: str) -> Optional[int]:
match = re.search(r"_v(\d+)", path)
if match:
try:
return int(match.group(1))
except ValueError:
return None
return None
def _resolve_raw_version(self, dataset_path: str, default: int = 1) -> int:
version = self._extract_version_from_path(dataset_path)
if version is None and self.project_service:
try:
version = self.project_service.get_latest_version(self.user_id, self.project_id, "raw")
except Exception:
version = default
return version or default
@tool
def get_data_summary(self, dataset_path: str) -> FormattedResponse:
"""
Calculates and saves a high-level summary of the dataset. This includes
shape (rows and columns), a list of column names, and data types for each column.
This should be the VERY FIRST tool you use to understand the dataset.
Args:
dataset_path (str): The full path to the dataset file within project storage.
Returns:
FormattedResponse: Structured dataset summary with artifact reference.
"""
try:
df = self.storage.load_dataframe(dataset_path)
summary = {
"shape": {"rows": df.shape[0], "columns": df.shape[1]},
"column_names": list(df.columns),
"column_data_types": {col: str(dtype) for col, dtype in df.dtypes.items()},
}
output_path = f"{self._get_base_path()}/analysis/data_profile.json"
info = self.storage.save_json(summary, output_path)
if self.project_service:
version = self._resolve_raw_version(dataset_path)
info.metadata.update({"columns": summary["column_names"]})
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.DATA_PROFILE,
version,
info,
version_scope="raw",
extra_metadata={"shape": summary["shape"]},
)
dtype_rows = [
{"column": col, "dtype": dtype}
for col, dtype in summary["column_data_types"].items()
]
blocks = [
metric_block("Rows", summary["shape"]["rows"]),
metric_block("Columns", summary["shape"]["columns"]),
simple_table_with_types(dtype_rows, caption="Column data types", block_id="column_dtypes"),
text_block(f"Saved summary to `{output_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary="Data summary generated",
correlation_id=info.path,
done=True,
)
except Exception as e:
error_response = make_text_response(
f"Could not get data summary: {e}",
severity=Severity.ERROR,
)
error_response.done = True
return error_response
@tool
def get_missing_values_summary(self, dataset_path: str) -> FormattedResponse:
"""
Analyzes the dataset for missing (null or NaN) values in each column and saves a
report. This is a crucial step for assessing data quality.
Args:
dataset_path (str): The full path to the dataset file within project storage.
Returns:
FormattedResponse: Structured missing-value overview with artifact reference.
"""
try:
df = self.storage.load_dataframe(dataset_path)
missing_values = df.isnull().sum()
missing_summary = {
"total_missing_values": int(missing_values.sum()),
"missing_percentage": f"{(missing_values.sum() / (df.shape[0] * df.shape[1])):.2%}",
"missing_values_per_column": {
col: int(count) for col, count in missing_values.items() if count > 0
}
}
output_path = f"{self._get_base_path()}/analysis/missing_values_report.json"
info = self.storage.save_json(missing_summary, output_path)
if self.project_service:
version = self._resolve_raw_version(dataset_path)
info.metadata.update({"columns_with_missing": list(missing_summary["missing_values_per_column"].keys())})
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.MISSING_VALUES,
version,
info,
version_scope="raw",
extra_metadata={
"total_missing": missing_summary["total_missing_values"],
"missing_percentage": missing_summary["missing_percentage"],
},
)
columns_with_missing = list(missing_summary["missing_values_per_column"].keys())
table_rows = [
{"column": col, "missing": count}
for col, count in missing_summary["missing_values_per_column"].items()
]
blocks = [
metric_block(
"Total Missing",
missing_summary["total_missing_values"],
unit="cells",
),
text_block(
f"Overall missing percentage: {missing_summary['missing_percentage']}",
severity=Severity.INFO,
),
simple_table(table_rows, caption="Missing values per column", block_id="missing_values"),
text_block(f"Saved missing-values report to `{output_path}`"),
]
summary_text = (
"No missing values detected"
if not columns_with_missing
else f"Missing values recorded for {len(columns_with_missing)} columns"
)
return FormattedResponse(
blocks=blocks,
summary=summary_text,
correlation_id=info.path,
done=True,
)
except Exception as e:
error_response = make_text_response(
f"Could not analyze missing values: {e}",
severity=Severity.ERROR,
)
error_response.done = True
return error_response
@tool
def save_sample_head(
self,
dataset_path: str,
limit: Optional[int] = None,
version: Optional[int] = None,
) -> FormattedResponse:
"""
Saves the first N rows of the dataset as a JSON file for UI preview.
Includes both the data sample and schema information.
Args:
dataset_path (str): The full path to the dataset file within project storage.
limit (Optional[int]): Number of rows to sample (defaults to DEFAULT_SAMPLE_ROWS).
Returns:
FormattedResponse: Structured sample preview with artifact reference.
"""
try:
df = self.storage.load_dataframe(dataset_path)
rows_to_sample = limit or DEFAULT_SAMPLE_ROWS
resolved_version = version or self._resolve_raw_version(dataset_path)
# Get sample data (first N rows)
sample_df = df.head(rows_to_sample)
# Create comprehensive sample data structure
sample_data = {
"dataset_info": {
"total_rows": len(df),
"total_columns": len(df.columns),
"sample_rows": len(sample_df),
"source_path": dataset_path
},
"schema": {
"columns": list(df.columns),
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
"null_counts": {col: int(count) for col, count in df.isnull().sum().items()}
},
"sample_data": {
"columns": list(sample_df.columns),
"rows": sample_df.to_dict(orient="records")
}
}
# Use versioned path from constants
output_path = StoragePaths.SAMPLE_RAW_HEAD.format(
user_id=self.user_id,
project_id=self.project_id,
version=resolved_version
)
info = self.storage.save_json(sample_data, output_path)
if self.project_service:
info.metadata.update({
"sample_rows": sample_data["dataset_info"].get("sample_rows"),
"total_rows": sample_data["dataset_info"].get("total_rows"),
})
self.project_service.register_artifact(
self.user_id,
self.project_id,
ArtifactTypes.SAMPLE_RAW_HEAD,
resolved_version,
info,
version_scope="raw",
extra_metadata={
"columns": sample_data["schema"].get("columns", []),
},
)
preview_rows = sample_df.head(min(rows_to_sample, 10)).to_dict(orient="records")
blocks = [
metric_block("Rows Sampled", len(sample_df)),
metric_block("Total Rows", len(df)),
simple_table_with_types(preview_rows, caption="First rows preview", block_id="sample_preview"),
text_block(f"Sample saved to `{output_path}`"),
]
return FormattedResponse(
blocks=blocks,
summary=f"Saved sample head (first {len(sample_df)} rows)",
correlation_id=info.path,
done=True,
)
except Exception as e:
error_response = make_text_response(
f"Could not save dataset sample: {e}",
severity=Severity.ERROR,
)
error_response.done = True
return error_response