bhavika24 commited on
Commit
f3ebab8
·
verified ·
1 Parent(s): 31318b4

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +13 -195
engine.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import re
3
- import sqlite3
4
  from openai import OpenAI
5
  from difflib import get_close_matches
6
  from datetime import datetime
@@ -26,7 +25,6 @@ api_key = os.getenv("OPENAI_API_KEY")
26
  if not api_key:
27
  raise ValueError("OPENAI_API_KEY environment variable is not set")
28
  client = OpenAI(api_key=api_key)
29
- conn = sqlite3.connect("mimic_iv.db", check_same_thread=False)
30
 
31
 
32
  # =========================
@@ -44,14 +42,6 @@ def humanize(text):
44
  return f"Sure \n\n{text}"
45
 
46
  def friendly(text):
47
- global LAST_SUGGESTED_DATE
48
- if LAST_SUGGESTED_DATE:
49
- return f"{text}\n\nLast data available is {LAST_SUGGESTED_DATE}"
50
- else:
51
- # If date not set yet, try to get it
52
- date = get_latest_data_date()
53
- if date:
54
- return f"{text}\n\nLast data available is {date}"
55
  return text
56
 
57
  def is_confirmation(text):
@@ -271,47 +261,6 @@ def describe_schema(max_tables=10):#what data you have or which table exist
271
  # TIME HANDLING
272
  # =========================
273
 
274
- def get_latest_data_date():
275
- """
276
- Returns the most meaningful 'latest date' for the system.
277
- Priority:
278
- 1. admissions.admittime
279
- 2. icustays.intime
280
- 3. chartevents.charttime
281
- """
282
-
283
- checks = [
284
- ("admissions", "admittime"),
285
- ("icustays", "intime"),
286
- ("chartevents", "charttime"),
287
- ]
288
-
289
- for table, column in checks:
290
- try:
291
- result = conn.execute(
292
- f"SELECT MAX({column}) FROM {table}"
293
- ).fetchone()
294
-
295
- if result and result[0]:
296
- return result[0]
297
- except Exception:
298
- continue
299
-
300
- return None
301
-
302
- def normalize_time_question(q):#total-actual date
303
- latest = get_latest_data_date()
304
- if not latest:
305
- return q
306
-
307
- if "today" in q:
308
- return q.replace("today", f"on {latest[:10]}")
309
-
310
- if "yesterday" in q:
311
- return q.replace("yesterday", f"on {latest[:10]}")
312
-
313
- return q
314
-
315
  # =========================
316
  # SQL GENERATION
317
  # =========================
@@ -493,49 +442,6 @@ def explain_sql(sql):
493
  "has_filter": "where" in sql.lower()
494
  }
495
 
496
- def run_query(sql):
497
- """Execute SQL query safely with validation and limits."""
498
- cur = conn.cursor()
499
-
500
- try:
501
- # 1️⃣ Validate query plan
502
- cur.execute("EXPLAIN QUERY PLAN " + sql)
503
- plan = cur.fetchall()
504
-
505
- for row in plan:
506
- detail = row[-1].lower()
507
- if "scan" in detail and "using index" not in detail:
508
- raise ValueError("Query rejected: full table scan detected")
509
-
510
- # 2️⃣ Execute query
511
- cur.execute(sql)
512
- rows = cur.fetchall()
513
-
514
- # ✅ 3️⃣ Guard against inflated COUNT results
515
- if "count(" in sql.lower() and "group by" not in sql.lower():
516
- if len(rows) == 1 and isinstance(rows[0][0], (int, float)):
517
- if rows[0][0] > 10_000_000:
518
- raise ValueError(
519
- "Suspiciously large count — possible join duplication"
520
- )
521
-
522
- # 4️⃣ Limit result size
523
- MAX_ROWS = 1000
524
- if len(rows) > MAX_ROWS:
525
- rows = rows[:MAX_ROWS]
526
-
527
- # 5️⃣ Extract columns
528
- cols = [c[0] for c in cur.description] if cur.description else []
529
-
530
- return cols, rows
531
-
532
- except sqlite3.Error as e:
533
- raise ValueError(f"Database query error: {str(e)}")
534
-
535
- finally:
536
- cur.close()
537
-
538
-
539
  # =========================
540
  # PATIENT SUMMARY
541
  # =========================
@@ -551,85 +457,18 @@ def validate_identifier(name):
551
  # Must start with letter or underscore, rest alphanumeric/underscore
552
  return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
553
 
554
- def build_table_summary(table_name):
555
- """Build summary for a table using metadata."""
556
- # Validate table name against metadata first
557
- schema = load_ai_schema()
558
- if table_name not in schema:
559
- return f"Table {table_name} not found in metadata."
560
-
561
- # Additional safety check
562
- if not validate_identifier(table_name):
563
- return f"Invalid table name: {table_name}"
564
-
565
- cur = conn.cursor()
566
-
567
- # Total rows (still need to query actual data for count)
568
- # Note: SQLite doesn't support parameterized table names
569
- # Since we validated table_name against metadata, it's safe
570
- try:
571
- total = cur.execute(
572
- f"SELECT COUNT(*) FROM {table_name}"
573
- ).fetchone()[0]
574
- except sqlite3.Error as e:
575
- return f"Error querying table {table_name}: {str(e)}"
576
-
577
- columns = schema[table_name]["columns"] # {col_name: description, ...}
578
-
579
- summary = f"Here's a summary of **{table_name}**:\n\n"
580
- summary += f"• Total records: {total}\n"
581
-
582
- # Try to summarize categorical columns using metadata
583
- for col_name, col_desc in columns.items():
584
- # Validate column name
585
- if not validate_identifier(col_name):
586
- continue
587
-
588
- # Try to determine if it's a categorical column based on name/description
589
- # Skip likely numeric/date columns
590
- col_lower = col_name.lower()
591
- if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]):
592
- continue
593
-
594
- # Try to get breakdown for text-like columns
595
- try:
596
- # Note: SQLite doesn't support parameterized identifiers, so we validate
597
- rows = cur.execute(
598
- f"""
599
- SELECT {col_name}, COUNT(*)
600
- FROM {table_name}
601
- GROUP BY {col_name}
602
- ORDER BY COUNT(*) DESC
603
- LIMIT 5
604
- """
605
- ).fetchall()
606
-
607
- if rows:
608
- summary += f"\n• {col_name.capitalize()} breakdown:\n"
609
- for val, count in rows:
610
- summary += f" - {val}: {count}\n"
611
- except (sqlite3.Error, sqlite3.OperationalError) as e:
612
- # Ignore columns that can't be grouped (likely not categorical)
613
- pass
614
-
615
- summary += "\nYou can ask more detailed questions about this data."
616
-
617
- return summary
618
-
619
  # =========================
620
  # MAIN ENGINE
621
  # =========================
622
 
623
  def process_question(question):
624
  question = correct_spelling(question)
625
- question = normalize_time_question(question)
626
 
627
  # 1️⃣ Metadata requests
628
  if any(x in question.lower() for x in ["what data", "what tables"]):
629
  return {
630
  "status": "ok",
631
- "message": describe_schema(),
632
- "data": []
633
  }
634
 
635
  # 2️⃣ Build LLM prompt
@@ -638,8 +477,7 @@ def process_question(question):
638
  except Exception as e:
639
  return {
640
  "status": "error",
641
- "message": str(e),
642
- "data": []
643
  }
644
 
645
  # 3️⃣ Generate SQL
@@ -648,15 +486,13 @@ def process_question(question):
648
  except Exception as e:
649
  return {
650
  "status": "error",
651
- "message": str(e),
652
- "data": []
653
  }
654
 
655
  if sql == "NOT_ANSWERABLE":
656
  return {
657
  "status": "ok",
658
- "message": "I don't have enough data to answer that.",
659
- "data": []
660
  }
661
 
662
  # 4️⃣ Sanitize & validate
@@ -665,36 +501,18 @@ def process_question(question):
665
  sql = correct_table_names(sql)
666
  sql = validate_sql(sql)
667
  sql_info = explain_sql(sql)
668
-
669
- except Exception as e:
670
- return {
671
- "status": "error",
672
- "message": str(e),
673
- "data": []
674
- }
675
-
676
- # 5️⃣ Execute
677
- try:
678
- cols, rows = run_query(sql)
679
  except Exception as e:
680
  return {
681
  "status": "error",
682
- "message": str(e),
683
- "data": []
684
  }
685
 
686
- # 6️⃣ Log
687
- log_interaction(
688
- user_q=question,
689
- sql=sql,
690
- result=rows[:10]
691
- )
692
-
693
- # 7️⃣ Return
694
  return {
695
- "status": "ok",
696
- "sql": sql,
697
- "sql_info": sql_info,
698
- "columns": cols,
699
- "data": rows
700
- }
 
 
1
  import os
2
  import re
 
3
  from openai import OpenAI
4
  from difflib import get_close_matches
5
  from datetime import datetime
 
25
  if not api_key:
26
  raise ValueError("OPENAI_API_KEY environment variable is not set")
27
  client = OpenAI(api_key=api_key)
 
28
 
29
 
30
  # =========================
 
42
  return f"Sure \n\n{text}"
43
 
44
  def friendly(text):
 
 
 
 
 
 
 
 
45
  return text
46
 
47
  def is_confirmation(text):
 
261
  # TIME HANDLING
262
  # =========================
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  # =========================
265
  # SQL GENERATION
266
  # =========================
 
442
  "has_filter": "where" in sql.lower()
443
  }
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  # =========================
446
  # PATIENT SUMMARY
447
  # =========================
 
457
  # Must start with letter or underscore, rest alphanumeric/underscore
458
  return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  # =========================
461
  # MAIN ENGINE
462
  # =========================
463
 
464
  def process_question(question):
465
  question = correct_spelling(question)
 
466
 
467
  # 1️⃣ Metadata requests
468
  if any(x in question.lower() for x in ["what data", "what tables"]):
469
  return {
470
  "status": "ok",
471
+ "message": describe_schema()
 
472
  }
473
 
474
  # 2️⃣ Build LLM prompt
 
477
  except Exception as e:
478
  return {
479
  "status": "error",
480
+ "message": str(e)
 
481
  }
482
 
483
  # 3️⃣ Generate SQL
 
486
  except Exception as e:
487
  return {
488
  "status": "error",
489
+ "message": str(e)
 
490
  }
491
 
492
  if sql == "NOT_ANSWERABLE":
493
  return {
494
  "status": "ok",
495
+ "message": "I don't have enough data to answer that."
 
496
  }
497
 
498
  # 4️⃣ Sanitize & validate
 
501
  sql = correct_table_names(sql)
502
  sql = validate_sql(sql)
503
  sql_info = explain_sql(sql)
 
 
 
 
 
 
 
 
 
 
 
504
  except Exception as e:
505
  return {
506
  "status": "error",
507
+ "message": str(e)
 
508
  }
509
 
510
+ # 5️⃣ Return SQL only (no execution)
 
 
 
 
 
 
 
511
  return {
512
+ "status": "ok",
513
+ "message": humanize(
514
+ "Here’s the SQL query I generated based on your question 😊"
515
+ ),
516
+ "sql": sql,
517
+ "sql_info": sql_info
518
+ }