Chamin09 commited on
Commit
e1f8415
·
verified ·
1 Parent(s): 53319fb

Create query_engine.py

Browse files
Files changed (1) hide show
  1. indexes/query_engine.py +106 -0
indexes/query_engine.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def query(self, query_text: str) -> Dict[str, Any]:
2
+ """Process a natural language query across CSV files."""
3
+ # Find relevant CSV files
4
+ relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
5
+
6
+ if not relevant_csvs:
7
+ return {
8
+ "answer": "No relevant CSV files found for your query.",
9
+ "sources": []
10
+ }
11
+
12
+ # Check for direct statistical queries
13
+ direct_answer = self._handle_statistical_query(query_text, relevant_csvs)
14
+ if direct_answer:
15
+ return {
16
+ "answer": direct_answer,
17
+ "sources": self._get_sources(relevant_csvs)
18
+ }
19
+
20
+ # If not a direct statistical query, use the regular approach
21
+ context = self._prepare_context(query_text, relevant_csvs)
22
+ prompt = self._generate_prompt(query_text, context)
23
+ response = self.llm.complete(prompt)
24
+
25
+ return {
26
+ "answer": response.text,
27
+ "sources": self._get_sources(relevant_csvs)
28
+ }
29
+
30
+ def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
31
+ """Handle direct statistical queries without using the LLM."""
32
+ query_lower = query.lower()
33
+
34
+ # Detect query type
35
+ is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
36
+ is_max_query = "maximum" in query_lower or "max" in query_lower
37
+ is_min_query = "minimum" in query_lower or "min" in query_lower
38
+ is_count_query = "count" in query_lower or "how many" in query_lower
39
+
40
+ if not (is_avg_query or is_max_query or is_min_query or is_count_query):
41
+ return None # Not a statistical query
42
+
43
+ # Extract potential column names from query
44
+ query_words = set(query_lower.replace("?", "").replace(",", "").split())
45
+
46
+ for csv_id in csv_ids:
47
+ if csv_id not in self.index_manager.indexes:
48
+ continue
49
+
50
+ file_path = self.index_manager.indexes[csv_id]["path"]
51
+ metadata = self.index_manager.indexes[csv_id]["metadata"]
52
+
53
+ try:
54
+ df = pd.read_csv(file_path)
55
+
56
+ # Find relevant columns based on query
57
+ target_columns = []
58
+ for col in df.columns:
59
+ col_lower = col.lower()
60
+ # Check if column name appears in query
61
+ if any(word in col_lower for word in query_words):
62
+ target_columns.append(col)
63
+
64
+ # If no direct matches, try to infer from common column names
65
+ if not target_columns:
66
+ if "age" in query_lower:
67
+ age_cols = [col for col in df.columns if "age" in col.lower()]
68
+ if age_cols:
69
+ target_columns = age_cols
70
+ elif "income" in query_lower or "salary" in query_lower:
71
+ income_cols = [col for col in df.columns if any(term in col.lower()
72
+ for term in ["income", "salary", "wage", "earnings"])]
73
+ if income_cols:
74
+ target_columns = income_cols
75
+ # Add more common column inferences as needed
76
+
77
+ # If still no matches, use all numeric columns
78
+ if not target_columns:
79
+ target_columns = df.select_dtypes(include=['number']).columns.tolist()
80
+
81
+ # Perform the requested calculation
82
+ results = []
83
+ for col in target_columns:
84
+ if not pd.api.types.is_numeric_dtype(df[col]):
85
+ continue
86
+
87
+ if is_avg_query:
88
+ value = df[col].mean()
89
+ results.append(f"The average {col} is {value:.2f}")
90
+ elif is_max_query:
91
+ value = df[col].max()
92
+ results.append(f"The maximum {col} is {value}")
93
+ elif is_min_query:
94
+ value = df[col].min()
95
+ results.append(f"The minimum {col} is {value}")
96
+ elif is_count_query:
97
+ value = len(df)
98
+ results.append(f"The total count of {col} is {value}")
99
+
100
+ if results:
101
+ return "\n".join(results)
102
+
103
+ except Exception as e:
104
+ print(f"Error processing CSV for statistical query: {e}")
105
+
106
+ return None # No results found