| | from typing import Dict, List, Any, Optional |
| | import pandas as pd |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | import numpy as np |
| |
|
| | class CSVQueryEngine: |
| | """Query engine for CSV data with multi-file support.""" |
| | |
| | def __init__(self, index_manager, llm): |
| | """Initialize with index manager and language model.""" |
| | self.index_manager = index_manager |
| | self.llm = llm |
| | |
| | def query(self, query_text: str) -> Dict[str, Any]: |
| | """Process a natural language query across CSV files.""" |
| | |
| | relevant_csvs = self.index_manager.find_relevant_csvs(query_text) |
| | |
| | if not relevant_csvs: |
| | return { |
| | "answer": "No relevant CSV files found for your query.", |
| | "sources": [] |
| | } |
| | |
| | |
| | context = self._prepare_context(query_text, relevant_csvs) |
| | |
| | |
| | prompt = self._generate_prompt(query_text, context) |
| | |
| | |
| | response = self.llm.complete(prompt) |
| | |
| | |
| | return { |
| | "answer": response.text, |
| | "sources": self._get_sources(relevant_csvs) |
| | } |
| | |
| | def _prepare_context(self, query: str, csv_ids: List[str]) -> str: |
| | """Prepare context from relevant CSV files with pre-calculated statistics.""" |
| | context_parts = [] |
| | calculated_answers = {} |
| | |
| | |
| | query_lower = query.lower() |
| | is_avg_question = "average" in query_lower or "mean" in query_lower |
| | is_max_question = "maximum" in query_lower or "max" in query_lower |
| | is_min_question = "minimum" in query_lower or "min" in query_lower |
| | |
| | |
| | query_words = set(query_lower.replace("?", "").replace(",", "").split()) |
| | |
| | for csv_id in csv_ids: |
| | |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | file_path = self.index_manager.indexes[csv_id]["path"] |
| | |
| | |
| | context_parts.append(f"CSV File: {metadata['filename']}") |
| | context_parts.append(f"Columns: {', '.join(metadata['columns'])}") |
| | context_parts.append(f"Row Count: {metadata['row_count']}") |
| | |
| | |
| | try: |
| | df = pd.read_csv(file_path) |
| | context_parts.append("\nSample Data:") |
| | context_parts.append(df.head(3).to_string()) |
| | |
| | |
| | column_matches = [] |
| | for col in df.columns: |
| | col_lower = col.lower() |
| | |
| | if col_lower in query_lower or any(word in col_lower for word in query_words): |
| | column_matches.append(col) |
| | |
| | |
| | if not column_matches: |
| | column_matches = df.select_dtypes(include=['number']).columns.tolist() |
| | |
| | |
| | for col in column_matches: |
| | if pd.api.types.is_numeric_dtype(df[col]): |
| | if is_avg_question: |
| | avg_value = df[col].mean() |
| | context_parts.append(f"\nThe average {col} is: {avg_value:.2f}") |
| | calculated_answers[f"average_{col}"] = avg_value |
| | |
| | if is_max_question: |
| | max_value = df[col].max() |
| | context_parts.append(f"\nThe maximum {col} is: {max_value}") |
| | calculated_answers[f"max_{col}"] = max_value |
| | |
| | if is_min_question: |
| | min_value = df[col].min() |
| | context_parts.append(f"\nThe minimum {col} is: {min_value}") |
| | calculated_answers[f"min_{col}"] = min_value |
| | |
| | except Exception as e: |
| | context_parts.append(f"Error reading CSV: {str(e)}") |
| | |
| | |
| | if calculated_answers: |
| | context_parts.append("\nDirect Answer:") |
| | for key, value in calculated_answers.items(): |
| | context_parts.append(f"{key.replace('_', ' ')}: {value}") |
| | |
| | return "\n\n".join(context_parts) |
| | |
| | def _prepare_context1(self, query: str, csv_ids: List[str]) -> str: |
| | """Prepare context from relevant CSV files.""" |
| | context_parts = [] |
| | |
| | for csv_id in csv_ids: |
| | |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | file_path = self.index_manager.indexes[csv_id]["path"] |
| | |
| | |
| | context_parts.append(f"CSV File: {metadata['filename']}") |
| | context_parts.append(f"Columns: {', '.join(metadata['columns'])}") |
| | context_parts.append(f"Row Count: {metadata['row_count']}") |
| | |
| | |
| | try: |
| | df = pd.read_csv(file_path) |
| | context_parts.append("\nSample Data:") |
| | context_parts.append(df.head(5).to_string()) |
| | |
| | |
| | context_parts.append("\nNumeric Column Statistics:") |
| | numeric_cols = df.select_dtypes(include=['number']).columns |
| | for col in numeric_cols: |
| | stats = df[col].describe() |
| | context_parts.append(f"{col} - mean: {stats['mean']:.2f}, min: {stats['min']:.2f}, max: {stats['max']:.2f}") |
| | except Exception as e: |
| | context_parts.append(f"Error reading CSV: {str(e)}") |
| | |
| | return "\n\n".join(context_parts) |
| | |
| | def _generate_prompt(self, query: str, context: str) -> str: |
| | """Generate a prompt for the LLM.""" |
| | return f"""You are an AI assistant specialized in analyzing CSV data. |
| | Your goal is to help users understand their data and extract insights. |
| | |
| | Below is information about CSV files that might help answer the query: |
| | |
| | {context} |
| | |
| | User Query: {query} |
| | |
| | Please provide a comprehensive and accurate answer based on the data. |
| | If calculations are needed, explain your process. |
| | If the data doesn't contain information to answer the query, say so clearly. |
| | |
| | Answer:""" |
| | |
| | def _get_sources(self, csv_ids: List[str]) -> List[Dict[str, str]]: |
| | """Get source information for the response.""" |
| | sources = [] |
| | |
| | for csv_id in csv_ids: |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | sources.append({ |
| | "csv": metadata["filename"], |
| | "columns": ", ".join(metadata["columns"][:5]) + ("..." if len(metadata["columns"]) > 5 else "") |
| | }) |
| | |
| | return sources |
| |
|