bhavika24 commited on
Commit
82d3e6d
·
verified ·
1 Parent(s): 4897d3e

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +93 -38
engine.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import sqlite3
3
  from openai import OpenAI
4
  from difflib import get_close_matches
@@ -111,6 +112,8 @@ def extract_relevant_tables(question, max_tables=4):
111
  "consultant": ["encounters"],
112
  "doctor": ["encounters"],
113
  "visit": ["encounters"],
 
 
114
  "appointment": ["encounters"],
115
  "patient": ["patients"],
116
  "medication": ["medications"],
@@ -142,18 +145,23 @@ def extract_relevant_tables(question, max_tables=4):
142
  elif any(tok in col_l for tok in tokens):
143
  score += 1
144
 
145
- # 3️⃣ Description relevance
146
  if meta.get("description"):
147
  desc_tokens = set(meta["description"].lower().split())
148
- score += len(tokens & desc_tokens)
 
 
 
 
149
 
150
- # 4️⃣ Semantic intent mapping (important)
151
  for intent, tables in DOMAIN_HINTS.items():
152
  if intent in q and table_l in tables:
153
  score += 5
154
 
155
  # 5️⃣ Only add if meets minimum threshold (prevents low-quality matches)
156
- if score >= 3:
 
157
  matched.append((table, score))
158
 
159
  # Sort by relevance
@@ -292,8 +300,12 @@ You are a hospital SQL assistant.
292
  Rules:
293
  - Use only SELECT
294
  - SQLite syntax
 
295
  - Use only listed tables/columns
296
  - Return ONLY SQL or NOT_ANSWERABLE
 
 
 
297
  """
298
 
299
  for table, meta in schema.items():
@@ -302,6 +314,7 @@ Rules:
302
  prompt += f"- {col}: {desc}\n"
303
 
304
  prompt += f"\nQuestion: {question}\n"
 
305
  return prompt
306
 
307
 
@@ -325,6 +338,31 @@ def sanitize_sql(sql):
325
  sql = sql.split(";")[0]
326
  return sql.replace("\n", " ").strip()
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  def validate_sql(sql):
329
  if not sql.lower().startswith("select"):
330
  raise Exception("Only SELECT allowed")
@@ -362,42 +400,47 @@ def has_underlying_data(sql):
362
  def build_table_summary(table_name):
363
  cur = conn.cursor()
364
 
365
- # Total rows
366
  total = cur.execute(
367
  f"SELECT COUNT(*) FROM {table_name}"
368
  ).fetchone()[0]
369
 
370
- # Get column info
371
- columns = cur.execute(
372
- f"PRAGMA table_info({table_name})"
373
- ).fetchall()
 
 
374
 
375
- summary = f"Heres a summary of **{table_name}**:\n\n"
376
  summary += f"• Total records: {total}\n"
377
 
378
- # Try to summarize categorical columns
379
- for col in columns:
380
- col_name = col[1]
381
- col_type = col[2].lower()
382
-
383
- if col_type in ("text", "varchar"):
384
- try:
385
- rows = cur.execute(
386
- f"""
387
- SELECT {col_name}, COUNT(*)
388
- FROM {table_name}
389
- GROUP BY {col_name}
390
- ORDER BY COUNT(*) DESC
391
- LIMIT 5
392
- """
393
- ).fetchall()
394
-
395
- if rows:
396
- summary += f"\n• {col_name.capitalize()} breakdown:\n"
397
- for val, count in rows:
398
- summary += f" - {val}: {count}\n"
399
- except:
400
- pass # ignore columns that can't be grouped
 
 
 
401
 
402
  summary += "\nYou can ask more detailed questions about this data."
403
 
@@ -453,13 +496,22 @@ def process_question(question):
453
  "message": build_table_summary(matched_tables[0]),
454
  "data": []
455
  }
456
- if len(matched_tables) > 1:
 
 
 
 
 
 
 
 
 
457
  return {
458
  "status": "ok",
459
  "message": (
460
- "Your question matches multiple datasets:\n"
461
- + "\n".join(f"- {t}" for t in matched_tables)
462
- + "\n\nPlease be more specific."
463
  ),
464
  "data": []
465
  }
@@ -512,7 +564,10 @@ def process_question(question):
512
  "data": []
513
  }
514
 
515
- sql = validate_sql(sanitize_sql(sql))
 
 
 
516
  cols, rows = run_query(sql)
517
 
518
  # ----------------------------------
 
1
  import os
2
+ import re
3
  import sqlite3
4
  from openai import OpenAI
5
  from difflib import get_close_matches
 
112
  "consultant": ["encounters"],
113
  "doctor": ["encounters"],
114
  "visit": ["encounters"],
115
+ "visited": ["encounters"], # Handle past tense
116
+ "visits": ["encounters"], # Handle plural
117
  "appointment": ["encounters"],
118
  "patient": ["patients"],
119
  "medication": ["medications"],
 
145
  elif any(tok in col_l for tok in tokens):
146
  score += 1
147
 
148
+ # 3️⃣ Description relevance (less weight to avoid false positives)
149
  if meta.get("description"):
150
  desc_tokens = set(meta["description"].lower().split())
151
+ # Only count meaningful word matches, not common words
152
+ common_words = {"the", "is", "at", "which", "on", "for", "a", "an"}
153
+ meaningful_matches = tokens & desc_tokens - common_words
154
+ if meaningful_matches:
155
+ score += len(meaningful_matches) * 0.5 # Reduced weight
156
 
157
+ # 4️⃣ Semantic intent mapping (important - highest priority)
158
  for intent, tables in DOMAIN_HINTS.items():
159
  if intent in q and table_l in tables:
160
  score += 5
161
 
162
  # 5️⃣ Only add if meets minimum threshold (prevents low-quality matches)
163
+ # Increased threshold from 3 to 4 for better precision
164
+ if score >= 4:
165
  matched.append((table, score))
166
 
167
  # Sort by relevance
 
300
  Rules:
301
  - Use only SELECT
302
  - SQLite syntax
303
+ - Use ONLY the exact table names listed below (do not create or infer table names)
304
  - Use only listed tables/columns
305
  - Return ONLY SQL or NOT_ANSWERABLE
306
+
307
+ IMPORTANT: If the question mentions "visit", "visited", or "visits", use the table name "encounters" (NOT "visits" or "visit").
308
+ If the question mentions "consultant" or "doctor", use the table name "encounters".
309
  """
310
 
311
  for table, meta in schema.items():
 
314
  prompt += f"- {col}: {desc}\n"
315
 
316
  prompt += f"\nQuestion: {question}\n"
317
+ prompt += "\nRemember: Use EXACT table names from the list above. Do not pluralize or modify table names."
318
  return prompt
319
 
320
 
 
338
  sql = sql.split(";")[0]
339
  return sql.replace("\n", " ").strip()
340
 
341
+ def correct_table_names(sql):
342
+ """Fix common table name mistakes in generated SQL."""
343
+ schema = load_ai_schema()
344
+ valid_tables = set(schema.keys())
345
+
346
+ sql_lower = sql.lower()
347
+ sql_corrected = sql
348
+
349
+ # Common table name mappings (case-insensitive replacement)
350
+ table_corrections = {
351
+ "visits": "encounters",
352
+ "visit": "encounters",
353
+ "providers": "encounters", # if this table doesn't exist
354
+ }
355
+
356
+ # Check each correction
357
+ for wrong_name, correct_name in table_corrections.items():
358
+ # Only correct if the wrong table doesn't exist AND correct one does
359
+ if wrong_name.lower() not in valid_tables and correct_name.lower() in valid_tables:
360
+ # Use word boundaries to avoid partial replacements
361
+ pattern = r'\b' + re.escape(wrong_name) + r'\b'
362
+ sql_corrected = re.sub(pattern, correct_name, sql_corrected, flags=re.IGNORECASE)
363
+
364
+ return sql_corrected
365
+
366
  def validate_sql(sql):
367
  if not sql.lower().startswith("select"):
368
  raise Exception("Only SELECT allowed")
 
400
  def build_table_summary(table_name):
401
  cur = conn.cursor()
402
 
403
+ # Total rows (still need to query actual data for count)
404
  total = cur.execute(
405
  f"SELECT COUNT(*) FROM {table_name}"
406
  ).fetchone()[0]
407
 
408
+ # Get column info from METADATA (ai_columns) not database structure
409
+ schema = load_ai_schema()
410
+ if table_name not in schema:
411
+ return f"Table {table_name} not found in metadata."
412
+
413
+ columns = schema[table_name]["columns"] # [(col_name, description), ...]
414
 
415
+ summary = f"Here's a summary of **{table_name}**:\n\n"
416
  summary += f"• Total records: {total}\n"
417
 
418
+ # Try to summarize categorical columns using metadata
419
+ for col_name, col_desc in columns:
420
+ # Try to determine if it's a categorical column based on name/description
421
+ # Skip likely numeric/date columns
422
+ col_lower = col_name.lower()
423
+ if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]):
424
+ continue
425
+
426
+ # Try to get breakdown for text-like columns
427
+ try:
428
+ rows = cur.execute(
429
+ f"""
430
+ SELECT {col_name}, COUNT(*)
431
+ FROM {table_name}
432
+ GROUP BY {col_name}
433
+ ORDER BY COUNT(*) DESC
434
+ LIMIT 5
435
+ """
436
+ ).fetchall()
437
+
438
+ if rows:
439
+ summary += f"\n• {col_name.capitalize()} breakdown:\n"
440
+ for val, count in rows:
441
+ summary += f" - {val}: {count}\n"
442
+ except:
443
+ pass # ignore columns that can't be grouped
444
 
445
  summary += "\nYou can ask more detailed questions about this data."
446
 
 
496
  "message": build_table_summary(matched_tables[0]),
497
  "data": []
498
  }
499
+
500
+ # Only block if too many tables matched AND it's not an analytical question
501
+ # Analytical questions (how many, count, etc.) often need multiple tables
502
+ is_analytical = any(k in q for k in [
503
+ "how many", "count", "total", "number of",
504
+ "average", "avg", "sum", "more than", "less than",
505
+ "compare", "trend"
506
+ ])
507
+
508
+ if len(matched_tables) > 4 and not is_analytical:
509
  return {
510
  "status": "ok",
511
  "message": (
512
+ "Your question matches too many datasets:\n"
513
+ + "\n".join(f"- {t}" for t in matched_tables[:5])
514
+ + "\n\nPlease be more specific about what you want to know."
515
  ),
516
  "data": []
517
  }
 
564
  "data": []
565
  }
566
 
567
+ # Sanitize, correct table names, then validate
568
+ sql = sanitize_sql(sql)
569
+ sql = correct_table_names(sql)
570
+ sql = validate_sql(sql)
571
  cols, rows = run_query(sql)
572
 
573
  # ----------------------------------