bhavika24 commited on
Commit
25a0c35
·
verified ·
1 Parent(s): 52a5931

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +87 -176
engine.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import sqlite3
3
  from openai import OpenAI
4
  from difflib import get_close_matches
 
5
 
6
 
7
  # =========================
@@ -13,40 +14,59 @@ conn = sqlite3.connect("hospital.db", check_same_thread=False)
13
 
14
 
15
  # =========================
16
- # Known Terms for Spell Correction
17
  # =========================
18
 
19
  KNOWN_TERMS = [
20
- "patient", "patients", "condition", "conditions", "diagnosis", "encounter", "encounters",
21
- "visit", "visits", "observation", "observations", "lab", "labs", "test", "tests",
22
- "medication", "medications", "drug", "drugs", "prescription", "prescriptions",
23
- "diabetes", "hypertension", "asthma", "cancer", "admitted", "admission"
 
 
24
  ]
25
 
26
 
27
  def correct_spelling(question: str) -> str:
28
  words = question.split()
29
- corrected_words = []
30
 
31
  for word in words:
32
- clean_word = word.lower().strip(",.?")
33
- matches = get_close_matches(clean_word, KNOWN_TERMS, n=1, cutoff=0.8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- if matches:
36
- corrected_words.append(matches[0])
37
- else:
38
- corrected_words.append(word)
39
 
40
- return " ".join(corrected_words)
 
 
 
 
 
 
41
 
42
 
43
  # =========================
44
- # Metadata Loader
45
  # =========================
46
 
47
  def load_ai_schema():
48
  cur = conn.cursor()
49
-
50
  schema = {}
51
 
52
  tables = cur.execute("""
@@ -55,17 +75,14 @@ def load_ai_schema():
55
  WHERE ai_enabled = 1
56
  """).fetchall()
57
 
58
- for table_name, desc in tables:
59
  cols = cur.execute("""
60
  SELECT column_name, description
61
  FROM ai_columns
62
  WHERE table_name = ? AND ai_allowed = 1
63
- """, (table_name,)).fetchall()
64
 
65
- schema[table_name] = {
66
- "description": desc,
67
- "columns": cols
68
- }
69
 
70
  return schema
71
 
@@ -81,23 +98,15 @@ def build_prompt(question: str) -> str:
81
  You are a hospital data assistant.
82
 
83
  Rules:
84
- - Generate only SELECT SQL queries.
85
- - Use only the tables and columns provided.
86
- - Do not invent tables or columns.
87
- - This database is SQLite. Use SQLite-compatible date functions.
88
- - For recent days use: date('now', '-N day')
89
- - Use case-insensitive matching for text fields.
90
- - Prefer LIKE with wildcards for medical condition names.
91
- - Use COUNT, AVG, MIN, MAX, GROUP BY when the question asks for totals, averages, or comparisons.
92
- - If the question cannot be answered using the schema, return NOT_ANSWERABLE.
93
- - Do not explain the query.
94
- - Return only SQL or NOT_ANSWERABLE.
95
-
96
- Available schema:
97
  """
98
 
99
  for table, meta in schema.items():
100
- prompt += f"\nTable: {table} - {meta['description']}\n"
101
  for col, desc in meta["columns"]:
102
  prompt += f" - {col}: {desc}\n"
103
 
@@ -106,212 +115,114 @@ Available schema:
106
 
107
 
108
  # =========================
109
- # LLM Call
110
  # =========================
111
 
112
  def call_llm(prompt: str) -> str:
113
- response = client.chat.completions.create(
114
  model="gpt-4.1-mini",
115
  messages=[
116
- {"role": "system", "content": "You are a SQL generator. Return only SQL. No explanation."},
117
  {"role": "user", "content": prompt}
118
  ],
119
- temperature=0.0
120
  )
121
-
122
- return response.choices[0].message.content.strip()
123
 
124
 
125
  # =========================
126
- # SQL Generation
127
  # =========================
128
 
129
- def generate_sql(question: str) -> str:
130
- prompt = build_prompt(question)
131
- sql = call_llm(prompt)
132
- return sql.strip()
133
 
134
 
135
- # =========================
136
- # SQL Cleaning & Validation
137
- # =========================
138
-
139
- def clean_sql(sql: str) -> str:
140
- sql = sql.strip()
141
-
142
- # Remove markdown code fences if present
143
- if sql.startswith("```"):
144
- parts = sql.split("```")
145
- if len(parts) > 1:
146
- sql = parts[1]
147
-
148
- sql = sql.replace("sql\n", "").strip()
149
- return sql
150
-
151
-
152
- def validate_sql(sql: str) -> str:
153
  sql = clean_sql(sql)
154
- s = sql.lower()
155
-
156
- forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"]
157
-
158
- if not s.startswith("select"):
159
- raise Exception("Only SELECT queries allowed")
160
-
161
- if any(f in s for f in forbidden):
162
- raise Exception("Forbidden SQL operation detected")
163
-
164
  return sql
165
 
166
 
167
- # =========================
168
- # Query Runner
169
- # =========================
170
-
171
- def run_query(sql: str):
172
  cur = conn.cursor()
173
- result = cur.execute(sql).fetchall()
174
- columns = [desc[0] for desc in cur.description]
175
- return columns, result
176
 
177
 
178
  # =========================
179
- # Guardrails
180
- # =========================
181
-
182
- def is_question_answerable(question):
183
- keywords = [
184
- "patient", "encounter", "condition", "observation",
185
- "medication", "visit", "diagnosis", "lab", "vital", "admitted"
186
- ]
187
-
188
- q = question.lower()
189
-
190
- if not any(k in q for k in keywords):
191
- return False
192
-
193
- return True
194
-
195
-
196
- # =========================
197
- # Time Awareness
198
  # =========================
199
 
200
  def get_latest_data_date():
201
- sql = "SELECT MAX(start_date) FROM encounters;"
202
- _, rows = run_query(sql)
203
  return rows[0][0]
204
 
205
 
206
- def check_time_relevance(question: str):
207
- q = question.lower()
208
- if any(word in q for word in ["last", "recent", "today", "this month", "this year"]):
209
- latest = get_latest_data_date()
210
- return f"Latest available data is from {latest}."
211
- return None
212
-
213
-
214
- # =========================
215
- # Empty Result Interpreter
216
- # =========================
217
-
218
- def interpret_empty_result(question: str):
219
- latest = get_latest_data_date()
220
- return f"No results found. Available data is up to {latest}."
221
-
222
- # =========================
223
- # Data Range Check
224
- # =========================
225
- from datetime import datetime
226
-
227
- def is_request_out_of_data_range(question: str) -> bool:
228
  latest = get_latest_data_date()
229
-
230
  if not latest:
231
  return True
232
 
233
- latest_date = datetime.fromisoformat(latest.replace("Z", "").split("T")[0])
234
  now = datetime.now()
235
-
236
  q = question.lower()
237
 
238
  if "this year" in q:
239
- return latest_date.year < now.year
240
 
241
  if "last month" in q:
242
- return (now.year, now.month - 1) > (latest_date.year, latest_date.month)
243
 
244
  if "recent" in q or "last 30" in q:
245
- return (now - latest_date).days > 30
246
 
247
  return False
248
 
249
 
250
-
251
  # =========================
252
- # ORCHESTRATOR (Single Entry Point)
253
  # =========================
254
 
255
  def process_question(question: str):
256
- # 0. Spell correction
257
  question = correct_spelling(question)
258
 
259
- # 1. Guardrail
260
- if not is_question_answerable(question):
261
- return {
262
- "status": "rejected",
263
- "message": "This question is not supported by the available data."
264
- }
265
 
266
- # 2. Time relevance
267
- # 2. Time relevance check
268
- if is_request_out_of_data_range(question):
269
  latest = get_latest_data_date()
270
  return {
271
  "status": "ok",
272
- "message": f"No data available for the requested time period. Latest available data is from {latest}.",
273
- "data": [],
274
- "sql": None,
275
- "note": None
276
- }
277
-
278
 
279
- # 3. Generate SQL
280
  sql = generate_sql(question)
281
-
282
- # 4. Validate SQL
283
  sql = validate_sql(sql)
284
 
285
- # 5. Execute query
286
- columns, rows = run_query(sql)
287
-
288
- # 6. Handle empty result with data coverage awareness
289
- if len(rows) == 0:
290
- latest = get_latest_data_date()
291
- q = question.lower()
292
-
293
- if any(word in q for word in ["last", "recent", "this month", "this year"]):
294
- return {
295
- "status": "ok",
296
- "sql": sql,
297
- "message": f"No data available for the requested time period. Latest available data is from {latest}.",
298
- "data": [],
299
- "note": None
300
- }
301
 
 
302
  return {
303
  "status": "ok",
304
- "sql": sql,
305
- "message": interpret_empty_result(question),
306
- "data": [],
307
- "note": time_note
308
  }
309
 
310
- # 7. Normal response
311
  return {
312
  "status": "ok",
313
  "sql": sql,
314
- "columns": columns,
315
- "data": rows[:50], # demo safety limit
316
- "note": time_note
317
  }
 
2
  import sqlite3
3
  from openai import OpenAI
4
  from difflib import get_close_matches
5
+ from datetime import datetime
6
 
7
 
8
  # =========================
 
14
 
15
 
16
  # =========================
17
+ # Known Terms
18
  # =========================
19
 
20
  KNOWN_TERMS = [
21
+ "patient", "patients", "condition", "conditions", "diagnosis",
22
+ "encounter", "encounters", "visit", "visits",
23
+ "observation", "observations", "lab", "labs",
24
+ "medication", "medications",
25
+ "diabetes", "hypertension", "asthma",
26
+ "admitted", "admission"
27
  ]
28
 
29
 
30
  def correct_spelling(question: str) -> str:
31
  words = question.split()
32
+ fixed = []
33
 
34
  for word in words:
35
+ clean = word.lower().strip(",.?")
36
+ match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
37
+ fixed.append(match[0] if match else word)
38
+
39
+ return " ".join(fixed)
40
+
41
+
42
+ # =========================
43
+ # Unsupported Concept Check
44
+ # =========================
45
+
46
+ def get_unsupported_reason(question: str):
47
+ q = question.lower()
48
+
49
+ if any(w in q for w in ["consultant", "doctor", "doctors"]):
50
+ return "Consultant or doctor workload data is not available."
51
 
52
+ if any(w in q for w in ["specialization", "department"]):
53
+ return "Doctor specialization or department data is not available."
 
 
54
 
55
+ if any(w in q for w in ["insurance", "policy"]):
56
+ return "Insurance-related data is not available."
57
+
58
+ if any(w in q for w in ["staff", "employee", "hr"]):
59
+ return "HR or staff data is not available."
60
+
61
+ return None
62
 
63
 
64
  # =========================
65
+ # Metadata
66
  # =========================
67
 
68
  def load_ai_schema():
69
  cur = conn.cursor()
 
70
  schema = {}
71
 
72
  tables = cur.execute("""
 
75
  WHERE ai_enabled = 1
76
  """).fetchall()
77
 
78
+ for table, desc in tables:
79
  cols = cur.execute("""
80
  SELECT column_name, description
81
  FROM ai_columns
82
  WHERE table_name = ? AND ai_allowed = 1
83
+ """, (table,)).fetchall()
84
 
85
+ schema[table] = {"description": desc, "columns": cols}
 
 
 
86
 
87
  return schema
88
 
 
98
  You are a hospital data assistant.
99
 
100
  Rules:
101
+ - Only generate SELECT queries.
102
+ - Use only provided tables and columns.
103
+ - SQLite syntax only.
104
+ - Use date('now', '-N day') for time filters.
105
+ - Return ONLY SQL or NOT_ANSWERABLE.
 
 
 
 
 
 
 
 
106
  """
107
 
108
  for table, meta in schema.items():
109
+ prompt += f"\nTable: {table}\n"
110
  for col, desc in meta["columns"]:
111
  prompt += f" - {col}: {desc}\n"
112
 
 
115
 
116
 
117
  # =========================
118
+ # LLM
119
  # =========================
120
 
121
  def call_llm(prompt: str) -> str:
122
+ res = client.chat.completions.create(
123
  model="gpt-4.1-mini",
124
  messages=[
125
+ {"role": "system", "content": "Return only SQL."},
126
  {"role": "user", "content": prompt}
127
  ],
128
+ temperature=0
129
  )
130
+ return res.choices[0].message.content.strip()
 
131
 
132
 
133
  # =========================
134
+ # SQL Helpers
135
  # =========================
136
 
137
+ def clean_sql(sql):
138
+ return sql.replace("```", "").replace("sql\n", "").strip()
 
 
139
 
140
 
141
+ def validate_sql(sql):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  sql = clean_sql(sql)
143
+ if not sql.lower().startswith("select"):
144
+ raise Exception("Invalid SQL")
 
 
 
 
 
 
 
 
145
  return sql
146
 
147
 
148
+ def run_query(sql):
 
 
 
 
149
  cur = conn.cursor()
150
+ rows = cur.execute(sql).fetchall()
151
+ cols = [c[0] for c in cur.description]
152
+ return cols, rows
153
 
154
 
155
  # =========================
156
+ # Time Logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  # =========================
158
 
159
  def get_latest_data_date():
160
+ _, rows = run_query("SELECT MAX(start_date) FROM encounters;")
 
161
  return rows[0][0]
162
 
163
 
164
+ def is_out_of_range(question: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  latest = get_latest_data_date()
 
166
  if not latest:
167
  return True
168
 
169
+ latest_dt = datetime.fromisoformat(latest.split("T")[0])
170
  now = datetime.now()
 
171
  q = question.lower()
172
 
173
  if "this year" in q:
174
+ return latest_dt.year < now.year
175
 
176
  if "last month" in q:
177
+ return (now.year, now.month - 1) > (latest_dt.year, latest_dt.month)
178
 
179
  if "recent" in q or "last 30" in q:
180
+ return (now - latest_dt).days > 30
181
 
182
  return False
183
 
184
 
 
185
  # =========================
186
+ # MAIN ENTRY
187
  # =========================
188
 
189
  def process_question(question: str):
190
+
191
  question = correct_spelling(question)
192
 
193
+ # Unsupported concept
194
+ reason = get_unsupported_reason(question)
195
+ if reason:
196
+ return {"status": "rejected", "message": reason}
 
 
197
 
198
+ # Out-of-range data
199
+ if is_out_of_range(question):
 
200
  latest = get_latest_data_date()
201
  return {
202
  "status": "ok",
203
+ "message": "No data available for the requested time period.",
204
+ "note": f"Latest available data is from {latest}.",
205
+ "suggestion": f"Try asking about data from {latest[:4]}.",
206
+ "data": []
207
+ }
 
208
 
209
+ # Generate SQL
210
  sql = generate_sql(question)
 
 
211
  sql = validate_sql(sql)
212
 
213
+ # Execute
214
+ cols, rows = run_query(sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ if not rows:
217
  return {
218
  "status": "ok",
219
+ "message": "No matching records found.",
220
+ "data": []
 
 
221
  }
222
 
 
223
  return {
224
  "status": "ok",
225
  "sql": sql,
226
+ "columns": cols,
227
+ "data": rows[:50]
 
228
  }