Spaces:
Sleeping
Sleeping
| """ | |
| Data Processor Agent for CoDA. | |
| Extracts metadata and insights from data files without loading full datasets, | |
| enabling the system to work within token limits while providing rich context | |
| for visualization decisions. | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| import pandas as pd | |
| from pydantic import BaseModel, Field | |
| from coda.core.base_agent import AgentContext, BaseAgent | |
| from coda.core.llm import LLMProvider | |
| from coda.core.memory import SharedMemory | |
| logger = logging.getLogger(__name__) | |
| class ColumnInfo(BaseModel): | |
| """Information about a single column.""" | |
| name: str | |
| dtype: str | |
| non_null_count: int | |
| unique_count: int | |
| sample_values: list[Any] | |
| class DataFileInfo(BaseModel): | |
| """Metadata about a single data file.""" | |
| file_path: str | |
| file_type: str | |
| shape: tuple[int, int] | |
| columns: list[ColumnInfo] | |
| memory_usage_mb: float | |
| class DataAnalysis(BaseModel): | |
| """Structured output from the Data Processor.""" | |
| files: list[DataFileInfo] = Field( | |
| description="Metadata for each processed data file" | |
| ) | |
| insights: list[str] = Field( | |
| description="Key insights about the data (patterns, outliers, etc.)" | |
| ) | |
| processing_steps: list[str] = Field( | |
| description="Recommended data processing steps" | |
| ) | |
| aggregations_needed: list[str] = Field( | |
| default_factory=list, | |
| description="Suggested aggregations for visualization" | |
| ) | |
| visualization_hints: list[str] = Field( | |
| default_factory=list, | |
| description="Hints for visualization based on data characteristics" | |
| ) | |
| potential_issues: list[str] = Field( | |
| default_factory=list, | |
| description="Potential data quality issues" | |
| ) | |
| class DataProcessorAgent(BaseAgent[DataAnalysis]): | |
| """ | |
| Processes data files to extract metadata and insights. | |
| Uses lightweight analysis to avoid token limits while providing | |
| comprehensive data understanding for downstream agents. | |
| """ | |
| MEMORY_KEY = "data_analysis" | |
| SUPPORTED_EXTENSIONS = {".csv", ".json", ".xlsx", ".xls", ".parquet"} | |
| def __init__( | |
| self, | |
| llm: LLMProvider, | |
| memory: SharedMemory, | |
| name: Optional[str] = None, | |
| ) -> None: | |
| super().__init__(llm, memory, name or "DataProcessor") | |
| def execute(self, context: AgentContext) -> DataAnalysis: | |
| """Override to include data extraction before LLM analysis.""" | |
| logger.info(f"[{self._name}] Processing {len(context.data_paths)} data files") | |
| file_infos = [] | |
| for path in context.data_paths: | |
| info = self._extract_file_metadata(path) | |
| if info: | |
| file_infos.append(info) | |
| self._memory.store( | |
| key="raw_file_info", | |
| value=[f.model_dump() for f in file_infos], | |
| agent_name=self._name, | |
| ) | |
| return super().execute(context) | |
| def _extract_file_metadata(self, file_path: str) -> Optional[DataFileInfo]: | |
| """Extract metadata from a data file using pandas.""" | |
| path = Path(file_path) | |
| if not path.exists(): | |
| logger.warning(f"File not found: {path}") | |
| return None | |
| if path.suffix.lower() not in self.SUPPORTED_EXTENSIONS: | |
| logger.warning(f"Unsupported file type: {path.suffix}") | |
| return None | |
| try: | |
| df = self._load_dataframe(path) | |
| columns = self._analyze_columns(df) | |
| return DataFileInfo( | |
| file_path=str(path), | |
| file_type=path.suffix.lower(), | |
| shape=(len(df), len(df.columns)), | |
| columns=columns, | |
| memory_usage_mb=df.memory_usage(deep=True).sum() / (1024 * 1024), | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to process {path}: {e}") | |
| return None | |
| def _load_dataframe(self, path: Path) -> pd.DataFrame: | |
| """Load a dataframe from various file formats.""" | |
| suffix = path.suffix.lower() | |
| if suffix == ".csv": | |
| return pd.read_csv(path) | |
| elif suffix == ".json": | |
| return pd.read_json(path) | |
| elif suffix in {".xlsx", ".xls"}: | |
| return pd.read_excel(path) | |
| elif suffix == ".parquet": | |
| return pd.read_parquet(path) | |
| else: | |
| raise ValueError(f"Unsupported format: {suffix}") | |
| def _analyze_columns(self, df: pd.DataFrame) -> list[ColumnInfo]: | |
| """Analyze each column in the dataframe.""" | |
| columns = [] | |
| for col in df.columns: | |
| series = df[col] | |
| sample_values = series.dropna().head(5).tolist() | |
| columns.append(ColumnInfo( | |
| name=str(col), | |
| dtype=str(series.dtype), | |
| non_null_count=int(series.count()), | |
| unique_count=int(series.nunique()), | |
| sample_values=sample_values, | |
| )) | |
| return columns | |
| def _get_system_prompt(self) -> str: | |
| return """You are a Data Analyst specialist in a data visualization team. | |
| Your expertise is in understanding data structures, identifying patterns, and recommending processing steps for effective visualization. | |
| Your responsibilities: | |
| 1. Analyze metadata to understand data characteristics | |
| 2. Identify insights and patterns relevant to visualization | |
| 3. Recommend data processing and aggregation steps | |
| 4. Suggest visualization approaches based on data types | |
| 5. Flag potential data quality issues | |
| Always respond with a valid JSON object matching the required schema.""" | |
| def _build_prompt(self, context: AgentContext) -> str: | |
| file_info = self._get_from_memory("raw_file_info") or [] | |
| query_analysis = self._get_from_memory("query_analysis") or {} | |
| file_summary = self._format_file_info(file_info) | |
| query_context = "" | |
| if query_analysis: | |
| query_context = f""" | |
| Query Analysis: | |
| - Visualization Types: {query_analysis.get('visualization_types', [])} | |
| - Key Points: {query_analysis.get('key_points', [])} | |
| - Data Requirements: {query_analysis.get('data_requirements', [])} | |
| """ | |
| return f"""Analyze the following data files for visualization purposes. | |
| User Query: {context.query} | |
| {query_context} | |
| Data Files: | |
| {file_summary} | |
| Based on this metadata, provide a JSON object with these fields. | |
| IMPORTANT: All list fields must contain SIMPLE STRINGS, not objects. | |
| {{ | |
| "insights": ["string1", "string2", ...], // Simple string descriptions of patterns | |
| "processing_steps": ["step1", "step2", ...], // Simple string descriptions of steps | |
| "aggregations_needed": ["agg1", "agg2", ...], // Simple string descriptions | |
| "visualization_hints": ["hint1", "hint2", ...], // Simple string hints | |
| "potential_issues": ["issue1", "issue2", ...] // Simple string issues | |
| }} | |
| JSON Response:""" | |
| def _format_file_info(self, file_info: list[dict]) -> str: | |
| """Format file information for the prompt.""" | |
| if not file_info: | |
| return "No data files available." | |
| lines = [] | |
| for f in file_info: | |
| lines.append(f"\nFile: {f['file_path']}") | |
| lines.append(f" Type: {f['file_type']}") | |
| lines.append(f" Shape: {f['shape'][0]} rows × {f['shape'][1]} columns") | |
| lines.append(" Columns:") | |
| for col in f.get("columns", []): | |
| samples = ", ".join(str(v) for v in col.get("sample_values", [])[:3]) | |
| lines.append( | |
| f" - {col['name']} ({col['dtype']}): " | |
| f"{col['unique_count']} unique, samples: [{samples}]" | |
| ) | |
| return "\n".join(lines) | |
| def _normalize_list_field(self, value: Any) -> list[str]: | |
| """Normalize a field that should be a list of strings.""" | |
| if value is None: | |
| return [] | |
| if isinstance(value, dict): | |
| return [f"{k}: {v}" for k, v in value.items()] | |
| if isinstance(value, list): | |
| result = [] | |
| for item in value: | |
| if isinstance(item, str): | |
| result.append(item) | |
| elif isinstance(item, dict): | |
| desc_keys = ["description", "desc", "text", "value", "step", "hint", "issue"] | |
| for key in desc_keys: | |
| if key in item: | |
| result.append(str(item[key])) | |
| break | |
| else: | |
| result.append(str(item)) | |
| else: | |
| result.append(str(item)) | |
| return result | |
| return [str(value)] | |
| def _parse_response(self, response: str) -> DataAnalysis: | |
| data = self._extract_json(response) | |
| data["insights"] = self._normalize_list_field(data.get("insights")) | |
| data["processing_steps"] = self._normalize_list_field(data.get("processing_steps")) | |
| data["aggregations_needed"] = self._normalize_list_field(data.get("aggregations_needed")) | |
| data["visualization_hints"] = self._normalize_list_field(data.get("visualization_hints")) | |
| data["potential_issues"] = self._normalize_list_field(data.get("potential_issues")) | |
| file_info = self._get_from_memory("raw_file_info") or [] | |
| data["files"] = file_info | |
| return DataAnalysis(**data) | |
| def _get_output_key(self) -> str: | |
| return self.MEMORY_KEY | |