bhavika24 commited on
Commit
bd0831f
·
verified ·
1 Parent(s): 2e9817a

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +89 -36
engine.py CHANGED
@@ -3,13 +3,16 @@ import re
3
  import sqlite3
4
  from openai import OpenAI
5
  from difflib import get_close_matches
6
- from datetime import datetime
7
 
8
  # =========================
9
  # SETUP
10
  # =========================
11
 
12
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
13
  conn = sqlite3.connect("hospital.db", check_same_thread=False)
14
 
15
  # =========================
@@ -56,12 +59,6 @@ KNOWN_TERMS = [
56
  "admitted", "admission",
57
  "year", "month", "last", "recent", "today"
58
  ]
59
- DOMAIN_ALIASES = {
60
- "consultant": ["provider", "encounter"],
61
- "doctor": ["provider"],
62
- "appointment": ["encounter"],
63
- "visit": ["encounter"],
64
- }
65
 
66
  def correct_spelling(q):
67
  words = q.split()
@@ -219,9 +216,13 @@ def describe_schema(max_tables=10):
219
  # =========================
220
 
221
  def get_latest_data_date():
 
222
  cur = conn.cursor()
223
- r = cur.execute("SELECT MAX(start_date) FROM encounters").fetchone()
224
- return r[0]
 
 
 
225
 
226
  def normalize_time_question(q):
227
  latest = get_latest_data_date()
@@ -327,22 +328,32 @@ If the question mentions "consultant" or "doctor", use the table name "encounter
327
 
328
 
329
  def call_llm(prompt):
330
- res = client.chat.completions.create(
331
- model="gpt-4.1-mini",
332
- messages=[
333
- {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"},
334
- {"role": "user", "content": prompt}
335
- ],
336
- temperature=0
337
- )
338
- return res.choices[0].message.content.strip()
 
 
 
 
 
 
339
 
340
  # =========================
341
  # SQL SAFETY
342
  # =========================
343
 
344
  def sanitize_sql(sql):
345
- sql = sql.replace("```", "").replace("sql", "").strip()
 
 
 
 
346
  sql = sql.split(";")[0]
347
  return sql.replace("\n", " ").strip()
348
 
@@ -373,14 +384,21 @@ def correct_table_names(sql):
373
 
374
  def validate_sql(sql):
375
  if not sql.lower().startswith("select"):
376
- raise Exception("Only SELECT allowed")
377
  return sql
378
 
379
  def run_query(sql):
 
380
  cur = conn.cursor()
381
- rows = cur.execute(sql).fetchall()
382
- cols = [c[0] for c in cur.description]
383
- return cols, rows
 
 
 
 
 
 
384
 
385
  # =========================
386
  # AGGREGATE SAFETY
@@ -391,32 +409,61 @@ def is_aggregate_only_query(sql):
391
  return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s
392
 
393
  def has_underlying_data(sql):
 
394
  base = sql.lower()
395
  if "from" not in base:
396
  return False
397
 
398
  base = base.split("from", 1)[1]
399
- test_sql = "SELECT 1 FROM " + base.split("group by")[0] + " LIMIT 1"
 
 
 
 
400
 
401
  cur = conn.cursor()
402
- return cur.execute(test_sql).fetchone() is not None
 
 
 
403
 
404
  # =========================
405
  # PATIENT SUMMARY
406
  # =========================
407
 
408
- def build_table_summary(table_name):
409
- cur = conn.cursor()
410
-
411
- # Total rows (still need to query actual data for count)
412
- total = cur.execute(
413
- f"SELECT COUNT(*) FROM {table_name}"
414
- ).fetchone()[0]
 
 
 
415
 
416
- # Get column info from METADATA (ai_columns) not database structure
 
 
417
  schema = load_ai_schema()
418
  if table_name not in schema:
419
  return f"Table {table_name} not found in metadata."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  columns = schema[table_name]["columns"] # [(col_name, description), ...]
422
 
@@ -425,6 +472,10 @@ def build_table_summary(table_name):
425
 
426
  # Try to summarize categorical columns using metadata
427
  for col_name, col_desc in columns:
 
 
 
 
428
  # Try to determine if it's a categorical column based on name/description
429
  # Skip likely numeric/date columns
430
  col_lower = col_name.lower()
@@ -433,6 +484,7 @@ def build_table_summary(table_name):
433
 
434
  # Try to get breakdown for text-like columns
435
  try:
 
436
  rows = cur.execute(
437
  f"""
438
  SELECT {col_name}, COUNT(*)
@@ -447,8 +499,9 @@ def build_table_summary(table_name):
447
  summary += f"\n• {col_name.capitalize()} breakdown:\n"
448
  for val, count in rows:
449
  summary += f" - {val}: {count}\n"
450
- except:
451
- pass # ignore columns that can't be grouped
 
452
 
453
  summary += "\nYou can ask more detailed questions about this data."
454
 
 
3
  import sqlite3
4
  from openai import OpenAI
5
  from difflib import get_close_matches
 
6
 
7
  # =========================
8
  # SETUP
9
  # =========================
10
 
11
+ # Validate API key
12
+ api_key = os.getenv("OPENAI_API_KEY")
13
+ if not api_key:
14
+ raise ValueError("OPENAI_API_KEY environment variable is not set")
15
+ client = OpenAI(api_key=api_key)
16
  conn = sqlite3.connect("hospital.db", check_same_thread=False)
17
 
18
  # =========================
 
59
  "admitted", "admission",
60
  "year", "month", "last", "recent", "today"
61
  ]
 
 
 
 
 
 
62
 
63
  def correct_spelling(q):
64
  words = q.split()
 
216
  # =========================
217
 
218
  def get_latest_data_date():
219
+ """Get the latest data date from encounters table."""
220
  cur = conn.cursor()
221
+ try:
222
+ r = cur.execute("SELECT MAX(start_date) FROM encounters").fetchone()
223
+ return r[0] if r and r[0] else None
224
+ except sqlite3.Error:
225
+ return None
226
 
227
  def normalize_time_question(q):
228
  latest = get_latest_data_date()
 
328
 
329
 
330
  def call_llm(prompt):
331
+ """Call OpenAI API with error handling."""
332
+ try:
333
+ res = client.chat.completions.create(
334
+ model="gpt-4.1-mini",
335
+ messages=[
336
+ {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"},
337
+ {"role": "user", "content": prompt}
338
+ ],
339
+ temperature=0
340
+ )
341
+ if not res.choices or not res.choices[0].message.content:
342
+ raise ValueError("Empty response from OpenAI API")
343
+ return res.choices[0].message.content.strip()
344
+ except Exception as e:
345
+ raise ValueError(f"OpenAI API error: {str(e)}")
346
 
347
  # =========================
348
  # SQL SAFETY
349
  # =========================
350
 
351
  def sanitize_sql(sql):
352
+ # Remove code fence markers but preserve legitimate SQL
353
+ sql = sql.replace("```sql", "").replace("```", "").strip()
354
+ # Remove leading/trailing markdown code markers
355
+ if sql.startswith("sql"):
356
+ sql = sql[3:].strip()
357
  sql = sql.split(";")[0]
358
  return sql.replace("\n", " ").strip()
359
 
 
384
 
385
  def validate_sql(sql):
386
  if not sql.lower().startswith("select"):
387
+ raise ValueError("Only SELECT allowed")
388
  return sql
389
 
390
  def run_query(sql):
391
+ """Execute SQL query with proper error handling."""
392
  cur = conn.cursor()
393
+ try:
394
+ rows = cur.execute(sql).fetchall()
395
+ if cur.description:
396
+ cols = [c[0] for c in cur.description]
397
+ else:
398
+ cols = []
399
+ return cols, rows
400
+ except sqlite3.Error as e:
401
+ raise ValueError(f"Database query error: {str(e)}")
402
 
403
  # =========================
404
  # AGGREGATE SAFETY
 
409
  return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s
410
 
411
  def has_underlying_data(sql):
412
+ """Check if underlying data exists for the SQL query."""
413
  base = sql.lower()
414
  if "from" not in base:
415
  return False
416
 
417
  base = base.split("from", 1)[1]
418
+ # Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause
419
+ for clause in ["group by", "order by", "limit", "having"]:
420
+ base = base.split(clause)[0]
421
+
422
+ test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1"
423
 
424
  cur = conn.cursor()
425
+ try:
426
+ return cur.execute(test_sql).fetchone() is not None
427
+ except sqlite3.Error:
428
+ return False
429
 
430
  # =========================
431
  # PATIENT SUMMARY
432
  # =========================
433
 
434
+ def validate_identifier(name):
435
+ """Validate that identifier is safe (only alphanumeric and underscores)."""
436
+ if not name or not isinstance(name, str):
437
+ return False
438
+ # Check for SQL injection attempts
439
+ forbidden = [";", "--", "/*", "*/", "'", '"', "`", "(", ")", " ", "\n", "\t"]
440
+ if any(char in name for char in forbidden):
441
+ return False
442
+ # Must start with letter or underscore, rest alphanumeric/underscore
443
+ return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
444
 
445
+ def build_table_summary(table_name):
446
+ """Build summary for a table using metadata."""
447
+ # Validate table name against metadata first
448
  schema = load_ai_schema()
449
  if table_name not in schema:
450
  return f"Table {table_name} not found in metadata."
451
+
452
+ # Additional safety check
453
+ if not validate_identifier(table_name):
454
+ return f"Invalid table name: {table_name}"
455
+
456
+ cur = conn.cursor()
457
+
458
+ # Total rows (still need to query actual data for count)
459
+ # Note: SQLite doesn't support parameterized table names
460
+ # Since we validated table_name against metadata, it's safe
461
+ try:
462
+ total = cur.execute(
463
+ f"SELECT COUNT(*) FROM {table_name}"
464
+ ).fetchone()[0]
465
+ except sqlite3.Error as e:
466
+ return f"Error querying table {table_name}: {str(e)}"
467
 
468
  columns = schema[table_name]["columns"] # [(col_name, description), ...]
469
 
 
472
 
473
  # Try to summarize categorical columns using metadata
474
  for col_name, col_desc in columns:
475
+ # Validate column name
476
+ if not validate_identifier(col_name):
477
+ continue
478
+
479
  # Try to determine if it's a categorical column based on name/description
480
  # Skip likely numeric/date columns
481
  col_lower = col_name.lower()
 
484
 
485
  # Try to get breakdown for text-like columns
486
  try:
487
+ # Note: SQLite doesn't support parameterized identifiers, so we validate
488
  rows = cur.execute(
489
  f"""
490
  SELECT {col_name}, COUNT(*)
 
499
  summary += f"\n• {col_name.capitalize()} breakdown:\n"
500
  for val, count in rows:
501
  summary += f" - {val}: {count}\n"
502
+ except (sqlite3.Error, sqlite3.OperationalError) as e:
503
+ # Ignore columns that can't be grouped (likely not categorical)
504
+ pass
505
 
506
  summary += "\nYou can ask more detailed questions about this data."
507