bhavika24 commited on
Commit
de9f3fd
·
verified ·
1 Parent(s): cb559f9

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. engine.py +136 -62
  3. metadata.json +49 -0
  4. mimic_iv.db +3 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  hospital.db filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  hospital.db filter=lfs diff=lfs merge=lfs -text
37
+ mimic_iv.db filter=lfs diff=lfs merge=lfs -text
engine.py CHANGED
@@ -13,7 +13,7 @@ api_key = os.getenv("OPENAI_API_KEY")
13
  if not api_key:
14
  raise ValueError("OPENAI_API_KEY environment variable is not set")
15
  client = OpenAI(api_key=api_key)
16
- conn = sqlite3.connect("mimic_iv_demo.db", check_same_thread=False)
17
 
18
 
19
  # =========================
@@ -54,20 +54,22 @@ def is_why_question(text):
54
  # =========================
55
 
56
  KNOWN_TERMS = [
57
- "patient", "patients", "condition", "conditions",
58
- "encounter", "encounters", "visit", "visits",
59
- "medication", "medications",
60
- "admitted", "admission",
61
- "year", "month", "last", "recent", "today"
 
62
  ]
63
 
 
64
  def correct_spelling(q):
65
  words = q.split()
66
  fixed = []
67
  for w in words:
68
  clean = w.lower().strip(",.?")
69
  match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
70
- fixed.append(match[0] if match else w)
71
  return " ".join(fixed)
72
 
73
 
@@ -77,20 +79,26 @@ def correct_spelling(q):
77
  # =========================
78
  import json
79
  from functools import lru_cache
 
 
 
 
 
 
80
 
81
  @lru_cache(maxsize=1)
82
  def load_ai_schema():
83
  """Load schema from metadata JSON file with error handling."""
84
  try:
85
- with open("hospital_metadata.json", "r") as f:
86
  schema = json.load(f)
87
  if not isinstance(schema, dict):
88
  raise ValueError("Invalid metadata format: expected a dictionary")
89
  return schema
90
  except FileNotFoundError:
91
- raise FileNotFoundError("hospital_metadata.json file not found. Please create it with your table metadata.")
92
  except json.JSONDecodeError as e:
93
- raise ValueError(f"Invalid JSON in hospital_metadata.json: {str(e)}")
94
  except Exception as e:
95
  raise ValueError(f"Error loading metadata: {str(e)}")
96
 
@@ -114,17 +122,16 @@ def extract_relevant_tables(question, max_tables=4):
114
 
115
  # Build hints only for tables that actually exist
116
  hint_mappings = {
117
- "consultant": ["encounter", "encounters", "visit", "visits"],
118
- "doctor": ["encounter", "encounters", "provider", "providers"],
119
- "visit": ["encounter", "encounters", "visit", "visits"],
120
- "visited": ["encounter", "encounters", "visit", "visits"],
121
- "visits": ["encounter", "encounters", "visit", "visits"],
122
- "appointment": ["encounter", "encounters", "appointment", "appointments"],
123
- "patient": ["patient", "patients"],
124
- "medication": ["medication", "medications", "drug", "drugs"],
125
- "drug": ["medication", "medications", "drug", "drugs"],
126
- "condition": ["condition", "conditions", "diagnosis", "diagnoses"],
127
- "diagnosis": ["condition", "conditions", "diagnosis", "diagnoses"]
128
  }
129
 
130
  # Only include hints for tables that exist in the schema
@@ -150,6 +157,9 @@ def extract_relevant_tables(question, max_tables=4):
150
 
151
  # 2️⃣ Column relevance
152
  for col, desc in meta["columns"].items():
 
 
 
153
  col_l = col.lower()
154
  if col_l in q:
155
  score += 3
@@ -158,7 +168,7 @@ def extract_relevant_tables(question, max_tables=4):
158
 
159
  # 3️⃣ Description relevance (less weight to avoid false positives)
160
  if meta.get("description"):
161
- desc_tokens = set(meta["description"].lower().split())
162
  # Only count meaningful word matches, not common words
163
  common_words = {"the", "is", "at", "which", "on", "for", "a", "an"}
164
  meaningful_matches = tokens & desc_tokens - common_words
@@ -204,7 +214,8 @@ def describe_schema(max_tables=10):
204
  response += f"• **{table.capitalize()}** — {meta['description']}\n"
205
  # Show only first 5 columns per table
206
  for col, desc in list(meta["columns"].items())[:5]:
207
- response += f" - {col}: {desc}\n"
 
208
  if len(meta["columns"]) > 5:
209
  response += f" ... and {len(meta['columns']) - 5} more columns\n"
210
  response += "\n"
@@ -306,7 +317,8 @@ def is_question_supported(question):
306
 
307
  # Description match
308
  if meta.get("description"):
309
- desc_tokens = set(meta["description"].lower().split())
 
310
  score += len(tokens & desc_tokens)
311
 
312
  # ✅ If any table is relevant enough → supported
@@ -325,43 +337,80 @@ def build_prompt(question):
325
  matched = extract_relevant_tables(question)
326
  full_schema = load_ai_schema()
327
 
328
- if matched:
329
- schema = {t: full_schema[t] for t in matched}
330
- else:
331
- # 🚫 Don't send all 100+ tables! Return a helpful error with available tables
332
- available_tables = list(full_schema.keys())[:10] # Show first 10 tables
333
  tables_list = "\n".join(f"- {t}" for t in available_tables)
334
  if len(full_schema) > 10:
335
  tables_list += f"\n... and {len(full_schema) - 10} more tables"
 
336
  raise ValueError(
337
- f"I couldn't find any relevant tables for your question.\n\n"
338
  f"Available tables:\n{tables_list}\n\n"
339
- f"Please try mentioning a specific table name or use 'what data' to see all available tables."
340
  )
341
 
 
 
 
 
 
 
 
 
342
  prompt = """
343
- You are a hospital SQL assistant.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- Rules:
346
- - Use only SELECT
347
- - SQLite syntax
348
- - Use ONLY the exact table names listed below (do not create or infer table names)
349
- - Use only listed tables/columns
350
- - Return ONLY SQL or NOT_ANSWERABLE
351
 
352
- IMPORTANT: Use EXACTLY the table names provided in the list below. Do not pluralize, modify, or guess table names.
353
  """
354
 
355
  for table, meta in schema.items():
356
  prompt += f"\nTable: {table}\n"
 
357
  for col, desc in meta["columns"].items():
358
- prompt += f"- {col}: {desc}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  prompt += f"\nQuestion: {question}\n"
361
- prompt += "\nRemember: Use EXACT table names from the list above. Do not pluralize or modify table names."
 
 
 
 
 
362
  return prompt
363
 
364
 
 
365
  def call_llm(prompt):
366
  """Call OpenAI API with error handling."""
367
  try:
@@ -393,31 +442,51 @@ def sanitize_sql(sql):
393
  return sql.replace("\n", " ").strip()
394
 
395
  def correct_table_names(sql):
396
- """Fix common table name mistakes in generated SQL."""
397
  schema = load_ai_schema()
398
- valid_tables = set(schema.keys())
399
-
400
- sql_lower = sql.lower()
401
- sql_corrected = sql
402
-
403
- # Common table name mappings (case-insensitive replacement)
404
  table_corrections = {
405
- "visits": "encounters",
406
- "visit": "encounters",
407
- "providers": "encounters", # if this table doesn't exist
 
408
  }
409
-
410
- # Check each correction
411
- for wrong_name, correct_name in table_corrections.items():
412
- # Only correct if the wrong table doesn't exist AND correct one does
413
- if wrong_name.lower() not in valid_tables and correct_name.lower() in valid_tables:
414
- # Use word boundaries to avoid partial replacements
415
- pattern = r'\b' + re.escape(wrong_name) + r'\b'
416
- sql_corrected = re.sub(pattern, correct_name, sql_corrected, flags=re.IGNORECASE)
417
-
418
- return sql_corrected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  def validate_sql(sql):
 
 
 
 
 
 
 
 
 
 
421
  if not sql.lower().startswith("select"):
422
  raise ValueError("Only SELECT allowed")
423
  return sql
@@ -441,7 +510,12 @@ def run_query(sql):
441
 
442
  def is_aggregate_only_query(sql):
443
  s = sql.lower()
444
- return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s
 
 
 
 
 
445
 
446
  def has_underlying_data(sql):
447
  """Check if underlying data exists for the SQL query."""
 
13
  if not api_key:
14
  raise ValueError("OPENAI_API_KEY environment variable is not set")
15
  client = OpenAI(api_key=api_key)
16
+ conn = sqlite3.connect("mimic_iv.db", check_same_thread=False)
17
 
18
 
19
  # =========================
 
54
  # =========================
55
 
56
  KNOWN_TERMS = [
57
+ "patient", "patients",
58
+ "admission", "admissions",
59
+ "icu", "stay", "icustay",
60
+ "diagnosis", "procedure",
61
+ "medication", "lab",
62
+ "year", "month", "recent", "today"
63
  ]
64
 
65
+
66
  def correct_spelling(q):
67
  words = q.split()
68
  fixed = []
69
  for w in words:
70
  clean = w.lower().strip(",.?")
71
  match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
72
+ fixed.append(match[0] if match else clean)
73
  return " ".join(fixed)
74
 
75
 
 
79
  # =========================
80
  import json
81
  from functools import lru_cache
82
+ def col_desc(desc):
83
+ """Safely extract column description from metadata."""
84
+ if isinstance(desc, dict):
85
+ return desc.get("description", "")
86
+ return str(desc)
87
+
88
 
89
  @lru_cache(maxsize=1)
90
  def load_ai_schema():
91
  """Load schema from metadata JSON file with error handling."""
92
  try:
93
+ with open("metadata.json", "r") as f:
94
  schema = json.load(f)
95
  if not isinstance(schema, dict):
96
  raise ValueError("Invalid metadata format: expected a dictionary")
97
  return schema
98
  except FileNotFoundError:
99
+ raise FileNotFoundError("metadata.json file not found. Please create it with your table metadata.")
100
  except json.JSONDecodeError as e:
101
+ raise ValueError(f"Invalid JSON in metadata.json: {str(e)}")
102
  except Exception as e:
103
  raise ValueError(f"Error loading metadata: {str(e)}")
104
 
 
122
 
123
  # Build hints only for tables that actually exist
124
  hint_mappings = {
125
+ "patient": ["patients"],
126
+ "admission": ["admissions"],
127
+ "visit": ["admissions", "icustays"],
128
+ "icu": ["icustays", "chartevents"],
129
+ "diagnosis": ["diagnoses_icd"],
130
+ "procedure": ["procedures_icd"],
131
+ "medication": ["prescriptions", "emar", "pharmacy"],
132
+ "lab": ["labevents"],
133
+ "vital": ["chartevents"],
134
+ "stay": ["icustays"]
 
135
  }
136
 
137
  # Only include hints for tables that exist in the schema
 
157
 
158
  # 2️⃣ Column relevance
159
  for col, desc in meta["columns"].items():
160
+ desc_text = col_desc(desc)
161
+ desc_tokens = set(desc_text.lower().split())
162
+
163
  col_l = col.lower()
164
  if col_l in q:
165
  score += 3
 
168
 
169
  # 3️⃣ Description relevance (less weight to avoid false positives)
170
  if meta.get("description"):
171
+ desc_tokens = set(col_desc(meta.get("description", "")).lower().split())
172
  # Only count meaningful word matches, not common words
173
  common_words = {"the", "is", "at", "which", "on", "for", "a", "an"}
174
  meaningful_matches = tokens & desc_tokens - common_words
 
214
  response += f"• **{table.capitalize()}** — {meta['description']}\n"
215
  # Show only first 5 columns per table
216
  for col, desc in list(meta["columns"].items())[:5]:
217
+ response += f" - {col}: {col_desc(desc)}\n"
218
+
219
  if len(meta["columns"]) > 5:
220
  response += f" ... and {len(meta['columns']) - 5} more columns\n"
221
  response += "\n"
 
317
 
318
  # Description match
319
  if meta.get("description"):
320
+ desc_tokens = set(col_desc(meta["description"]).lower().split())
321
+
322
  score += len(tokens & desc_tokens)
323
 
324
  # ✅ If any table is relevant enough → supported
 
337
  matched = extract_relevant_tables(question)
338
  full_schema = load_ai_schema()
339
 
340
+ if not matched:
341
+ available_tables = list(full_schema.keys())[:10]
 
 
 
342
  tables_list = "\n".join(f"- {t}" for t in available_tables)
343
  if len(full_schema) > 10:
344
  tables_list += f"\n... and {len(full_schema) - 10} more tables"
345
+
346
  raise ValueError(
347
+ "I couldn't find any relevant tables for your question.\n\n"
348
  f"Available tables:\n{tables_list}\n\n"
349
+ "Try mentioning a table name or ask: 'what data is available?'"
350
  )
351
 
352
+ schema = {t: full_schema[t] for t in matched}
353
+
354
+ IMPORTANT_COLS = {
355
+ "subject_id", "hadm_id", "stay_id",
356
+ "icustay_id", "itemid",
357
+ "charttime", "starttime", "endtime"
358
+ }
359
+
360
  prompt = """
361
+ You are an expert SQLite query generator.
362
+
363
+ STRICT RULES:
364
+ - Use ONLY the tables and columns listed below
365
+ - NEVER invent table or column names
366
+ - If the answer cannot be derived, return: NOT_ANSWERABLE
367
+ - Do NOT explain the SQL
368
+ - Do NOT wrap SQL in markdown
369
+ - Use explicit JOIN conditions
370
+ - Prefer COUNT(*) for totals
371
+
372
+ Always use these joins:
373
+ - patients.subject_id = admissions.subject_id
374
+ - admissions.hadm_id = icustays.hadm_id
375
+ - icustays.stay_id = chartevents.stay_id
376
 
 
 
 
 
 
 
377
 
378
+ Schema:
379
  """
380
 
381
  for table, meta in schema.items():
382
  prompt += f"\nTable: {table}\n"
383
+
384
  for col, desc in meta["columns"].items():
385
+ text = f"{col} {col_desc(desc)}".lower()
386
+
387
+ # Keep columns relevant to question
388
+ if any(w in text for w in question.lower().split()):
389
+ prompt += f"- {col}\n"
390
+
391
+ # Always keep join / key columns
392
+ elif col in IMPORTANT_COLS or col.endswith("_id"):
393
+ prompt += f"- {col}\n"
394
+
395
+ # Optional: help LLM with joins (very helpful for MIMIC)
396
+ prompt += """
397
+ Join hints:
398
+ - patients.subject_id ↔ admissions.subject_id
399
+ - admissions.hadm_id ↔ icustays.hadm_id
400
+ - icustays.stay_id ↔ chartevents.stay_id
401
+ """
402
 
403
  prompt += f"\nQuestion: {question}\n"
404
+ prompt += "\nUse EXACT table and column names as shown above."
405
+
406
+ # Safety cap
407
+ if len(prompt) > 6000:
408
+ prompt = prompt[:6000] + "\n\n# Schema truncated for safety\n"
409
+
410
  return prompt
411
 
412
 
413
+
414
  def call_llm(prompt):
415
  """Call OpenAI API with error handling."""
416
  try:
 
442
  return sql.replace("\n", " ").strip()
443
 
444
  def correct_table_names(sql):
 
445
  schema = load_ai_schema()
446
+ valid_tables = {t.lower() for t in schema.keys()}
447
+
 
 
 
 
448
  table_corrections = {
449
+ "visit": "admissions",
450
+ "visits": "admissions",
451
+ "provider": "caregiver",
452
+ "providers": "caregiver"
453
  }
454
+
455
+ def replace_table(match):
456
+ keyword = match.group(1)
457
+ table = match.group(2)
458
+ table_l = table.lower()
459
+
460
+ if table_l in valid_tables:
461
+ return match.group(0)
462
+
463
+ if table_l in table_corrections:
464
+ corrected = table_corrections[table_l]
465
+ if corrected in valid_tables:
466
+ return f"{keyword} {corrected}"
467
+
468
+ return match.group(0)
469
+
470
+ pattern = re.compile(
471
+ r"\b(from|join)\s+([a-zA-Z_][a-zA-Z0-9_]*)",
472
+ re.IGNORECASE
473
+ )
474
+
475
+ return pattern.sub(replace_table, sql)
476
+
477
+
478
 
479
  def validate_sql(sql):
480
+ if " join " in sql.lower() and " on " not in sql.lower():
481
+ raise ValueError("JOIN without ON condition is not allowed")
482
+
483
+ if ";" in sql.strip()[:-1]:
484
+ raise ValueError("Multiple SQL statements are not allowed")
485
+
486
+ FORBIDDEN = ["insert", "update", "delete", "drop", "alter"]
487
+ if any(k in sql.lower() for k in FORBIDDEN):
488
+ raise ValueError("Unsafe SQL detected")
489
+
490
  if not sql.lower().startswith("select"):
491
  raise ValueError("Only SELECT allowed")
492
  return sql
 
510
 
511
  def is_aggregate_only_query(sql):
512
  s = sql.lower()
513
+ return (
514
+ any(fn in s for fn in ["count(", "sum(", "avg("])
515
+ and "group by" not in s
516
+ and "over(" not in s
517
+ )
518
+
519
 
520
  def has_underlying_data(sql):
521
  """Check if underlying data exists for the SQL query."""
metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "patients": {
3
+ "description": "Patient demographic information",
4
+ "columns": {
5
+ "subject_id": "Unique patient identifier",
6
+ "gender": "Biological sex",
7
+ "anchor_age": "Approximate age",
8
+ "anchor_year": "Anchor year for age"
9
+ }
10
+ },
11
+ "admissions": {
12
+ "description": "Hospital admissions for patients",
13
+ "columns": {
14
+ "hadm_id": "Hospital admission ID",
15
+ "subject_id": "Patient ID",
16
+ "admittime": "Admission timestamp",
17
+ "dischtime": "Discharge timestamp",
18
+ "admission_type": "Emergency, elective, etc",
19
+ "admission_location": "Source of admission"
20
+ }
21
+ },
22
+ "icustays": {
23
+ "description": "ICU stay records",
24
+ "columns": {
25
+ "stay_id": "ICU stay identifier",
26
+ "hadm_id": "Hospital admission ID",
27
+ "subject_id": "Patient ID",
28
+ "intime": "ICU admission time",
29
+ "outtime": "ICU discharge time"
30
+ }
31
+ },
32
+ "chartevents": {
33
+ "description": "Time-series ICU measurements (vitals, labs)",
34
+ "columns": {
35
+ "stay_id": "ICU stay ID",
36
+ "itemid": "Measurement type",
37
+ "charttime": "Time of observation",
38
+ "valuenum": "Numeric value"
39
+ }
40
+ },
41
+ "diagnoses_icd": {
42
+ "description": "ICD diagnoses for admissions",
43
+ "columns": {
44
+ "hadm_id": "Hospital admission ID",
45
+ "icd_code": "Diagnosis code",
46
+ "icd_version": "ICD version"
47
+ }
48
+ }
49
+ }
mimic_iv.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f199f1b69f2ec1b722011e5055797c0e11f139f1dc899e9076f9ecef6d7c1ce6
3
+ size 128155648