bhavika24 commited on
Commit
28620f4
·
verified ·
1 Parent(s): 7b117b4

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +68 -27
engine.py CHANGED
@@ -80,8 +80,19 @@ from functools import lru_cache
80
 
81
  @lru_cache(maxsize=1)
82
  def load_ai_schema():
83
- with open("hospital_metadata.json", "r") as f:
84
- return json.load(f)
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # =========================
87
  # TABLE MATCHING (CORE LOGIC)
@@ -94,20 +105,33 @@ def extract_relevant_tables(question, max_tables=4):
94
 
95
  matched = []
96
 
97
- # Lightweight intent hints (NO hard dependency)
98
- DOMAIN_HINTS = {
99
- "consultant": ["encounters"],
100
- "doctor": ["encounters"],
101
- "visit": ["encounters"],
102
- "visited": ["encounters"], # Handle past tense
103
- "visits": ["encounters"], # Handle plural
104
- "appointment": ["encounters"],
105
- "patient": ["patients"],
106
- "medication": ["medications"],
107
- "drug": ["medications"],
108
- "condition": ["conditions"],
109
- "diagnosis": ["conditions"]
 
 
 
 
 
 
 
110
  }
 
 
 
 
 
 
111
 
112
  # Early exit threshold - if we find a perfect match, we can stop early
113
  VERY_HIGH_SCORE = 10
@@ -125,7 +149,7 @@ def extract_relevant_tables(question, max_tables=4):
125
  continue
126
 
127
  # 2️⃣ Column relevance
128
- for col, _ in meta["columns"]:
129
  col_l = col.lower()
130
  if col_l in q:
131
  score += 3
@@ -173,7 +197,7 @@ def describe_schema(max_tables=10):
173
  for table, meta in shown_tables:
174
  response += f"• **{table.capitalize()}** — {meta['description']}\n"
175
  # Show only first 5 columns per table
176
- for col, desc in list(meta["columns"])[:5]:
177
  response += f" - {col}: {desc}\n"
178
  if len(meta["columns"]) > 5:
179
  response += f" ... and {len(meta['columns']) - 5} more columns\n"
@@ -198,12 +222,30 @@ def describe_schema(max_tables=10):
198
  # =========================
199
 
200
  def get_latest_data_date():
201
- try:
202
- return conn.execute(
203
- "SELECT MAX(admittime) FROM admissions"
204
- ).fetchone()[0]
205
- except:
206
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
 
209
  def normalize_time_question(q):
@@ -249,7 +291,7 @@ def is_question_supported(question):
249
  score += 3
250
 
251
  # Column name match
252
- for col, _ in meta["columns"]:
253
  col_l = col.lower()
254
  if col_l in q:
255
  score += 2
@@ -295,8 +337,7 @@ Rules:
295
  - Use only listed tables/columns
296
  - Return ONLY SQL or NOT_ANSWERABLE
297
 
298
- IMPORTANT: If the question mentions "visit", "visited", or "visits", use the table name "encounters" (NOT "visits" or "visit").
299
- If the question mentions "consultant" or "doctor", use the table name "encounters".
300
  """
301
 
302
  for table, meta in schema.items():
@@ -447,7 +488,7 @@ def build_table_summary(table_name):
447
  except sqlite3.Error as e:
448
  return f"Error querying table {table_name}: {str(e)}"
449
 
450
- columns = schema[table_name]["columns"] # [(col_name, description), ...]
451
 
452
  summary = f"Here's a summary of **{table_name}**:\n\n"
453
  summary += f"• Total records: {total}\n"
 
80
 
81
  @lru_cache(maxsize=1)
82
  def load_ai_schema():
83
+ """Load schema from metadata JSON file with error handling."""
84
+ try:
85
+ with open("hospital_metadata.json", "r") as f:
86
+ schema = json.load(f)
87
+ if not isinstance(schema, dict):
88
+ raise ValueError("Invalid metadata format: expected a dictionary")
89
+ return schema
90
+ except FileNotFoundError:
91
+ raise FileNotFoundError("hospital_metadata.json file not found. Please create it with your table metadata.")
92
+ except json.JSONDecodeError as e:
93
+ raise ValueError(f"Invalid JSON in hospital_metadata.json: {str(e)}")
94
+ except Exception as e:
95
+ raise ValueError(f"Error loading metadata: {str(e)}")
96
 
97
  # =========================
98
  # TABLE MATCHING (CORE LOGIC)
 
105
 
106
  matched = []
107
 
108
+ # Lightweight intent hints - dynamically filter to only include tables that exist
109
+ # Map natural language terms to potential table names (check against schema)
110
+ all_tables = list(schema.keys())
111
+ table_names_lower = [t.lower() for t in all_tables]
112
+
113
+ DOMAIN_HINTS = {}
114
+
115
+ # Build hints only for tables that actually exist
116
+ hint_mappings = {
117
+ "consultant": ["encounter", "encounters", "visit", "visits"],
118
+ "doctor": ["encounter", "encounters", "provider", "providers"],
119
+ "visit": ["encounter", "encounters", "visit", "visits"],
120
+ "visited": ["encounter", "encounters", "visit", "visits"],
121
+ "visits": ["encounter", "encounters", "visit", "visits"],
122
+ "appointment": ["encounter", "encounters", "appointment", "appointments"],
123
+ "patient": ["patient", "patients"],
124
+ "medication": ["medication", "medications", "drug", "drugs"],
125
+ "drug": ["medication", "medications", "drug", "drugs"],
126
+ "condition": ["condition", "conditions", "diagnosis", "diagnoses"],
127
+ "diagnosis": ["condition", "conditions", "diagnosis", "diagnoses"]
128
  }
129
+
130
+ # Only include hints for tables that exist in the schema
131
+ for intent, possible_tables in hint_mappings.items():
132
+ matching_tables = [t for t in possible_tables if t in table_names_lower]
133
+ if matching_tables:
134
+ DOMAIN_HINTS[intent] = matching_tables
135
 
136
  # Early exit threshold - if we find a perfect match, we can stop early
137
  VERY_HIGH_SCORE = 10
 
149
  continue
150
 
151
  # 2️⃣ Column relevance
152
+ for col, desc in meta["columns"].items():
153
  col_l = col.lower()
154
  if col_l in q:
155
  score += 3
 
197
  for table, meta in shown_tables:
198
  response += f"• **{table.capitalize()}** — {meta['description']}\n"
199
  # Show only first 5 columns per table
200
+ for col, desc in list(meta["columns"].items())[:5]:
201
  response += f" - {col}: {desc}\n"
202
  if len(meta["columns"]) > 5:
203
  response += f" ... and {len(meta['columns']) - 5} more columns\n"
 
222
  # =========================
223
 
224
  def get_latest_data_date():
225
+ """Get the latest data date by checking tables with date columns."""
226
+ schema = load_ai_schema()
227
+
228
+ # Common date column names to check
229
+ date_columns = ["date", "start_date", "end_date", "admission_date", "admittime", "dischtime", "created_at", "updated_at"]
230
+
231
+ # Try to find a table with a date column
232
+ for table_name in schema.keys():
233
+ columns = schema[table_name].get("columns", {})
234
+
235
+ # Check if table has any date-like column
236
+ for col_name in columns.keys():
237
+ col_lower = col_name.lower()
238
+ if any(date_col in col_lower for date_col in date_columns):
239
+ try:
240
+ result = conn.execute(
241
+ f"SELECT MAX({col_name}) FROM {table_name}"
242
+ ).fetchone()
243
+ if result and result[0]:
244
+ return result[0]
245
+ except (sqlite3.Error, sqlite3.OperationalError):
246
+ continue # Try next table/column
247
+
248
+ return None
249
 
250
 
251
  def normalize_time_question(q):
 
291
  score += 3
292
 
293
  # Column name match
294
+ for col, desc in meta["columns"].items():
295
  col_l = col.lower()
296
  if col_l in q:
297
  score += 2
 
337
  - Use only listed tables/columns
338
  - Return ONLY SQL or NOT_ANSWERABLE
339
 
340
+ IMPORTANT: Use EXACTLY the table names provided in the list below. Do not pluralize, modify, or guess table names.
 
341
  """
342
 
343
  for table, meta in schema.items():
 
488
  except sqlite3.Error as e:
489
  return f"Error querying table {table_name}: {str(e)}"
490
 
491
+ columns = schema[table_name]["columns"] # {col_name: description, ...}
492
 
493
  summary = f"Here's a summary of **{table_name}**:\n\n"
494
  summary += f"• Total records: {total}\n"