File size: 11,508 Bytes
783a952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# ml_module/tools/data_tools.py
import re
from typing import Dict, Optional

import pandas as pd
from agno.tools import Toolkit, tool

from ml_module.services.storage_service import MLStorageService
from ml_module.services.project_service import ProjectService
from ml_module.core.constants import ArtifactTypes, DEFAULT_SAMPLE_ROWS, StoragePaths
from ml_module.core.response_formatter import (
    FormattedResponse,
    Severity,
    make_text_response,
    metric_block,
    simple_table,
    simple_table_with_types,
    visualization_block,
    text_block,
)

class DataAnalysisToolkit(Toolkit):
    """A collection of safe tools for performing data analysis."""
    def __init__(
        self,
        storage_service: MLStorageService,
        user_id: str,
        project_id: str,
        project_service: Optional[ProjectService] = None,
    ):
        super().__init__(name="data_analysis_tools")
        self.storage = storage_service
        self.user_id = user_id
        self.project_id = project_id
        self.project_service = project_service

    def _get_base_path(self) -> str:
        return f"{self.user_id}/{self.project_id}"

    def _extract_version_from_path(self, path: str) -> Optional[int]:
        match = re.search(r"_v(\d+)", path)
        if match:
            try:
                return int(match.group(1))
            except ValueError:
                return None
        return None

    def _resolve_raw_version(self, dataset_path: str, default: int = 1) -> int:
        version = self._extract_version_from_path(dataset_path)
        if version is None and self.project_service:
            try:
                version = self.project_service.get_latest_version(self.user_id, self.project_id, "raw")
            except Exception:
                version = default
        return version or default

    @tool
    def get_data_summary(self, dataset_path: str) -> FormattedResponse:
        """
        Calculates and saves a high-level summary of the dataset. This includes
        shape (rows and columns), a list of column names, and data types for each column.
        This should be the VERY FIRST tool you use to understand the dataset.

        Args:
            dataset_path (str): The full path to the dataset file within project storage.

        Returns:
            FormattedResponse: Structured dataset summary with artifact reference.
        """
        try:
            df = self.storage.load_dataframe(dataset_path)
            summary = {
                "shape": {"rows": df.shape[0], "columns": df.shape[1]},
                "column_names": list(df.columns),
                "column_data_types": {col: str(dtype) for col, dtype in df.dtypes.items()},
            }
            output_path = f"{self._get_base_path()}/analysis/data_profile.json"
            info = self.storage.save_json(summary, output_path)

            if self.project_service:
                version = self._resolve_raw_version(dataset_path)
                info.metadata.update({"columns": summary["column_names"]})
                self.project_service.register_artifact(
                    self.user_id,
                    self.project_id,
                    ArtifactTypes.DATA_PROFILE,
                    version,
                    info,
                    version_scope="raw",
                    extra_metadata={"shape": summary["shape"]},
                )
            dtype_rows = [
                {"column": col, "dtype": dtype}
                for col, dtype in summary["column_data_types"].items()
            ]
            blocks = [
                metric_block("Rows", summary["shape"]["rows"]),
                metric_block("Columns", summary["shape"]["columns"]),
                simple_table_with_types(dtype_rows, caption="Column data types", block_id="column_dtypes"),
                text_block(f"Saved summary to `{output_path}`"),
            ]
            return FormattedResponse(
                blocks=blocks,
                summary="Data summary generated",
                correlation_id=info.path,
                done=True,
            )
        except Exception as e:
            error_response = make_text_response(
                f"Could not get data summary: {e}",
                severity=Severity.ERROR,
            )
            error_response.done = True
            return error_response

    @tool
    def get_missing_values_summary(self, dataset_path: str) -> FormattedResponse:
        """
        Analyzes the dataset for missing (null or NaN) values in each column and saves a
        report. This is a crucial step for assessing data quality.

        Args:
            dataset_path (str): The full path to the dataset file within project storage.

        Returns:
            FormattedResponse: Structured missing-value overview with artifact reference.
        """
        try:
            df = self.storage.load_dataframe(dataset_path)
            missing_values = df.isnull().sum()
            missing_summary = {
                "total_missing_values": int(missing_values.sum()),
                "missing_percentage": f"{(missing_values.sum() / (df.shape[0] * df.shape[1])):.2%}",
                "missing_values_per_column": {
                    col: int(count) for col, count in missing_values.items() if count > 0
                }
            }
            output_path = f"{self._get_base_path()}/analysis/missing_values_report.json"
            info = self.storage.save_json(missing_summary, output_path)

            if self.project_service:
                version = self._resolve_raw_version(dataset_path)
                info.metadata.update({"columns_with_missing": list(missing_summary["missing_values_per_column"].keys())})
                self.project_service.register_artifact(
                    self.user_id,
                    self.project_id,
                    ArtifactTypes.MISSING_VALUES,
                    version,
                    info,
                    version_scope="raw",
                    extra_metadata={
                        "total_missing": missing_summary["total_missing_values"],
                        "missing_percentage": missing_summary["missing_percentage"],
                    },
                )
            columns_with_missing = list(missing_summary["missing_values_per_column"].keys())
            table_rows = [
                {"column": col, "missing": count}
                for col, count in missing_summary["missing_values_per_column"].items()
            ]
            blocks = [
                metric_block(
                    "Total Missing",
                    missing_summary["total_missing_values"],
                    unit="cells",
                ),
                text_block(
                    f"Overall missing percentage: {missing_summary['missing_percentage']}",
                    severity=Severity.INFO,
                ),
                simple_table(table_rows, caption="Missing values per column", block_id="missing_values"),
                text_block(f"Saved missing-values report to `{output_path}`"),
            ]
            summary_text = (
                "No missing values detected"
                if not columns_with_missing
                else f"Missing values recorded for {len(columns_with_missing)} columns"
            )
            return FormattedResponse(
                blocks=blocks,
                summary=summary_text,
                correlation_id=info.path,
                done=True,
            )
        except Exception as e:
            error_response = make_text_response(
                f"Could not analyze missing values: {e}",
                severity=Severity.ERROR,
            )
            error_response.done = True
            return error_response

    @tool
    def save_sample_head(
        self,
        dataset_path: str,
        limit: Optional[int] = None,
        version: Optional[int] = None,
    ) -> FormattedResponse:
        """
        Saves the first N rows of the dataset as a JSON file for UI preview.
        Includes both the data sample and schema information.
        
        Args:
            dataset_path (str): The full path to the dataset file within project storage.
            limit (Optional[int]): Number of rows to sample (defaults to DEFAULT_SAMPLE_ROWS).
        
        Returns:
            FormattedResponse: Structured sample preview with artifact reference.
        """
        try:
            df = self.storage.load_dataframe(dataset_path)
            rows_to_sample = limit or DEFAULT_SAMPLE_ROWS
            resolved_version = version or self._resolve_raw_version(dataset_path)
            
            # Get sample data (first N rows)
            sample_df = df.head(rows_to_sample)
            
            # Create comprehensive sample data structure
            sample_data = {
                "dataset_info": {
                    "total_rows": len(df),
                    "total_columns": len(df.columns),
                    "sample_rows": len(sample_df),
                    "source_path": dataset_path
                },
                "schema": {
                    "columns": list(df.columns),
                    "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
                    "null_counts": {col: int(count) for col, count in df.isnull().sum().items()}
                },
                "sample_data": {
                    "columns": list(sample_df.columns),
                    "rows": sample_df.to_dict(orient="records")
                }
            }
            
            # Use versioned path from constants
            output_path = StoragePaths.SAMPLE_RAW_HEAD.format(
                user_id=self.user_id, 
                project_id=self.project_id, 
                version=resolved_version
            )
            info = self.storage.save_json(sample_data, output_path)

            if self.project_service:
                info.metadata.update({
                    "sample_rows": sample_data["dataset_info"].get("sample_rows"),
                    "total_rows": sample_data["dataset_info"].get("total_rows"),
                })
                self.project_service.register_artifact(
                    self.user_id,
                    self.project_id,
                    ArtifactTypes.SAMPLE_RAW_HEAD,
                    resolved_version,
                    info,
                    version_scope="raw",
                    extra_metadata={
                        "columns": sample_data["schema"].get("columns", []),
                    },
                )
            preview_rows = sample_df.head(min(rows_to_sample, 10)).to_dict(orient="records")
            blocks = [
                metric_block("Rows Sampled", len(sample_df)),
                metric_block("Total Rows", len(df)),
                simple_table_with_types(preview_rows, caption="First rows preview", block_id="sample_preview"),
                text_block(f"Sample saved to `{output_path}`"),
            ]
            return FormattedResponse(
                blocks=blocks,
                summary=f"Saved sample head (first {len(sample_df)} rows)",
                correlation_id=info.path,
                done=True,
            )
        except Exception as e:
            error_response = make_text_response(
                f"Could not save dataset sample: {e}",
                severity=Severity.ERROR,
            )
            error_response.done = True
            return error_response