File size: 15,533 Bytes
681cb59
 
 
 
 
 
 
 
 
 
 
 
df5683f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681cb59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f8415
681cb59
e1f8415
 
 
fb82bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df5683f
681cb59
 
 
e1f8415
df5683f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681cb59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1f8415
681cb59
 
e1f8415
681cb59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
from typing import Dict, List, Any, Optional
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

class CSVQueryEngine:

    def __init__(self, index_manager, llm):
        """Initialize with index manager and language model."""
        self.index_manager = index_manager
        self.llm = llm

    def _prepare_context(self, query: str, csv_ids: List[str]) -> str:
        """Prepare context from relevant CSV files."""
        context_parts = []
        
        for csv_id in csv_ids:
            # Get metadata
            if csv_id not in self.index_manager.indexes:
                continue
                
            metadata = self.index_manager.indexes[csv_id]["metadata"]
            file_path = self.index_manager.indexes[csv_id]["path"]
            
            # Add CSV metadata
            context_parts.append(f"CSV File: {metadata['filename']}")
            context_parts.append(f"Columns: {', '.join(metadata['columns'])}")
            context_parts.append(f"Row Count: {metadata['row_count']}")
            
            # Add sample data
            try:
                df = pd.read_csv(file_path)
                context_parts.append("\nSample Data:")
                context_parts.append(df.head(5).to_string())
                
                # Add some basic statistics for numeric columns
                context_parts.append("\nNumeric Column Statistics:")
                numeric_cols = df.select_dtypes(include=['number']).columns
                for col in numeric_cols:
                    stats = df[col].describe()
                    context_parts.append(f"{col} - mean: {stats['mean']:.2f}, min: {stats['min']:.2f}, max: {stats['max']:.2f}")
                
                # Add categorical column information
                categorical_cols = df.select_dtypes(include=['object', 'category']).columns
                if len(categorical_cols) > 0:
                    context_parts.append("\nCategorical Column Information:")
                    for col in categorical_cols:
                        value_counts = df[col].value_counts().head(5)
                        context_parts.append(f"{col} - unique values: {df[col].nunique()}, top values: {', '.join(value_counts.index.astype(str))}")
                
                # Add date information if present
                date_cols = []
                for col in df.columns:
                    try:
                        if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all():
                            date_cols.append(col)
                    except:
                        pass
                
                if date_cols:
                    context_parts.append("\nDate Column Information:")
                    for col in date_cols:
                        if not pd.api.types.is_datetime64_dtype(df[col]):
                            df[col] = pd.to_datetime(df[col], errors='coerce')
                        context_parts.append(f"{col} - range: {df[col].min()} to {df[col].max()}")
                
            except Exception as e:
                context_parts.append(f"Error reading CSV: {str(e)}")
        
        return "\n\n".join(context_parts)
    
    def _generate_prompt(self, query: str, context: str) -> str:
        """Generate a prompt for the LLM."""
        return f"""You are an AI assistant specialized in analyzing CSV data.
        Your goal is to help users understand their data and extract insights.
        
        Below is information about CSV files that might help answer the query:
        
        {context}
        
        User Query: {query}
        
        Please provide a comprehensive and accurate answer based on the data.
        If calculations are needed, explain your process.
        If the data doesn't contain information to answer the query, say so clearly.
        
        Answer:"""

    def query(self, query_text: str) -> Dict[str, Any]:
        """Process a natural language query across CSV files."""
        # Find relevant CSV files
        relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
        
        if not relevant_csvs:
            return {
                "answer": "No relevant CSV files found for your query.",
                "sources": []
            }
        
        # Check for direct statistical queries
        direct_answer = self._handle_statistical_query(query_text, relevant_csvs)
        if direct_answer:
            return {
                "answer": direct_answer,
                "sources": self._get_sources(relevant_csvs)
            }
        
        # If not a direct statistical query, use the regular approach
        context = self._prepare_context(query_text, relevant_csvs)
        prompt = self._generate_prompt(query_text, context)
        response = self.llm.complete(prompt)
        
        return {
            "answer": response.text,
            "sources": self._get_sources(relevant_csvs)
        }
    
    def _get_sources(self, csv_ids: List[str]) -> List[Dict[str, str]]:
        """Get source information for the response."""
        sources = []
        
        for csv_id in csv_ids:
            if csv_id not in self.index_manager.indexes:
                continue
                
            metadata = self.index_manager.indexes[csv_id]["metadata"]
            sources.append({
                "csv": metadata["filename"],
                "columns": ", ".join(metadata["columns"][:5]) + ("..." if len(metadata["columns"]) > 5 else "")
            })
        
        return sources
    
    
    def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
        """Handle direct statistical queries without using the LLM."""
        query_lower = query.lower()
        
        # Detect query type
        is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
        is_max_query = "maximum" in query_lower or "max" in query_lower
        is_min_query = "minimum" in query_lower or "min" in query_lower
        is_count_query = "count" in query_lower or "how many" in query_lower
        is_unique_query = "unique" in query_lower or "distinct" in query_lower
        
        if not (is_avg_query or is_max_query or is_min_query or is_count_query or is_unique_query):
            return None  # Not a statistical query
        
        # Extract potential column names from query
        query_words = set(query_lower.replace("?", "").replace(",", "").split())
        
        for csv_id in csv_ids:
            if csv_id not in self.index_manager.indexes:
                continue
                
            file_path = self.index_manager.indexes[csv_id]["path"]
            metadata = self.index_manager.indexes[csv_id]["metadata"]
            
            try:
                df = pd.read_csv(file_path)
                
                # Find relevant columns based on query
                target_columns = []
                for col in df.columns:
                    col_lower = col.lower()
                    # Check if column name appears in query
                    if any(word in col_lower for word in query_words) or col_lower in query_lower:
                        target_columns.append(col)
                
                # If no direct matches but query mentions specific types of data
                if not target_columns:
                    if any(word in query_lower for word in ["age", "old", "young"]):
                        age_cols = [col for col in df.columns if "age" in col.lower()]
                        if age_cols:
                            target_columns = age_cols
                    elif any(word in query_lower for word in ["class", "category", "type", "grade"]):
                        class_cols = [col for col in df.columns if any(term in col.lower() 
                                    for term in ["class", "category", "type", "grade"])]
                        if class_cols:
                            target_columns = class_cols
                    elif any(word in query_lower for word in ["income", "salary", "money", "price", "cost"]):
                        income_cols = [col for col in df.columns if any(term in col.lower() 
                                    for term in ["income", "salary", "wage", "earnings", "price", "cost"])]
                        if income_cols:
                            target_columns = income_cols
                    elif any(word in query_lower for word in ["date", "time", "year", "month", "day"]):
                        date_cols = []
                        for col in df.columns:
                            try:
                                if pd.api.types.is_datetime64_dtype(df[col]) or pd.to_datetime(df[col], errors='coerce').notna().all():
                                    date_cols.append(col)
                            except:
                                pass
                        if date_cols:
                            target_columns = date_cols
                
                # If still no matches, use all columns for count/unique queries,
                # or numeric columns for other statistical queries
                if not target_columns:
                    if is_count_query or is_unique_query:
                        target_columns = df.columns.tolist()
                    else:
                        target_columns = df.select_dtypes(include=['number']).columns.tolist()
                
                # Perform the requested calculation
                results = []
                for col in target_columns:
                    if is_avg_query:
                        if pd.api.types.is_numeric_dtype(df[col]):
                            value = df[col].mean()
                            results.append(f"The average {col} is {value:.2f}")
                    elif is_max_query:
                        if pd.api.types.is_numeric_dtype(df[col]):
                            value = df[col].max()
                            results.append(f"The maximum {col} is {value}")
                        else:
                            # For non-numeric columns, show the maximum in alphabetical order
                            value = df[col].max()
                            results.append(f"The maximum (alphabetically) {col} is '{value}'")
                    elif is_min_query:
                        if pd.api.types.is_numeric_dtype(df[col]):
                            value = df[col].min()
                            results.append(f"The minimum {col} is {value}")
                        else:
                            # For non-numeric columns, show the minimum in alphabetical order
                            value = df[col].min()
                            results.append(f"The minimum (alphabetically) {col} is '{value}'")
                    elif is_count_query:
                        value = len(df)
                        results.append(f"The total count of rows is {value}")
                    elif is_unique_query:
                        value = df[col].nunique()
                        unique_values = df[col].unique()
                        unique_str = ", ".join(str(x) for x in unique_values[:5])
                        if len(unique_values) > 5:
                            unique_str += f", ... and {len(unique_values) - 5} more"
                        results.append(f"There are {value} unique values in {col}: {unique_str}")
                
                if results:
                    return "\n".join(results)
                    
            except Exception as e:
                print(f"Error processing CSV for statistical query: {e}")
        
        return None  # No results found
    
    
    def _handle_statistical_query1(self, query: str, csv_ids: List[str]) -> Optional[str]:
        """Handle direct statistical queries without using the LLM."""
        query_lower = query.lower()
        
        # Detect query type
        is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
        is_max_query = "maximum" in query_lower or "max" in query_lower
        is_min_query = "minimum" in query_lower or "min" in query_lower
        is_count_query = "count" in query_lower or "how many" in query_lower
        
        if not (is_avg_query or is_max_query or is_min_query or is_count_query):
            return None  # Not a statistical query
        
        # Extract potential column names from query
        query_words = set(query_lower.replace("?", "").replace(",", "").split())
        
        for csv_id in csv_ids:
            if csv_id not in self.index_manager.indexes:
                continue
                
            file_path = self.index_manager.indexes[csv_id]["path"]
            metadata = self.index_manager.indexes[csv_id]["metadata"]
            
            try:
                df = pd.read_csv(file_path)
                
                # Find relevant columns based on query
                target_columns = []
                for col in df.columns:
                    col_lower = col.lower()
                    # Check if column name appears in query
                    if any(word in col_lower for word in query_words):
                        target_columns.append(col)
                
                # If no direct matches, try to infer from common column names
                if not target_columns:
                    if "age" in query_lower:
                        age_cols = [col for col in df.columns if "age" in col.lower()]
                        if age_cols:
                            target_columns = age_cols
                    elif "income" in query_lower or "salary" in query_lower:
                        income_cols = [col for col in df.columns if any(term in col.lower() 
                                       for term in ["income", "salary", "wage", "earnings"])]
                        if income_cols:
                            target_columns = income_cols
                    # Add more common column inferences as needed
                
                # If still no matches, use all numeric columns
                if not target_columns:
                    target_columns = df.select_dtypes(include=['number']).columns.tolist()
                
                # Perform the requested calculation
                results = []
                for col in target_columns:
                    if not pd.api.types.is_numeric_dtype(df[col]):
                        continue
                        
                    if is_avg_query:
                        value = df[col].mean()
                        results.append(f"The average {col} is {value:.2f}")
                    elif is_max_query:
                        value = df[col].max()
                        results.append(f"The maximum {col} is {value}")
                    elif is_min_query:
                        value = df[col].min()
                        results.append(f"The minimum {col} is {value}")
                    elif is_count_query:
                        value = len(df)
                        results.append(f"The total count of {col} is {value}")
                
                if results:
                    return "\n".join(results)
                    
            except Exception as e:
                print(f"Error processing CSV for statistical query: {e}")
        
        return None  # No results found