Chamin09 commited on
Commit
681cb59
·
verified ·
1 Parent(s): e1f8415

Update indexes/query_engine.py

Browse files
Files changed (1) hide show
  1. indexes/query_engine.py +111 -99
indexes/query_engine.py CHANGED
@@ -1,106 +1,118 @@
1
- def query(self, query_text: str) -> Dict[str, Any]:
2
- """Process a natural language query across CSV files."""
3
- # Find relevant CSV files
4
- relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
5
-
6
- if not relevant_csvs:
7
- return {
8
- "answer": "No relevant CSV files found for your query.",
9
- "sources": []
10
- }
11
-
12
- # Check for direct statistical queries
13
- direct_answer = self._handle_statistical_query(query_text, relevant_csvs)
14
- if direct_answer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  return {
16
- "answer": direct_answer,
17
  "sources": self._get_sources(relevant_csvs)
18
  }
19
 
20
- # If not a direct statistical query, use the regular approach
21
- context = self._prepare_context(query_text, relevant_csvs)
22
- prompt = self._generate_prompt(query_text, context)
23
- response = self.llm.complete(prompt)
24
-
25
- return {
26
- "answer": response.text,
27
- "sources": self._get_sources(relevant_csvs)
28
- }
29
-
30
- def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
31
- """Handle direct statistical queries without using the LLM."""
32
- query_lower = query.lower()
33
-
34
- # Detect query type
35
- is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
36
- is_max_query = "maximum" in query_lower or "max" in query_lower
37
- is_min_query = "minimum" in query_lower or "min" in query_lower
38
- is_count_query = "count" in query_lower or "how many" in query_lower
39
-
40
- if not (is_avg_query or is_max_query or is_min_query or is_count_query):
41
- return None # Not a statistical query
42
-
43
- # Extract potential column names from query
44
- query_words = set(query_lower.replace("?", "").replace(",", "").split())
45
-
46
- for csv_id in csv_ids:
47
- if csv_id not in self.index_manager.indexes:
48
- continue
49
-
50
- file_path = self.index_manager.indexes[csv_id]["path"]
51
- metadata = self.index_manager.indexes[csv_id]["metadata"]
52
 
53
- try:
54
- df = pd.read_csv(file_path)
55
-
56
- # Find relevant columns based on query
57
- target_columns = []
58
- for col in df.columns:
59
- col_lower = col.lower()
60
- # Check if column name appears in query
61
- if any(word in col_lower for word in query_words):
62
- target_columns.append(col)
63
-
64
- # If no direct matches, try to infer from common column names
65
- if not target_columns:
66
- if "age" in query_lower:
67
- age_cols = [col for col in df.columns if "age" in col.lower()]
68
- if age_cols:
69
- target_columns = age_cols
70
- elif "income" in query_lower or "salary" in query_lower:
71
- income_cols = [col for col in df.columns if any(term in col.lower()
72
- for term in ["income", "salary", "wage", "earnings"])]
73
- if income_cols:
74
- target_columns = income_cols
75
- # Add more common column inferences as needed
76
-
77
- # If still no matches, use all numeric columns
78
- if not target_columns:
79
- target_columns = df.select_dtypes(include=['number']).columns.tolist()
80
-
81
- # Perform the requested calculation
82
- results = []
83
- for col in target_columns:
84
- if not pd.api.types.is_numeric_dtype(df[col]):
85
- continue
86
-
87
- if is_avg_query:
88
- value = df[col].mean()
89
- results.append(f"The average {col} is {value:.2f}")
90
- elif is_max_query:
91
- value = df[col].max()
92
- results.append(f"The maximum {col} is {value}")
93
- elif is_min_query:
94
- value = df[col].min()
95
- results.append(f"The minimum {col} is {value}")
96
- elif is_count_query:
97
- value = len(df)
98
- results.append(f"The total count of {col} is {value}")
99
 
100
- if results:
101
- return "\n".join(results)
102
 
103
- except Exception as e:
104
- print(f"Error processing CSV for statistical query: {e}")
105
-
106
- return None # No results found
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, Optional
2
+ import pandas as pd
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import numpy as np
5
+
6
+ class CSVQueryEngine:
7
+
8
+ def __init__(self, index_manager, llm):
9
+ """Initialize with index manager and language model."""
10
+ self.index_manager = index_manager
11
+ self.llm = llm
12
+
13
+ def query(self, query_text: str) -> Dict[str, Any]:
14
+ """Process a natural language query across CSV files."""
15
+ # Find relevant CSV files
16
+ relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
17
+
18
+ if not relevant_csvs:
19
+ return {
20
+ "answer": "No relevant CSV files found for your query.",
21
+ "sources": []
22
+ }
23
+
24
+ # Check for direct statistical queries
25
+ direct_answer = self._handle_statistical_query(query_text, relevant_csvs)
26
+ if direct_answer:
27
+ return {
28
+ "answer": direct_answer,
29
+ "sources": self._get_sources(relevant_csvs)
30
+ }
31
+
32
+ # If not a direct statistical query, use the regular approach
33
+ context = self._prepare_context(query_text, relevant_csvs)
34
+ prompt = self._generate_prompt(query_text, context)
35
+ response = self.llm.complete(prompt)
36
+
37
  return {
38
+ "answer": response.text,
39
  "sources": self._get_sources(relevant_csvs)
40
  }
41
 
42
+ def _handle_statistical_query(self, query: str, csv_ids: List[str]) -> Optional[str]:
43
+ """Handle direct statistical queries without using the LLM."""
44
+ query_lower = query.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Detect query type
47
+ is_avg_query = "average" in query_lower or "mean" in query_lower or "avg" in query_lower
48
+ is_max_query = "maximum" in query_lower or "max" in query_lower
49
+ is_min_query = "minimum" in query_lower or "min" in query_lower
50
+ is_count_query = "count" in query_lower or "how many" in query_lower
51
+
52
+ if not (is_avg_query or is_max_query or is_min_query or is_count_query):
53
+ return None # Not a statistical query
54
+
55
+ # Extract potential column names from query
56
+ query_words = set(query_lower.replace("?", "").replace(",", "").split())
57
+
58
+ for csv_id in csv_ids:
59
+ if csv_id not in self.index_manager.indexes:
60
+ continue
61
+
62
+ file_path = self.index_manager.indexes[csv_id]["path"]
63
+ metadata = self.index_manager.indexes[csv_id]["metadata"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ try:
66
+ df = pd.read_csv(file_path)
67
 
68
+ # Find relevant columns based on query
69
+ target_columns = []
70
+ for col in df.columns:
71
+ col_lower = col.lower()
72
+ # Check if column name appears in query
73
+ if any(word in col_lower for word in query_words):
74
+ target_columns.append(col)
75
+
76
+ # If no direct matches, try to infer from common column names
77
+ if not target_columns:
78
+ if "age" in query_lower:
79
+ age_cols = [col for col in df.columns if "age" in col.lower()]
80
+ if age_cols:
81
+ target_columns = age_cols
82
+ elif "income" in query_lower or "salary" in query_lower:
83
+ income_cols = [col for col in df.columns if any(term in col.lower()
84
+ for term in ["income", "salary", "wage", "earnings"])]
85
+ if income_cols:
86
+ target_columns = income_cols
87
+ # Add more common column inferences as needed
88
+
89
+ # If still no matches, use all numeric columns
90
+ if not target_columns:
91
+ target_columns = df.select_dtypes(include=['number']).columns.tolist()
92
+
93
+ # Perform the requested calculation
94
+ results = []
95
+ for col in target_columns:
96
+ if not pd.api.types.is_numeric_dtype(df[col]):
97
+ continue
98
+
99
+ if is_avg_query:
100
+ value = df[col].mean()
101
+ results.append(f"The average {col} is {value:.2f}")
102
+ elif is_max_query:
103
+ value = df[col].max()
104
+ results.append(f"The maximum {col} is {value}")
105
+ elif is_min_query:
106
+ value = df[col].min()
107
+ results.append(f"The minimum {col} is {value}")
108
+ elif is_count_query:
109
+ value = len(df)
110
+ results.append(f"The total count of {col} is {value}")
111
+
112
+ if results:
113
+ return "\n".join(results)
114
+
115
+ except Exception as e:
116
+ print(f"Error processing CSV for statistical query: {e}")
117
+
118
+ return None # No results found