bhavika24 commited on
Commit
eb31619
·
verified ·
1 Parent(s): bfa0b78

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +73 -23
engine.py CHANGED
@@ -93,25 +93,56 @@ def load_ai_schema():
93
  # TABLE MATCHING (CORE LOGIC)
94
  # =========================
95
 
96
- def extract_relevant_tables(question):
97
  schema = load_ai_schema()
98
  q = question.lower()
99
 
 
100
  matched = []
101
 
102
  for table, meta in schema.items():
103
- # match table name
104
- if table.lower() in q:
105
- matched.append(table)
106
- continue
107
 
108
- # match column names
 
 
 
 
 
 
 
 
 
109
  for col, _ in meta["columns"]:
110
- if col.lower() in q:
111
- matched.append(table)
112
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- return list(set(matched))[:5]
115
 
116
 
117
 
@@ -137,8 +168,7 @@ def describe_schema():
137
  "• Admissions by year\n\n"
138
  "Just tell me what you want to explore "
139
  )
140
- if not schema:
141
- return "No AI-enabled tables are configured."
142
 
143
 
144
 
@@ -172,24 +202,44 @@ def normalize_time_question(q):
172
 
173
  def is_question_supported(question):
174
  q = question.lower()
 
 
 
 
 
 
 
 
175
 
176
- if any(k in q for k in [
177
- "count", "total", "average", "sum",
178
- "how many", "number of", "trend"
179
- ]):
180
  return True
181
 
 
182
  schema = load_ai_schema()
183
- for table, meta in schema.items():
184
- if table in q:
185
- return True
186
- for col, _ in meta["columns"]:
187
- if col in q:
188
- return True
189
 
190
- return False
 
191
 
 
 
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  # =========================
 
93
  # TABLE MATCHING (CORE LOGIC)
94
  # =========================
95
 
96
+ def extract_relevant_tables(question, max_tables=5):
97
  schema = load_ai_schema()
98
  q = question.lower()
99
 
100
+ tokens = set(q.replace("?", "").replace(",", "").split())
101
  matched = []
102
 
103
  for table, meta in schema.items():
104
+ score = 0
105
+ table_l = table.lower()
 
 
106
 
107
+ # 1️⃣ Table name match (strong signal)
108
+ if table_l in q:
109
+ score += 5
110
+
111
+ # 2️⃣ Description match
112
+ if meta.get("description"):
113
+ desc_words = meta["description"].lower().split()
114
+ score += len(tokens & set(desc_words)) * 2
115
+
116
+ # 3️⃣ Column name matches
117
  for col, _ in meta["columns"]:
118
+ col_l = col.lower()
119
+ if col_l in q:
120
+ score += 3
121
+ elif any(tok in col_l for tok in tokens):
122
+ score += 1
123
+
124
+ # 4️⃣ Weak semantic hints
125
+ semantic_map = {
126
+ "patient": ["patient", "patients"],
127
+ "visit": ["visit", "encounter"],
128
+ "medication": ["drug", "medicine"],
129
+ "admission": ["admit", "admission"],
130
+ "date": ["date", "year", "month"]
131
+ }
132
+
133
+ for key, words in semantic_map.items():
134
+ if any(w in q for w in words) and key in table_l:
135
+ score += 2
136
+
137
+ if score > 0:
138
+ matched.append((table, score))
139
+
140
+ # Sort by relevance
141
+ matched.sort(key=lambda x: x[1], reverse=True)
142
+
143
+ # Return top N tables
144
+ return [t[0] for t in matched[:max_tables]]
145
 
 
146
 
147
 
148
 
 
168
  "• Admissions by year\n\n"
169
  "Just tell me what you want to explore "
170
  )
171
+
 
172
 
173
 
174
 
 
202
 
203
  def is_question_supported(question):
204
  q = question.lower()
205
+ tokens = set(q.replace("?", "").replace(",", "").split())
206
+
207
+ # 1️⃣ Allow analytical intent even if table not mentioned
208
+ analytic_keywords = {
209
+ "count", "total", "average", "avg", "sum",
210
+ "how many", "number of", "trend", "trendline",
211
+ "increase", "decrease", "compare"
212
+ }
213
 
214
+ if any(k in q for k in analytic_keywords):
 
 
 
215
  return True
216
 
217
+ # 2️⃣ Schema-based scoring
218
  schema = load_ai_schema()
219
+ score = 0
 
 
 
 
 
220
 
221
+ for table, meta in schema.items():
222
+ table_l = table.lower()
223
 
224
+ # Table name match
225
+ if table_l in q:
226
+ score += 3
227
 
228
+ # Column name match
229
+ for col, _ in meta["columns"]:
230
+ col_l = col.lower()
231
+ if col_l in q:
232
+ score += 2
233
+ elif any(tok in col_l for tok in tokens):
234
+ score += 1
235
+
236
+ # Description match
237
+ if meta.get("description"):
238
+ desc_tokens = meta["description"].lower().split()
239
+ score += len(tokens & set(desc_tokens))
240
+
241
+ # 3️⃣ Threshold — prevents random questions
242
+ return score >= 2
243
 
244
 
245
  # =========================