bhavika24 commited on
Commit
4897d3e
·
verified ·
1 Parent(s): eb31619

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +84 -47
engine.py CHANGED
@@ -47,6 +47,12 @@ KNOWN_TERMS = [
47
  "admitted", "admission",
48
  "year", "month", "last", "recent", "today"
49
  ]
 
 
 
 
 
 
50
 
51
  def correct_spelling(q):
52
  words = q.split()
@@ -93,27 +99,42 @@ def load_ai_schema():
93
  # TABLE MATCHING (CORE LOGIC)
94
  # =========================
95
 
96
- def extract_relevant_tables(question, max_tables=5):
97
  schema = load_ai_schema()
98
  q = question.lower()
99
-
100
  tokens = set(q.replace("?", "").replace(",", "").split())
 
101
  matched = []
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  for table, meta in schema.items():
104
  score = 0
105
  table_l = table.lower()
106
 
107
- # 1️⃣ Table name match (strong signal)
108
  if table_l in q:
109
- score += 5
110
-
111
- # 2️⃣ Description match
112
- if meta.get("description"):
113
- desc_words = meta["description"].lower().split()
114
- score += len(tokens & set(desc_words)) * 2
115
 
116
- # 3️⃣ Column name matches
117
  for col, _ in meta["columns"]:
118
  col_l = col.lower()
119
  if col_l in q:
@@ -121,46 +142,52 @@ def extract_relevant_tables(question, max_tables=5):
121
  elif any(tok in col_l for tok in tokens):
122
  score += 1
123
 
124
- # 4️⃣ Weak semantic hints
125
- semantic_map = {
126
- "patient": ["patient", "patients"],
127
- "visit": ["visit", "encounter"],
128
- "medication": ["drug", "medicine"],
129
- "admission": ["admit", "admission"],
130
- "date": ["date", "year", "month"]
131
- }
132
 
133
- for key, words in semantic_map.items():
134
- if any(w in q for w in words) and key in table_l:
135
- score += 2
 
136
 
137
- if score > 0:
 
138
  matched.append((table, score))
139
 
140
  # Sort by relevance
141
  matched.sort(key=lambda x: x[1], reverse=True)
142
 
143
- # Return top N tables
144
  return [t[0] for t in matched[:max_tables]]
145
 
146
 
147
-
148
-
149
  # =========================
150
  # HUMAN SCHEMA DESCRIPTION
151
  # =========================
152
 
153
- def describe_schema():
154
  schema = load_ai_schema()
 
155
 
156
- response = "Heres the data I currently have access to:\n\n"
157
 
158
- for table, meta in schema.items():
 
 
 
159
  response += f"• **{table.capitalize()}** — {meta['description']}\n"
160
- for col, desc in meta["columns"]:
 
161
  response += f" - {col}: {desc}\n"
 
 
162
  response += "\n"
163
 
 
 
 
 
164
  response += (
165
  "You can ask things like:\n"
166
  "• How many patients are there?\n"
@@ -168,9 +195,6 @@ def describe_schema():
168
  "• Admissions by year\n\n"
169
  "Just tell me what you want to explore "
170
  )
171
-
172
-
173
-
174
 
175
  return response
176
 
@@ -204,21 +228,21 @@ def is_question_supported(question):
204
  q = question.lower()
205
  tokens = set(q.replace("?", "").replace(",", "").split())
206
 
207
- # 1️⃣ Allow analytical intent even if table not mentioned
208
  analytic_keywords = {
209
  "count", "total", "average", "avg", "sum",
210
- "how many", "number of", "trend", "trendline",
211
- "increase", "decrease", "compare"
212
  }
213
 
214
  if any(k in q for k in analytic_keywords):
215
  return True
216
 
217
- # 2️⃣ Schema-based scoring
218
  schema = load_ai_schema()
219
- score = 0
220
 
221
  for table, meta in schema.items():
 
222
  table_l = table.lower()
223
 
224
  # Table name match
@@ -235,11 +259,15 @@ def is_question_supported(question):
235
 
236
  # Description match
237
  if meta.get("description"):
238
- desc_tokens = meta["description"].lower().split()
239
- score += len(tokens & set(desc_tokens))
 
 
 
 
 
 
240
 
241
- # 3️⃣ Threshold — prevents random questions
242
- return score >= 2
243
 
244
 
245
  # =========================
@@ -252,7 +280,11 @@ def build_prompt(question):
252
  if matched:
253
  schema = {t: load_ai_schema()[t] for t in matched}
254
  else:
255
- schema = load_ai_schema() # fallback if nothing matched
 
 
 
 
256
 
257
  prompt = """
258
  You are a hospital SQL assistant.
@@ -323,9 +355,6 @@ def has_underlying_data(sql):
323
  cur = conn.cursor()
324
  return cur.execute(test_sql).fetchone() is not None
325
 
326
-
327
-
328
-
329
  # =========================
330
  # PATIENT SUMMARY
331
  # =========================
@@ -466,12 +495,20 @@ def process_question(question):
466
  # ----------------------------------
467
  # Generate SQL
468
  # ----------------------------------
469
- sql = call_llm(build_prompt(question))
 
 
 
 
 
 
 
 
470
 
471
  if sql == "NOT_ANSWERABLE":
472
  return {
473
  "status": "ok",
474
- "message": "I dont have enough data to answer that.",
475
  "data": []
476
  }
477
 
 
47
  "admitted", "admission",
48
  "year", "month", "last", "recent", "today"
49
  ]
50
+ DOMAIN_ALIASES = {
51
+ "consultant": ["provider", "encounter"],
52
+ "doctor": ["provider"],
53
+ "appointment": ["encounter"],
54
+ "visit": ["encounter"],
55
+ }
56
 
57
  def correct_spelling(q):
58
  words = q.split()
 
99
  # TABLE MATCHING (CORE LOGIC)
100
  # =========================
101
 
102
+ def extract_relevant_tables(question, max_tables=4):
103
  schema = load_ai_schema()
104
  q = question.lower()
 
105
  tokens = set(q.replace("?", "").replace(",", "").split())
106
+
107
  matched = []
108
 
109
+ # Lightweight intent hints (NO hard dependency)
110
+ DOMAIN_HINTS = {
111
+ "consultant": ["encounters"],
112
+ "doctor": ["encounters"],
113
+ "visit": ["encounters"],
114
+ "appointment": ["encounters"],
115
+ "patient": ["patients"],
116
+ "medication": ["medications"],
117
+ "drug": ["medications"],
118
+ "condition": ["conditions"],
119
+ "diagnosis": ["conditions"]
120
+ }
121
+
122
+ # Early exit threshold - if we find a perfect match, we can stop early
123
+ VERY_HIGH_SCORE = 10
124
+
125
  for table, meta in schema.items():
126
  score = 0
127
  table_l = table.lower()
128
 
129
+ # 1️⃣ Strong signal: table name (exact match is very high confidence)
130
  if table_l in q:
131
+ score += 6
132
+ # Early exit optimization: if exact table match found, prioritize it
133
+ if score >= VERY_HIGH_SCORE:
134
+ matched.append((table, score))
135
+ continue
 
136
 
137
+ # 2️⃣ Column relevance
138
  for col, _ in meta["columns"]:
139
  col_l = col.lower()
140
  if col_l in q:
 
142
  elif any(tok in col_l for tok in tokens):
143
  score += 1
144
 
145
+ # 3️⃣ Description relevance
146
+ if meta.get("description"):
147
+ desc_tokens = set(meta["description"].lower().split())
148
+ score += len(tokens & desc_tokens)
 
 
 
 
149
 
150
+ # 4️⃣ Semantic intent mapping (important)
151
+ for intent, tables in DOMAIN_HINTS.items():
152
+ if intent in q and table_l in tables:
153
+ score += 5
154
 
155
+ # 5️⃣ Only add if meets minimum threshold (prevents low-quality matches)
156
+ if score >= 3:
157
  matched.append((table, score))
158
 
159
  # Sort by relevance
160
  matched.sort(key=lambda x: x[1], reverse=True)
161
 
 
162
  return [t[0] for t in matched[:max_tables]]
163
 
164
 
 
 
165
  # =========================
166
  # HUMAN SCHEMA DESCRIPTION
167
  # =========================
168
 
169
+ def describe_schema(max_tables=10):
170
  schema = load_ai_schema()
171
+ total_tables = len(schema)
172
 
173
+ response = f"Here's the data I currently have access to ({total_tables} tables):\n\n"
174
 
175
+ # Show only top N tables to avoid overwhelming output
176
+ shown_tables = list(schema.items())[:max_tables]
177
+
178
+ for table, meta in shown_tables:
179
  response += f"• **{table.capitalize()}** — {meta['description']}\n"
180
+ # Show only first 5 columns per table
181
+ for col, desc in list(meta["columns"])[:5]:
182
  response += f" - {col}: {desc}\n"
183
+ if len(meta["columns"]) > 5:
184
+ response += f" ... and {len(meta['columns']) - 5} more columns\n"
185
  response += "\n"
186
 
187
+ if total_tables > max_tables:
188
+ response += f"\n... and {total_tables - max_tables} more tables.\n"
189
+ response += "Ask about a specific table to see its details.\n\n"
190
+
191
  response += (
192
  "You can ask things like:\n"
193
  "• How many patients are there?\n"
 
195
  "• Admissions by year\n\n"
196
  "Just tell me what you want to explore "
197
  )
 
 
 
198
 
199
  return response
200
 
 
228
  q = question.lower()
229
  tokens = set(q.replace("?", "").replace(",", "").split())
230
 
231
+ # 1️⃣ Allow analytical intent even without table names
232
  analytic_keywords = {
233
  "count", "total", "average", "avg", "sum",
234
+ "how many", "number of", "trend",
235
+ "increase", "decrease", "compare", "more than", "less than"
236
  }
237
 
238
  if any(k in q for k in analytic_keywords):
239
  return True
240
 
241
+ # 2️⃣ Check schema relevance (table-by-table)
242
  schema = load_ai_schema()
 
243
 
244
  for table, meta in schema.items():
245
+ score = 0
246
  table_l = table.lower()
247
 
248
  # Table name match
 
259
 
260
  # Description match
261
  if meta.get("description"):
262
+ desc_tokens = set(meta["description"].lower().split())
263
+ score += len(tokens & desc_tokens)
264
+
265
+ # ✅ If any table is relevant enough → supported
266
+ if score >= 2:
267
+ return True
268
+
269
+ return False
270
 
 
 
271
 
272
 
273
  # =========================
 
280
  if matched:
281
  schema = {t: load_ai_schema()[t] for t in matched}
282
  else:
283
+ # 🚫 Don't send all 100+ tables! Return a helpful error instead
284
+ raise ValueError(
285
+ "I couldn't find any relevant tables for your question. "
286
+ "Please try mentioning a specific table name or use 'what data' to see available tables."
287
+ )
288
 
289
  prompt = """
290
  You are a hospital SQL assistant.
 
355
  cur = conn.cursor()
356
  return cur.execute(test_sql).fetchone() is not None
357
 
 
 
 
358
  # =========================
359
  # PATIENT SUMMARY
360
  # =========================
 
495
  # ----------------------------------
496
  # Generate SQL
497
  # ----------------------------------
498
+ try:
499
+ sql = call_llm(build_prompt(question))
500
+ except ValueError as e:
501
+ # Handle case where no relevant tables found
502
+ return {
503
+ "status": "ok",
504
+ "message": str(e),
505
+ "data": []
506
+ }
507
 
508
  if sql == "NOT_ANSWERABLE":
509
  return {
510
  "status": "ok",
511
+ "message": "I don't have enough data to answer that.",
512
  "data": []
513
  }
514