bhavika24 commited on
Commit
9496a38
·
verified ·
1 Parent(s): 0ff257a

Update engine.py

Browse files
Files changed (1) hide show
  1. engine.py +314 -282
engine.py CHANGED
@@ -1,282 +1,314 @@
1
- import os
2
- import sqlite3
3
- from openai import OpenAI
4
- from difflib import get_close_matches
5
- from datetime import datetime
6
-
7
- # =========================
8
- # SETUP
9
- # =========================
10
-
11
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12
- conn = sqlite3.connect("hospital.db", check_same_thread=False)
13
-
14
- # =========================
15
- # HUMAN RESPONSE HELPERS
16
- # =========================
17
-
18
- def humanize(text):
19
- return f"Sure 🙂\n\n{text}"
20
-
21
- def friendly(text):
22
- return f"{text}\n\nIf you want, I can help you explore this further 🙂"
23
-
24
- # =========================
25
- # SPELL CORRECTION
26
- # =========================
27
-
28
- KNOWN_TERMS = [
29
- "patient", "patients", "condition", "conditions",
30
- "encounter", "encounters", "visit", "visits",
31
- "medication", "medications",
32
- "admitted", "admission",
33
- "year", "month", "last", "recent", "today"
34
- ]
35
-
36
- def correct_spelling(q):
37
- words = q.split()
38
- fixed = []
39
- for w in words:
40
- clean = w.lower().strip(",.?")
41
- match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
42
- fixed.append(match[0] if match else w)
43
- return " ".join(fixed)
44
-
45
- # =========================
46
- # SCHEMA
47
- # =========================
48
-
49
- def load_ai_schema():
50
- cur = conn.cursor()
51
- schema = {}
52
-
53
- tables = cur.execute("""
54
- SELECT table_name, description
55
- FROM ai_tables
56
- WHERE ai_enabled = 1
57
- """).fetchall()
58
-
59
- for table, desc in tables:
60
- cols = cur.execute("""
61
- SELECT column_name, description
62
- FROM ai_columns
63
- WHERE table_name = ? AND ai_allowed = 1
64
- """, (table,)).fetchall()
65
-
66
- schema[table] = {
67
- "description": desc,
68
- "columns": cols
69
- }
70
-
71
- return schema
72
-
73
- # =========================
74
- # HUMAN SCHEMA DESCRIPTION
75
- # =========================
76
-
77
- def describe_schema():
78
- schema = load_ai_schema()
79
-
80
- response = "Here’s the data I currently have access to:\n\n"
81
-
82
- for table, meta in schema.items():
83
- response += f"• **{table.capitalize()}** — {meta['description']}\n"
84
- for col, desc in meta["columns"]:
85
- response += f" - {col}: {desc}\n"
86
- response += "\n"
87
-
88
- response += (
89
- "You can ask things like:\n"
90
- "• How many patients are there?\n"
91
- "• Patient count by gender\n"
92
- "• Admissions by year\n\n"
93
- "Just tell me what you want to explore 🙂"
94
- )
95
-
96
- return response
97
-
98
- # =========================
99
- # TIME HANDLING
100
- # =========================
101
-
102
- def get_latest_data_date():
103
- cur = conn.cursor()
104
- r = cur.execute("SELECT MAX(start_date) FROM encounters").fetchone()
105
- return r[0]
106
-
107
- def normalize_time_question(q):
108
- latest = get_latest_data_date()
109
- if not latest:
110
- return q
111
-
112
- if "today" in q:
113
- return q.replace("today", f"on {latest[:10]}")
114
-
115
- if "yesterday" in q:
116
- return q.replace("yesterday", f"on {latest[:10]}")
117
-
118
- return q
119
-
120
- # =========================
121
- # UNSUPPORTED QUESTIONS
122
- # =========================
123
-
124
- def get_unsupported_reason(q):
125
- q = q.lower()
126
-
127
- if any(w in q for w in ["consultant", "doctor"]):
128
- return {
129
- "reason": "Doctor or consultant-level data is not available.",
130
- "suggestion": "Try asking about patients, visits, or admissions."
131
- }
132
-
133
- if any(w in q for w in ["department", "specialization"]):
134
- return {
135
- "reason": "Department-level data is not stored.",
136
- "suggestion": "Try patient or visit related questions."
137
- }
138
-
139
- return None
140
-
141
- # =========================
142
- # SQL GENERATION
143
- # =========================
144
-
145
- def build_prompt(question):
146
- schema = load_ai_schema()
147
-
148
- prompt = """
149
- You are a hospital SQL assistant.
150
-
151
- Rules:
152
- - Use only SELECT
153
- - SQLite syntax
154
- - Use only listed tables/columns
155
- - Return ONLY SQL or NOT_ANSWERABLE
156
- """
157
-
158
- for table, meta in schema.items():
159
- prompt += f"\nTable: {table}\n"
160
- for col, desc in meta["columns"]:
161
- prompt += f"- {col}: {desc}\n"
162
-
163
- prompt += f"\nQuestion: {question}\n"
164
- return prompt
165
-
166
- def call_llm(prompt):
167
- res = client.chat.completions.create(
168
- model="gpt-4.1-mini",
169
- messages=[
170
- {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"},
171
- {"role": "user", "content": prompt}
172
- ],
173
- temperature=0
174
- )
175
- return res.choices[0].message.content.strip()
176
-
177
- # =========================
178
- # SQL SAFETY
179
- # =========================
180
-
181
- def sanitize_sql(sql):
182
- sql = sql.replace("```", "").replace("sql", "").strip()
183
- sql = sql.split(";")[0]
184
- return sql.replace("\n", " ").strip()
185
-
186
- def validate_sql(sql):
187
- if not sql.lower().startswith("select"):
188
- raise Exception("Only SELECT allowed")
189
- return sql
190
-
191
- def run_query(sql):
192
- cur = conn.cursor()
193
- rows = cur.execute(sql).fetchall()
194
- cols = [c[0] for c in cur.description]
195
- return cols, rows
196
-
197
- # =========================
198
- # AGGREGATE SAFETY
199
- # =========================
200
-
201
- def is_aggregate_only_query(sql):
202
- s = sql.lower()
203
- return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s
204
-
205
- def has_underlying_data(sql):
206
- base = sql.lower()
207
- if "from" not in base:
208
- return False
209
-
210
- base = base.split("from", 1)[1]
211
- test_sql = "SELECT 1 FROM " + base.split("group by")[0] + " LIMIT 1"
212
-
213
- cur = conn.cursor()
214
- return cur.execute(test_sql).fetchone() is not None
215
-
216
- # =========================
217
- # MAIN ENGINE
218
- # =========================
219
-
220
- def process_question(question):
221
-
222
- question = correct_spelling(question)
223
- question = normalize_time_question(question)
224
-
225
- if any(x in question.lower() for x in ["what data", "what tables", "which data"]):
226
- return {
227
- "status": "ok",
228
- "message": humanize(describe_schema()),
229
- "data": []
230
- }
231
-
232
- unsupported = get_unsupported_reason(question)
233
- if unsupported:
234
- return {
235
- "status": "ok",
236
- "message": (
237
- f"{unsupported['reason']}\n\n"
238
- f"{unsupported['suggestion']}\n\n"
239
- "Example questions:\n"
240
- " How many patients were admitted last year?\n"
241
- "• Total visits by month\n"
242
- "• Patient count by gender"
243
- ),
244
- "data": []
245
- }
246
-
247
- sql = call_llm(build_prompt(question))
248
-
249
- if sql == "NOT_ANSWERABLE":
250
- return {
251
- "status": "ok",
252
- "message": "I don’t have enough data to answer that.",
253
- "data": []
254
- }
255
-
256
- sql = validate_sql(sanitize_sql(sql))
257
- cols, rows = run_query(sql)
258
-
259
- if is_aggregate_only_query(sql) and not has_underlying_data(sql):
260
- latest = get_latest_data_date()
261
- return {
262
- "status": "ok",
263
- "message": friendly("No data is available for that time period."),
264
- "note": f"Available data is only up to {latest}.",
265
- "data": []
266
- }
267
-
268
- if not rows:
269
- latest = get_latest_data_date()
270
- return {
271
- "status": "ok",
272
- "message": friendly("No records found."),
273
- "note": f"Available data is only up to {latest}.",
274
- "data": []
275
- }
276
-
277
- return {
278
- "status": "ok",
279
- "sql": sql,
280
- "columns": cols,
281
- "data": rows
282
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ from openai import OpenAI
4
+ from difflib import get_close_matches
5
+ from datetime import datetime
6
+
7
+ # =========================
8
+ # SETUP
9
+ # =========================
10
+
11
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12
+ conn = sqlite3.connect("hospital.db", check_same_thread=False)
13
+
14
+ # =========================
15
+ # CONVERSATION STATE
16
+ # =========================
17
+
18
+ LAST_PROMPT_TYPE = None
19
+
20
+ # =========================
21
+ # HUMAN RESPONSE HELPERS
22
+ # =========================
23
+
24
+ def humanize(text):
25
+ return f"Sure 🙂\n\n{text}"
26
+
27
+ def friendly(text):
28
+ return f"{text}\n\nIf you want, I can help you explore this further 🙂"
29
+
30
+ def is_confirmation(text):
31
+ return text.strip().lower() in ["yes", "yep", "yeah", "ok", "okay", "sure"]
32
+
33
+ # =========================
34
+ # SPELL CORRECTION
35
+ # =========================
36
+
37
+ KNOWN_TERMS = [
38
+ "patient", "patients", "condition", "conditions",
39
+ "encounter", "encounters", "visit", "visits",
40
+ "medication", "medications",
41
+ "admitted", "admission",
42
+ "year", "month", "last", "recent", "today"
43
+ ]
44
+
45
+ def correct_spelling(q):
46
+ words = q.split()
47
+ fixed = []
48
+ for w in words:
49
+ clean = w.lower().strip(",.?")
50
+ match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
51
+ fixed.append(match[0] if match else w)
52
+ return " ".join(fixed)
53
+
54
+ # =========================
55
+ # SCHEMA
56
+ # =========================
57
+
58
+ def load_ai_schema():
59
+ cur = conn.cursor()
60
+ schema = {}
61
+
62
+ tables = cur.execute("""
63
+ SELECT table_name, description
64
+ FROM ai_tables
65
+ WHERE ai_enabled = 1
66
+ """).fetchall()
67
+
68
+ for table, desc in tables:
69
+ cols = cur.execute("""
70
+ SELECT column_name, description
71
+ FROM ai_columns
72
+ WHERE table_name = ? AND ai_allowed = 1
73
+ """, (table,)).fetchall()
74
+
75
+ schema[table] = {
76
+ "description": desc,
77
+ "columns": cols
78
+ }
79
+
80
+ return schema
81
+
82
+ # =========================
83
+ # HUMAN SCHEMA DESCRIPTION
84
+ # =========================
85
+
86
+ def describe_schema():
87
+ schema = load_ai_schema()
88
+
89
+ response = "Here’s the data I currently have access to:\n\n"
90
+
91
+ for table, meta in schema.items():
92
+ response += f"• **{table.capitalize()}** {meta['description']}\n"
93
+ for col, desc in meta["columns"]:
94
+ response += f" - {col}: {desc}\n"
95
+ response += "\n"
96
+
97
+ response += (
98
+ "You can ask things like:\n"
99
+ "• How many patients are there?\n"
100
+ "• Patient count by gender\n"
101
+ "• Admissions by year\n\n"
102
+ "Just tell me what you want to explore 🙂"
103
+ )
104
+
105
+ return response
106
+
107
+ # =========================
108
+ # TIME HANDLING
109
+ # =========================
110
+
111
+ def get_latest_data_date():
112
+ cur = conn.cursor()
113
+ r = cur.execute("SELECT MAX(start_date) FROM encounters").fetchone()
114
+ return r[0]
115
+
116
+ def normalize_time_question(q):
117
+ latest = get_latest_data_date()
118
+ if not latest:
119
+ return q
120
+
121
+ if "today" in q:
122
+ return q.replace("today", f"on {latest[:10]}")
123
+
124
+ if "yesterday" in q:
125
+ return q.replace("yesterday", f"on {latest[:10]}")
126
+
127
+ return q
128
+
129
+ # =========================
130
+ # UNSUPPORTED QUESTIONS
131
+ # =========================
132
+
133
+ def get_unsupported_reason(q):
134
+ q = q.lower()
135
+
136
+ if any(w in q for w in ["consultant", "doctor"]):
137
+ return {
138
+ "reason": "Doctor or consultant-level data is not available.",
139
+ "suggestion": "Try asking about patients, visits, or admissions."
140
+ }
141
+
142
+ if any(w in q for w in ["department", "specialization"]):
143
+ return {
144
+ "reason": "Department-level data is not stored.",
145
+ "suggestion": "Try patient or visit related questions."
146
+ }
147
+
148
+ return None
149
+
150
+ # =========================
151
+ # SQL GENERATION
152
+ # =========================
153
+
154
+ def build_prompt(question):
155
+ schema = load_ai_schema()
156
+
157
+ prompt = """
158
+ You are a hospital SQL assistant.
159
+
160
+ Rules:
161
+ - Use only SELECT
162
+ - SQLite syntax
163
+ - Use only listed tables/columns
164
+ - Return ONLY SQL or NOT_ANSWERABLE
165
+ """
166
+
167
+ for table, meta in schema.items():
168
+ prompt += f"\nTable: {table}\n"
169
+ for col, desc in meta["columns"]:
170
+ prompt += f"- {col}: {desc}\n"
171
+
172
+ prompt += f"\nQuestion: {question}\n"
173
+ return prompt
174
+
175
+ def call_llm(prompt):
176
+ res = client.chat.completions.create(
177
+ model="gpt-4.1-mini",
178
+ messages=[
179
+ {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"},
180
+ {"role": "user", "content": prompt}
181
+ ],
182
+ temperature=0
183
+ )
184
+ return res.choices[0].message.content.strip()
185
+
186
+ # =========================
187
+ # SQL SAFETY
188
+ # =========================
189
+
190
+ def sanitize_sql(sql):
191
+ sql = sql.replace("```", "").replace("sql", "").strip()
192
+ sql = sql.split(";")[0]
193
+ return sql.replace("\n", " ").strip()
194
+
195
+ def validate_sql(sql):
196
+ if not sql.lower().startswith("select"):
197
+ raise Exception("Only SELECT allowed")
198
+ return sql
199
+
200
+ def run_query(sql):
201
+ cur = conn.cursor()
202
+ rows = cur.execute(sql).fetchall()
203
+ cols = [c[0] for c in cur.description]
204
+ return cols, rows
205
+
206
+ # =========================
207
+ # AGGREGATE SAFETY
208
+ # =========================
209
+
210
+ def is_aggregate_only_query(sql):
211
+ s = sql.lower()
212
+ return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s
213
+
214
+ def has_underlying_data(sql):
215
+ base = sql.lower()
216
+ if "from" not in base:
217
+ return False
218
+
219
+ base = base.split("from", 1)[1]
220
+ test_sql = "SELECT 1 FROM " + base.split("group by")[0] + " LIMIT 1"
221
+
222
+ cur = conn.cursor()
223
+ return cur.execute(test_sql).fetchone() is not None
224
+
225
+ # =========================
226
+ # MAIN ENGINE
227
+ # =========================
228
+
229
+ def process_question(question):
230
+ global LAST_PROMPT_TYPE
231
+
232
+ question = question.strip().lower()
233
+
234
+ # Handle confirmation replies like "yes"
235
+ if is_confirmation(question) and LAST_PROMPT_TYPE == "NO_DATA":
236
+ return {
237
+ "status": "ok",
238
+ "message": (
239
+ "Great 🙂\n\n"
240
+ "Here are some things you can ask:\n"
241
+ "• How many patients were admitted in 2021?\n"
242
+ "• Patient count by gender\n"
243
+ "• Total visits by month\n"
244
+ "• Most common conditions\n\n"
245
+ "Just type one of these or ask your own question."
246
+ ),
247
+ "data": []
248
+ }
249
+
250
+ question = correct_spelling(question)
251
+ question = normalize_time_question(question)
252
+
253
+ if any(x in question for x in ["what data", "what tables", "which data"]):
254
+ return {
255
+ "status": "ok",
256
+ "message": humanize(describe_schema()),
257
+ "data": []
258
+ }
259
+
260
+ unsupported = get_unsupported_reason(question)
261
+ if unsupported:
262
+ return {
263
+ "status": "ok",
264
+ "message": (
265
+ f"{unsupported['reason']}\n\n"
266
+ f"{unsupported['suggestion']}\n\n"
267
+ "Example questions:\n"
268
+ "• How many patients were admitted last year?\n"
269
+ "• Total visits by month\n"
270
+ "• Patient count by gender"
271
+ ),
272
+ "data": []
273
+ }
274
+
275
+ sql = call_llm(build_prompt(question))
276
+
277
+ if sql == "NOT_ANSWERABLE":
278
+ return {
279
+ "status": "ok",
280
+ "message": "I don’t have enough data to answer that.",
281
+ "data": []
282
+ }
283
+
284
+ sql = validate_sql(sanitize_sql(sql))
285
+ cols, rows = run_query(sql)
286
+
287
+ if is_aggregate_only_query(sql) and not has_underlying_data(sql):
288
+ LAST_PROMPT_TYPE = "NO_DATA"
289
+ latest = get_latest_data_date()
290
+ return {
291
+ "status": "ok",
292
+ "message": friendly("No data is available for that time period."),
293
+ "note": f"Available data is only up to {latest}.",
294
+ "data": []
295
+ }
296
+
297
+ if not rows:
298
+ LAST_PROMPT_TYPE = "NO_DATA"
299
+ latest = get_latest_data_date()
300
+ return {
301
+ "status": "ok",
302
+ "message": friendly("No records found."),
303
+ "note": f"Available data is only up to {latest}.",
304
+ "data": []
305
+ }
306
+
307
+ LAST_PROMPT_TYPE = None
308
+
309
+ return {
310
+ "status": "ok",
311
+ "sql": sql,
312
+ "columns": cols,
313
+ "data": rows
314
+ }