bhavika24 commited on
Commit
31318b4
·
verified ·
1 Parent(s): 484dcc5

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +106 -294
engine.py CHANGED
@@ -17,8 +17,6 @@ def log_interaction(user_q, sql=None, result=None, error=None):
17
  "error": error
18
  })
19
 
20
-
21
-
22
  # =========================
23
  # SETUP
24
  # =========================
@@ -38,8 +36,6 @@ conn = sqlite3.connect("mimic_iv.db", check_same_thread=False)
38
  LAST_PROMPT_TYPE = None
39
  LAST_SUGGESTED_DATE = None
40
 
41
-
42
-
43
  # =========================
44
  # HUMAN RESPONSE HELPERS
45
  # =========================
@@ -87,8 +83,6 @@ def correct_spelling(q):
87
  fixed.append(match[0] if match else clean)
88
  return " ".join(fixed)
89
 
90
-
91
-
92
  # =========================
93
  # SCHEMA
94
  # =========================
@@ -174,8 +168,6 @@ def extract_relevant_tables(question, max_tables=4):
174
  "vital": ["chartevents"],
175
  "vitals": ["chartevents"],
176
  }
177
-
178
-
179
  # Only include hints for tables that exist in the schema
180
  for intent, possible_tables in hint_mappings.items():
181
  matching_tables = [t for t in possible_tables if t in table_names_lower]
@@ -238,7 +230,6 @@ def extract_relevant_tables(question, max_tables=4):
238
 
239
  return [t[0] for t in matched[:max_tables]]
240
 
241
-
242
  # =========================
243
  # HUMAN SCHEMA DESCRIPTION
244
  # =========================
@@ -308,8 +299,6 @@ def get_latest_data_date():
308
 
309
  return None
310
 
311
-
312
-
313
  def normalize_time_question(q):#total-actual date
314
  latest = get_latest_data_date()
315
  if not latest:
@@ -323,50 +312,6 @@ def normalize_time_question(q):#total-actual date
323
 
324
  return q
325
 
326
- # =========================
327
- # UNSUPPORTED QUESTIONS
328
- # =========================
329
-
330
- def is_question_supported(question):
331
- q = question.lower()
332
-
333
- # 1️⃣ Always allow analytical / time-based queries
334
- analytic_keywords = [
335
- "count", "total", "average", "avg", "sum",
336
- "how many", "number of",
337
- "trend", "increase", "decrease", "compare",
338
- "last", "latest", "recent", "past",
339
- "day", "days", "month", "year"
340
- ]
341
-
342
- if any(k in q for k in analytic_keywords):
343
- return True
344
-
345
- # 2️⃣ Schema-based relevance check
346
- schema = load_ai_schema()
347
-
348
- for table, meta in schema.items():
349
- table_l = table.lower()
350
-
351
- # Table name mentioned
352
- if table_l in q:
353
- return True
354
-
355
- # Column or description match
356
- for col, desc in meta["columns"].items():
357
- if col.lower() in q:
358
- return True
359
-
360
- if isinstance(desc, str) and any(
361
- word in desc.lower() for word in q.split()
362
- ):
363
- return True
364
-
365
- return False
366
-
367
-
368
-
369
-
370
  # =========================
371
  # SQL GENERATION
372
  # =========================
@@ -406,8 +351,7 @@ STRICT RULES:
406
  - Do NOT wrap SQL in markdown
407
  - Use explicit JOIN conditions
408
  - Prefer COUNT(*) for totals
409
-
410
- Always use these joins:
411
  - patients.subject_id = admissions.subject_id
412
  - admissions.hadm_id = icustays.hadm_id
413
  - icustays.stay_id = chartevents.stay_id
@@ -447,8 +391,6 @@ Join hints:
447
 
448
  return prompt
449
 
450
-
451
-
452
  def call_llm(prompt):
453
  """Call OpenAI API with error handling."""
454
  try:
@@ -515,65 +457,85 @@ def correct_table_names(sql):
515
 
516
 
517
  def validate_sql(sql):
518
- if " join " in sql.lower() and " on " not in sql.lower():
519
- raise ValueError("JOIN without ON condition is not allowed")
520
 
521
- if ";" in sql.strip()[:-1]:
522
- raise ValueError("Multiple SQL statements are not allowed")
 
523
 
524
- FORBIDDEN = ["insert", "update", "delete", "drop", "alter"]
525
- if any(k in sql.lower() for k in FORBIDDEN):
 
526
  raise ValueError("Unsafe SQL detected")
527
 
528
- if not sql.lower().startswith("select"):
529
- raise ValueError("Only SELECT allowed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  return sql
531
 
 
 
 
 
 
 
 
532
  def run_query(sql):
533
- """Execute SQL query with proper error handling."""
534
  cur = conn.cursor()
 
535
  try:
536
- rows = cur.execute(sql).fetchall()
537
- if cur.description:
538
- cols = [c[0] for c in cur.description]
539
- else:
540
- cols = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  return cols, rows
 
542
  except sqlite3.Error as e:
543
  raise ValueError(f"Database query error: {str(e)}")
544
 
545
- # =========================
546
- # AGGREGATE SAFETY
547
- # =========================
548
-
549
- def is_aggregate_only_query(sql):
550
- s = sql.lower()
551
- return (
552
- any(fn in s for fn in ["count(", "sum(", "avg("])
553
- and "group by" not in s
554
- and "over(" not in s
555
- )
556
 
557
 
558
- def has_underlying_data(sql):
559
- """Check if underlying data exists for the SQL query."""
560
- base = sql.lower()
561
- if "from" not in base:
562
- return False
563
-
564
- base = base.split("from", 1)[1]
565
- # Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause
566
- for clause in ["group by", "order by", "limit", "having"]:
567
- base = base.split(clause)[0]
568
-
569
- test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1"
570
-
571
- cur = conn.cursor()
572
- try:
573
- return cur.execute(test_sql).fetchone() is not None
574
- except sqlite3.Error:
575
- return False
576
-
577
  # =========================
578
  # PATIENT SUMMARY
579
  # =========================
@@ -654,235 +616,85 @@ def build_table_summary(table_name):
654
 
655
  return summary
656
 
657
-
658
-
659
  # =========================
660
  # MAIN ENGINE
661
  # =========================
662
 
663
  def process_question(question):
664
-
665
- global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE
666
-
667
- q = question.strip().lower()
668
-
669
- # ----------------------------------
670
- # Normalize first
671
- # ----------------------------------
672
  question = correct_spelling(question)
673
  question = normalize_time_question(question)
674
-
675
- LAST_PROMPT_TYPE = None
676
- LAST_SUGGESTED_DATE = None
677
-
678
 
679
- # ----------------------------------
680
- # Handle "data updated till"
681
- # ----------------------------------
682
- if any(x in q for x in ["updated", "upto", "up to", "latest data"]):
683
  return {
684
  "status": "ok",
685
- "message": f"Data is available up to {get_latest_data_date()}",
686
  "data": []
687
  }
688
 
689
- # ----------------------------------
690
- # Extract relevant tables
691
- # ----------------------------------
692
- matched_tables = extract_relevant_tables(question)
693
-
694
- # ----------------------------------
695
- # SUMMARY ONLY IF USER ASKS FOR IT
696
- # ----------------------------------
697
- if (
698
- len(matched_tables) == 1
699
- and any(k in q for k in ["summary", "overview", "describe"])
700
- and not any(k in q for k in ["count", "total", "how many", "average"])
701
- ):
702
-
703
- return {
704
- "status": "ok",
705
- "message": build_table_summary(matched_tables[0]),
706
- "data": []
707
- }
708
-
709
- # Only block if too many tables matched AND it's not an analytical question
710
- # Analytical questions (how many, count, etc.) often need multiple tables
711
- is_analytical = any(k in q for k in [
712
- "how many", "count", "total", "number of",
713
- "average", "avg", "sum", "more than", "less than",
714
- "compare", "trend"
715
- ])
716
-
717
- if len(matched_tables) > 4 and not is_analytical:
718
- return {
719
- "status": "ok",
720
- "message": (
721
- "Your question matches too many datasets:\n"
722
- + "\n".join(f"- {t}" for t in matched_tables[:5])
723
- + "\n\nPlease be more specific about what you want to know."
724
- ),
725
- "data": []
726
- }
727
-
728
-
729
- # ----------------------------------
730
- # Metadata discovery
731
- # ----------------------------------
732
- if any(x in q for x in ["what data", "what tables", "which data"]):
733
  return {
734
- "status": "ok",
735
- "message": humanize(describe_schema()),
736
  "data": []
737
  }
738
- # ----------------------------------
739
- # # LAST DATA / RECENT DATA HANDLING
740
- # # ----------------------------------
741
- if any(x in q for x in ["last data", "latest data"]):
742
- return {
743
- "status": "ok",
744
- "message": f"Latest data available is from {get_latest_data_date()}",
745
- "data": []
746
- }
747
-
748
- if "last" in q and "day" in q and ("visit" in q or "admission" in q):
749
- sql = """
750
- SELECT subject_id, admittime
751
- FROM admissions
752
- WHERE admittime >= date(
753
- (SELECT MAX(admittime) FROM admissions),
754
- '-30 days'
755
- )
756
- ORDER BY admittime DESC
757
- """
758
- cols, rows = run_query(sql)
759
-
760
- log_interaction(
761
- user_q=question,
762
- sql=sql,
763
- result=rows
764
- )
765
 
766
- return {
767
- "status": "ok",
768
- "sql": sql,
769
- "columns": cols,
770
- "data": rows
771
- }
772
-
773
- # ----------------------------------
774
- # Unsupported question check
775
- # ----------------------------------
776
- if not is_question_supported(question):
777
- log_interaction(
778
- user_q=question,
779
- error="Unsupported question"
780
- )
781
- return {
782
- "status": "ok",
783
- "message": (
784
- "That information isn’t available in the system.\n\n"
785
- "You can ask about:\n"
786
- "• Patients\n"
787
- "• Admissions / Visits\n"
788
- "• ICU stays\n"
789
- "• Diagnoses / Conditions\n"
790
- "• Vitals & lab measurements"
791
- ),
792
- "data": []
793
- }
794
-
795
- # ----------------------------------
796
- # Generate SQL
797
- # ----------------------------------
798
  try:
799
- sql = call_llm(build_prompt(question))
800
- except ValueError as e:
801
- log_interaction(
802
- user_q=question,
803
- error=str(e)
804
- )
805
- return {
806
- "status": "ok",
807
- "message": str(e),
808
- "data": []
809
- }
810
-
811
-
812
-
813
- if sql == "NOT_ANSWERABLE":
814
  return {
815
- "status": "ok",
816
- "message": "I don't have enough data to answer that.",
817
  "data": []
818
  }
819
 
820
- # Sanitize, correct table names, then validate
821
- sql = sanitize_sql(sql)
822
- sql = correct_table_names(sql)
823
- sql = validate_sql(sql)
824
- cols, rows = run_query(sql)
825
-
826
- # ✅ LOG ONCE (THIS FIXES YOUR DOWNLOAD ISSUE)
827
- log_interaction(
828
- user_q=question,
829
- sql=sql,
830
- result=rows
831
- )
832
-
833
- if not rows:
834
  return {
835
  "status": "ok",
836
- "message": friendly("No records found."),
837
  "data": []
838
  }
839
 
840
- return {
841
- "status": "ok",
842
- "sql": sql,
843
- "columns": cols,
844
- "data": rows
845
- }
846
-
847
-
848
-
849
- # ----------------------------------
850
- # No data handling
851
- # ----------------------------------
852
- if is_aggregate_only_query(sql) and not has_underlying_data(sql):
853
- LAST_PROMPT_TYPE = "NO_DATA"
854
- LAST_SUGGESTED_DATE = get_latest_data_date()
855
-
856
  return {
857
- "status": "ok",
858
- "message": friendly("No data is available for that time period."),
859
- "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.",
860
  "data": []
861
  }
862
 
863
- if not rows:
864
- log_interaction(
865
- user_q=question,
866
- sql=sql,
867
- result=[]
868
- )
869
-
870
- LAST_PROMPT_TYPE = "NO_DATA"
871
- LAST_SUGGESTED_DATE = get_latest_data_date()
872
-
873
  return {
874
- "status": "ok",
875
- "message": friendly("No records found."),
876
- "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.",
877
  "data": []
878
  }
879
 
880
- # ----------------------------------
881
- # Success
882
- # ----------------------------------
 
 
 
 
 
883
  return {
884
  "status": "ok",
885
  "sql": sql,
 
886
  "columns": cols,
887
  "data": rows
888
  }
 
17
  "error": error
18
  })
19
 
 
 
20
  # =========================
21
  # SETUP
22
  # =========================
 
36
  LAST_PROMPT_TYPE = None
37
  LAST_SUGGESTED_DATE = None
38
 
 
 
39
  # =========================
40
  # HUMAN RESPONSE HELPERS
41
  # =========================
 
83
  fixed.append(match[0] if match else clean)
84
  return " ".join(fixed)
85
 
 
 
86
  # =========================
87
  # SCHEMA
88
  # =========================
 
168
  "vital": ["chartevents"],
169
  "vitals": ["chartevents"],
170
  }
 
 
171
  # Only include hints for tables that exist in the schema
172
  for intent, possible_tables in hint_mappings.items():
173
  matching_tables = [t for t in possible_tables if t in table_names_lower]
 
230
 
231
  return [t[0] for t in matched[:max_tables]]
232
 
 
233
  # =========================
234
  # HUMAN SCHEMA DESCRIPTION
235
  # =========================
 
299
 
300
  return None
301
 
 
 
302
  def normalize_time_question(q):#total-actual date
303
  latest = get_latest_data_date()
304
  if not latest:
 
312
 
313
  return q
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  # =========================
316
  # SQL GENERATION
317
  # =========================
 
351
  - Do NOT wrap SQL in markdown
352
  - Use explicit JOIN conditions
353
  - Prefer COUNT(*) for totals
354
+ - Use these joins only if columns from both tables are required.
 
355
  - patients.subject_id = admissions.subject_id
356
  - admissions.hadm_id = icustays.hadm_id
357
  - icustays.stay_id = chartevents.stay_id
 
391
 
392
  return prompt
393
 
 
 
394
  def call_llm(prompt):
395
  """Call OpenAI API with error handling."""
396
  try:
 
457
 
458
 
459
  def validate_sql(sql):
460
+ sql_l = sql.lower().strip()
 
461
 
462
+ # Must be SELECT
463
+ if not sql_l.startswith("select"):
464
+ raise ValueError("Only SELECT statements are allowed")
465
 
466
+ # Block dangerous keywords
467
+ forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"]
468
+ if any(word in sql_l for word in forbidden):
469
  raise ValueError("Unsafe SQL detected")
470
 
471
+ # Block multiple statements
472
+ if ";" in sql_l[:-1]:
473
+ raise ValueError("Multiple SQL statements are not allowed")
474
+
475
+ # JOIN must have ON
476
+ if " join " in sql_l and " on " not in sql_l:
477
+ raise ValueError("JOIN without ON condition is not allowed")
478
+
479
+ # Prevent SELECT *
480
+ if "select *" in sql_l:
481
+ raise ValueError("SELECT * is not allowed")
482
+
483
+ # Enforce LIMIT
484
+ if "limit" not in sql_l:
485
+ sql += " LIMIT 100"
486
+
487
  return sql
488
 
489
+ def explain_sql(sql):
490
+ return {
491
+ "type": "aggregation" if "count(" in sql else "selection",
492
+ "has_join": "join" in sql.lower(),
493
+ "has_filter": "where" in sql.lower()
494
+ }
495
+
496
  def run_query(sql):
497
+ """Execute SQL query safely with validation and limits."""
498
  cur = conn.cursor()
499
+
500
  try:
501
+ # 1️⃣ Validate query plan
502
+ cur.execute("EXPLAIN QUERY PLAN " + sql)
503
+ plan = cur.fetchall()
504
+
505
+ for row in plan:
506
+ detail = row[-1].lower()
507
+ if "scan" in detail and "using index" not in detail:
508
+ raise ValueError("Query rejected: full table scan detected")
509
+
510
+ # 2️⃣ Execute query
511
+ cur.execute(sql)
512
+ rows = cur.fetchall()
513
+
514
+ # ✅ 3️⃣ Guard against inflated COUNT results
515
+ if "count(" in sql.lower() and "group by" not in sql.lower():
516
+ if len(rows) == 1 and isinstance(rows[0][0], (int, float)):
517
+ if rows[0][0] > 10_000_000:
518
+ raise ValueError(
519
+ "Suspiciously large count — possible join duplication"
520
+ )
521
+
522
+ # 4️⃣ Limit result size
523
+ MAX_ROWS = 1000
524
+ if len(rows) > MAX_ROWS:
525
+ rows = rows[:MAX_ROWS]
526
+
527
+ # 5️⃣ Extract columns
528
+ cols = [c[0] for c in cur.description] if cur.description else []
529
+
530
  return cols, rows
531
+
532
  except sqlite3.Error as e:
533
  raise ValueError(f"Database query error: {str(e)}")
534
 
535
+ finally:
536
+ cur.close()
 
 
 
 
 
 
 
 
 
537
 
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  # =========================
540
  # PATIENT SUMMARY
541
  # =========================
 
616
 
617
  return summary
618
 
 
 
619
  # =========================
620
  # MAIN ENGINE
621
  # =========================
622
 
623
  def process_question(question):
 
 
 
 
 
 
 
 
624
  question = correct_spelling(question)
625
  question = normalize_time_question(question)
 
 
 
 
626
 
627
+ # 1️⃣ Metadata requests
628
+ if any(x in question.lower() for x in ["what data", "what tables"]):
 
 
629
  return {
630
  "status": "ok",
631
+ "message": describe_schema(),
632
  "data": []
633
  }
634
 
635
+ # 2️⃣ Build LLM prompt
636
+ try:
637
+ prompt = build_prompt(question)
638
+ except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  return {
640
+ "status": "error",
641
+ "message": str(e),
642
  "data": []
643
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
 
645
+ # 3️⃣ Generate SQL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  try:
647
+ sql = call_llm(prompt)
648
+ except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  return {
650
+ "status": "error",
651
+ "message": str(e),
652
  "data": []
653
  }
654
 
655
+ if sql == "NOT_ANSWERABLE":
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  return {
657
  "status": "ok",
658
+ "message": "I don't have enough data to answer that.",
659
  "data": []
660
  }
661
 
662
+ # 4️⃣ Sanitize & validate
663
+ try:
664
+ sql = sanitize_sql(sql)
665
+ sql = correct_table_names(sql)
666
+ sql = validate_sql(sql)
667
+ sql_info = explain_sql(sql)
668
+
669
+ except Exception as e:
 
 
 
 
 
 
 
 
670
  return {
671
+ "status": "error",
672
+ "message": str(e),
 
673
  "data": []
674
  }
675
 
676
+ # 5️⃣ Execute
677
+ try:
678
+ cols, rows = run_query(sql)
679
+ except Exception as e:
 
 
 
 
 
 
680
  return {
681
+ "status": "error",
682
+ "message": str(e),
 
683
  "data": []
684
  }
685
 
686
+ # 6️⃣ Log
687
+ log_interaction(
688
+ user_q=question,
689
+ sql=sql,
690
+ result=rows[:10]
691
+ )
692
+
693
+ # 7️⃣ Return
694
  return {
695
  "status": "ok",
696
  "sql": sql,
697
+ "sql_info": sql_info,
698
  "columns": cols,
699
  "data": rows
700
  }