| | from typing import Dict, List, Any, Optional |
| | import pandas as pd |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | import numpy as np |
| |
|
| | class CSVQueryEngine: |
| |
|
| | def __init__(self, index_manager, llm): |
| | """Initialize with index manager and language model.""" |
| | self.index_manager = index_manager |
| | self.llm = llm |
| |
|
| | def _prepare_context(self, query: str, csv_ids: List[str]) -> str: |
| | """Prepare context from relevant CSV files.""" |
| | context_parts = [] |
| | |
| | for csv_id in csv_ids: |
| | |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | file_path = self.index_manager.indexes[csv_id]["path"] |
| | |
| | |
| | context_parts.append(f"CSV File: {metadata['filename']}") |
| | context_parts.append(f"Columns: {', '.join(metadata['columns'])}") |
| | context_parts.append(f"Row Count: {metadata['row_count']}") |
| | |
| | |
| | try: |
| | df = pd.read_csv(file_path) |
| | context_parts.append("\nSample Data:") |
| | context_parts.append(df.head(5).to_string()) |
| | |
| | |
| | context_parts.append("\nNumeric Column Statistics:") |
| | numeric_cols = df.select_dtypes(include=['number']).columns |
| | for col in numeric_cols: |
| | stats = df[col].describe() |
| | context_parts.append(f"{col} - mean: {stats['mean']:.2f}, min: {stats['min']:.2f}, max: {stats['max']:.2f}") |
| | |
| | |
| | categorical_cols = df.select_dtypes(include=['object', 'category']).columns |
| | if len(categorical_cols) > 0: |
| | context_parts.append("\nCategorical Column Information:") |
| | for col in categorical_cols: |
| | value_counts = df[col].value_counts().head(5) |
| | context_parts.append(f"{col} - unique values: {df[col].nunique()}, top values: {', '.join(value_counts.index.astype(str))}") |
| | |
| | |
| | date_cols = [] |
| | for col in df.columns: |
| | try: |
| | if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all(): |
| | date_cols.append(col) |
| | except: |
| | pass |
| | |
| | if date_cols: |
| | context_parts.append("\nDate Column Information:") |
| | for col in date_cols: |
| | if not pd.api.types.is_datetime64_dtype(df[col]): |
| | df[col] = pd.to_datetime(df[col], errors='coerce') |
| | context_parts.append(f"{col} - range: {df[col].min()} to {df[col].max()}") |
| | |
| | except Exception as e: |
| | context_parts.append(f"Error reading CSV: {str(e)}") |
| | |
| | return "\n\n".join(context_parts) |
| | |
| | def _generate_prompt(self, query: str, context: str) -> str: |
| | """Generate a prompt for the LLM.""" |
| | return f"""You are an AI assistant specialized in analyzing CSV data. |
| | Your goal is to help users understand their data and extract insights. |
| | |
| | Below is information about CSV files that might help answer the query: |
| | |
| | {context} |
| | |
| | User Query: {query} |
| | |
| | Please provide a comprehensive and accurate answer based on the data. |
| | If calculations are needed, explain your process. |
| | If the data doesn't contain information to answer the query, say so clearly. |
| | |
| | Answer:""" |
| |
|
| | def query(self, query_text: str) -> Dict[str, Any]: |
| | """Process a natural language query across CSV files.""" |
| | |
| | relevant_csvs = self.index_manager.find_relevant_csvs(query_text) |
| | |
| | if not relevant_csvs: |
| | return { |
| | "answer": "No relevant CSV files found for your query.", |
| | "sources": [] |
| | } |
| | |
| | |
| | direct_answer = self._handle_statistical_query(query_text, relevant_csvs) |
| | if direct_answer: |
| | return { |
| | "answer": direct_answer, |
| | "sources": self._get_sources(relevant_csvs) |
| | } |
| | |
| | |
| | context = self._prepare_context(query_text, relevant_csvs) |
| | prompt = self._generate_prompt(query_text, context) |
| | response = self.llm.complete(prompt) |
| | |
| | return { |
| | "answer": response.text, |
| | "sources": self._get_sources(relevant_csvs) |
| | } |
| | |
| | def _get_sources(self, csv_ids: List[str]) -> List[Dict[str, str]]: |
| | """Get source information for the response.""" |
| | sources = [] |
| | |
| | for csv_id in csv_ids: |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | sources.append({ |
| | "csv": metadata["filename"], |
| | "columns": ", ".join(metadata["columns"][:5]) + ("..." if len(metadata["columns"]) > 5 else "") |
| | }) |
| | |
| | return sources |
| | |
| | |
| | def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]: |
| | """Handle direct statistical queries without using the LLM.""" |
| | query_lower = query.lower() |
| | |
| | |
| | is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower |
| | is_max_query = "maximum" in query_lower or "max" in query_lower |
| | is_min_query = "minimum" in query_lower or "min" in query_lower |
| | is_count_query = "count" in query_lower or "how many" in query_lower |
| | is_unique_query = "unique" in query_lower or "distinct" in query_lower |
| | |
| | if not (is_avg_query or is_max_query or is_min_query or is_count_query or is_unique_query): |
| | return None |
| | |
| | |
| | query_words = set(query_lower.replace("?", "").replace(",", "").split()) |
| | |
| | for csv_id in csv_ids: |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | file_path = self.index_manager.indexes[csv_id]["path"] |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | |
| | try: |
| | df = pd.read_csv(file_path) |
| | |
| | |
| | target_columns = [] |
| | for col in df.columns: |
| | col_lower = col.lower() |
| | |
| | if any(word in col_lower for word in query_words) or col_lower in query_lower: |
| | target_columns.append(col) |
| | |
| | |
| | if not target_columns: |
| | if any(word in query_lower for word in ["age", "old", "young"]): |
| | age_cols = [col for col in df.columns if "age" in col.lower()] |
| | if age_cols: |
| | target_columns = age_cols |
| | elif any(word in query_lower for word in ["class", "category", "type", "grade"]): |
| | class_cols = [col for col in df.columns if any(term in col.lower() |
| | for term in ["class", "category", "type", "grade"])] |
| | if class_cols: |
| | target_columns = class_cols |
| | elif any(word in query_lower for word in ["income", "salary", "money", "price", "cost"]): |
| | income_cols = [col for col in df.columns if any(term in col.lower() |
| | for term in ["income", "salary", "wage", "earnings", "price", "cost"])] |
| | if income_cols: |
| | target_columns = income_cols |
| | elif any(word in query_lower for word in ["date", "time", "year", "month", "day"]): |
| | date_cols = [] |
| | for col in df.columns: |
| | try: |
| | if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all(): |
| | date_cols.append(col) |
| | except: |
| | pass |
| | if date_cols: |
| | target_columns = date_cols |
| | |
| | |
| | |
| | if not target_columns: |
| | if is_count_query or is_unique_query: |
| | target_columns = df.columns.tolist() |
| | else: |
| | target_columns = df.select_dtypes(include=['number']).columns.tolist() |
| | |
| | |
| | results = [] |
| | for col in target_columns: |
| | if is_avg_query: |
| | if pd.api.types.is_numeric_dtype(df[col]): |
| | value = df[col].mean() |
| | results.append(f"The average {col} is {value:.2f}") |
| | elif is_max_query: |
| | if pd.api.types.is_numeric_dtype(df[col]): |
| | value = df[col].max() |
| | results.append(f"The maximum {col} is {value}") |
| | else: |
| | |
| | value = df[col].max() |
| | results.append(f"The maximum (alphabetically) {col} is '{value}'") |
| | elif is_min_query: |
| | if pd.api.types.is_numeric_dtype(df[col]): |
| | value = df[col].min() |
| | results.append(f"The minimum {col} is {value}") |
| | else: |
| | |
| | value = df[col].min() |
| | results.append(f"The minimum (alphabetically) {col} is '{value}'") |
| | elif is_count_query: |
| | value = len(df) |
| | results.append(f"The total count of rows is {value}") |
| | elif is_unique_query: |
| | value = df[col].nunique() |
| | unique_values = df[col].unique() |
| | unique_str = ", ".join(str(x) for x in unique_values[:5]) |
| | if len(unique_values) > 5: |
| | unique_str += f", ... and {len(unique_values) - 5} more" |
| | results.append(f"There are {value} unique values in {col}: {unique_str}") |
| | |
| | if results: |
| | return "\n".join(results) |
| | |
| | except Exception as e: |
| | print(f"Error processing CSV for statistical query: {e}") |
| | |
| | return None |
| | |
| | |
| | def _handle_statistical_query1(self, query: str, csv_ids: List[str]) -> Optional[str]: |
| | """Handle direct statistical queries without using the LLM.""" |
| | query_lower = query.lower() |
| | |
| | |
| | is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower |
| | is_max_query = "maximum" in query_lower or "max" in query_lower |
| | is_min_query = "minimum" in query_lower or "min" in query_lower |
| | is_count_query = "count" in query_lower or "how many" in query_lower |
| | |
| | if not (is_avg_query or is_max_query or is_min_query or is_count_query): |
| | return None |
| | |
| | |
| | query_words = set(query_lower.replace("?", "").replace(",", "").split()) |
| | |
| | for csv_id in csv_ids: |
| | if csv_id not in self.index_manager.indexes: |
| | continue |
| | |
| | file_path = self.index_manager.indexes[csv_id]["path"] |
| | metadata = self.index_manager.indexes[csv_id]["metadata"] |
| | |
| | try: |
| | df = pd.read_csv(file_path) |
| | |
| | |
| | target_columns = [] |
| | for col in df.columns: |
| | col_lower = col.lower() |
| | |
| | if any(word in col_lower for word in query_words): |
| | target_columns.append(col) |
| | |
| | |
| | if not target_columns: |
| | if "age" in query_lower: |
| | age_cols = [col for col in df.columns if "age" in col.lower()] |
| | if age_cols: |
| | target_columns = age_cols |
| | elif "income" in query_lower or "salary" in query_lower: |
| | income_cols = [col for col in df.columns if any(term in col.lower() |
| | for term in ["income", "salary", "wage", "earnings"])] |
| | if income_cols: |
| | target_columns = income_cols |
| | |
| | |
| | |
| | if not target_columns: |
| | target_columns = df.select_dtypes(include=['number']).columns.tolist() |
| | |
| | |
| | results = [] |
| | for col in target_columns: |
| | if not pd.api.types.is_numeric_dtype(df[col]): |
| | continue |
| | |
| | if is_avg_query: |
| | value = df[col].mean() |
| | results.append(f"The average {col} is {value:.2f}") |
| | elif is_max_query: |
| | value = df[col].max() |
| | results.append(f"The maximum {col} is {value}") |
| | elif is_min_query: |
| | value = df[col].min() |
| | results.append(f"The minimum {col} is {value}") |
| | elif is_count_query: |
| | value = len(df) |
| | results.append(f"The total count of {col} is {value}") |
| | |
| | if results: |
| | return "\n".join(results) |
| | |
| | except Exception as e: |
| | print(f"Error processing CSV for statistical query: {e}") |
| | |
| | return None |
| |
|