bhavika24 commited on
Commit
3a7a8cd
·
verified ·
1 Parent(s): cf60b47

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +11 -20
  3. UI.py +69 -0
  4. engine.py +236 -0
  5. hospital.db +3 -0
  6. requirements.txt +3 -3
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ hospital.db filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -1,20 +1,11 @@
1
- FROM python:3.13.5-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
-
14
- RUN pip3 install -r requirements.txt
15
-
16
- EXPOSE 8501
17
-
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
-
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ EXPOSE 8501
10
+
11
+ CMD ["streamlit", "run", "ui.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
 
 
 
UI.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from engine import process_question
3
+
4
+ st.set_page_config(page_title="Hospital AI Assistant", layout="wide")
5
+
6
+ st.title("🏥 Hospital AI Assistant")
7
+ st.caption("Ask questions about patients, conditions, visits, medications, labs")
8
+
9
+ # Initialize chat history
10
+ if "messages" not in st.session_state:
11
+ st.session_state.messages = []
12
+
13
+ # Display chat history
14
+ for msg in st.session_state.messages:
15
+ with st.chat_message(msg["role"]):
16
+ st.markdown(msg["content"])
17
+
18
+ # Chat input
19
+ user_input = st.chat_input("Ask a question about hospital data...")
20
+
21
+ if user_input:
22
+ # Show user message
23
+ st.session_state.messages.append({"role": "user", "content": user_input})
24
+ with st.chat_message("user"):
25
+ st.markdown(user_input)
26
+
27
+ # Call AI engine directly
28
+ with st.spinner("Thinking..."):
29
+ try:
30
+ result = process_question(user_input)
31
+ except Exception as e:
32
+ result = {"status": "error", "message": str(e)}
33
+
34
+ # Build assistant reply
35
+ if result.get("status") == "ok":
36
+ reply = ""
37
+
38
+ # Time note (if any)
39
+ if result.get("note"):
40
+ reply += f"🕒 *{result['note']}*\n\n"
41
+
42
+ # Data table
43
+ if result.get("data"):
44
+ columns = result.get("columns", [])
45
+ data = result["data"]
46
+
47
+ table_md = "| " + " | ".join(columns) + " |\n"
48
+ table_md += "| " + " | ".join(["---"] * len(columns)) + " |\n"
49
+
50
+ for row in data[:10]:
51
+ table_md += "| " + " | ".join(str(x) for x in row) + " |\n"
52
+
53
+ reply += table_md
54
+ else:
55
+ reply += result.get("message", "No data found.")
56
+
57
+ # SQL toggle
58
+ reply += "\n\n---\n"
59
+ reply += "<details><summary><b>Generated SQL</b></summary>\n\n"
60
+ reply += f"```sql\n{result['sql']}\n```"
61
+ reply += "\n</details>"
62
+
63
+ else:
64
+ reply = f"❌ {result.get('message', 'Something went wrong')}"
65
+
66
+ # Show assistant message
67
+ st.session_state.messages.append({"role": "assistant", "content": reply})
68
+ with st.chat_message("assistant"):
69
+ st.markdown(reply)
engine.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ from openai import OpenAI
4
+
5
+
6
+ # =========================
7
+ # Setup
8
+ # =========================
9
+
10
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
11
+ conn = sqlite3.connect("hospital.db", check_same_thread=False)
12
+
13
+
14
+ # =========================
15
+ # Metadata Loader
16
+ # =========================
17
+
18
+ def load_ai_schema():
19
+ cur = conn.cursor()
20
+
21
+ schema = {}
22
+
23
+ tables = cur.execute("""
24
+ SELECT table_name, description
25
+ FROM ai_tables
26
+ WHERE ai_enabled = 1
27
+ """).fetchall()
28
+
29
+ for table_name, desc in tables:
30
+ cols = cur.execute("""
31
+ SELECT column_name, description
32
+ FROM ai_columns
33
+ WHERE table_name = ? AND ai_allowed = 1
34
+ """, (table_name,)).fetchall()
35
+
36
+ schema[table_name] = {
37
+ "description": desc,
38
+ "columns": cols
39
+ }
40
+
41
+ return schema
42
+
43
+
44
+ # =========================
45
+ # Prompt Builder
46
+ # =========================
47
+
48
+ def build_prompt(question: str) -> str:
49
+ schema = load_ai_schema()
50
+
51
+ prompt = """
52
+ You are a hospital data assistant.
53
+
54
+ Rules:
55
+ - Generate only SELECT SQL queries.
56
+ - Use only the tables and columns provided.
57
+ - Do not invent tables or columns.
58
+ - This database is SQLite. Use SQLite-compatible date functions.
59
+ - For recent days use: date('now', '-N day')
60
+ - Use case-insensitive matching for text fields.
61
+ - Prefer LIKE with wildcards for medical condition names.
62
+ - Use COUNT, AVG, MIN, MAX, GROUP BY when the question asks for totals, averages, or comparisons.
63
+ - If the question cannot be answered using the schema, return NOT_ANSWERABLE.
64
+ - Do not explain the query.
65
+ - Return only SQL or NOT_ANSWERABLE.
66
+
67
+ Available schema:
68
+ """
69
+
70
+ for table, meta in schema.items():
71
+ prompt += f"\nTable: {table} - {meta['description']}\n"
72
+ for col, desc in meta["columns"]:
73
+ prompt += f" - {col}: {desc}\n"
74
+
75
+ prompt += f"\nUser question: {question}\n"
76
+ return prompt
77
+
78
+
79
+ # =========================
80
+ # LLM Call
81
+ # =========================
82
+
83
+ def call_llm(prompt: str) -> str:
84
+ response = client.chat.completions.create(
85
+ model="gpt-4.1-mini",
86
+ messages=[
87
+ {"role": "system", "content": "You are a SQL generator. Return only SQL. No explanation."},
88
+ {"role": "user", "content": prompt}
89
+ ],
90
+ temperature=0.0
91
+ )
92
+
93
+ return response.choices[0].message.content.strip()
94
+
95
+
96
+ # =========================
97
+ # SQL Generation
98
+ # =========================
99
+
100
+ def generate_sql(question: str) -> str:
101
+ prompt = build_prompt(question)
102
+ sql = call_llm(prompt)
103
+ return sql.strip()
104
+
105
+
106
+ # =========================
107
+ # SQL Cleaning & Validation
108
+ # =========================
109
+
110
+ def clean_sql(sql: str) -> str:
111
+ sql = sql.strip()
112
+
113
+ # Remove markdown code fences if present
114
+ if sql.startswith("```"):
115
+ parts = sql.split("```")
116
+ if len(parts) > 1:
117
+ sql = parts[1]
118
+
119
+ sql = sql.replace("sql\n", "").strip()
120
+ return sql
121
+
122
+
123
+ def validate_sql(sql: str) -> str:
124
+ sql = clean_sql(sql)
125
+ s = sql.lower()
126
+
127
+ forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"]
128
+
129
+ if not s.startswith("select"):
130
+ raise Exception("Only SELECT queries allowed")
131
+
132
+ if any(f in s for f in forbidden):
133
+ raise Exception("Forbidden SQL operation detected")
134
+
135
+ return sql
136
+
137
+
138
+ # =========================
139
+ # Query Runner
140
+ # =========================
141
+
142
+ def run_query(sql: str):
143
+ cur = conn.cursor()
144
+ result = cur.execute(sql).fetchall()
145
+ columns = [desc[0] for desc in cur.description]
146
+ return columns, result
147
+
148
+
149
+ # =========================
150
+ # Guardrails
151
+ # =========================
152
+
153
+ def is_question_answerable(question):
154
+ schema = load_ai_schema()
155
+ schema_text = " ".join(schema.keys()).lower()
156
+
157
+ keywords = ["patient", "encounter", "condition", "observation", "medication", "visit", "diagnosis", "lab", "vital"]
158
+
159
+ q = question.lower()
160
+
161
+ # If none of the core domain keywords are present, likely out of scope
162
+ if not any(k in q for k in keywords):
163
+ return False
164
+
165
+ return True
166
+
167
+
168
+ # =========================
169
+ # Time Awareness
170
+ # =========================
171
+
172
+ def get_latest_data_date():
173
+ sql = "SELECT MAX(start_date) FROM encounters;"
174
+ _, rows = run_query(sql)
175
+ return rows[0][0]
176
+
177
+
178
+ def check_time_relevance(question: str):
179
+ q = question.lower()
180
+ if any(word in q for word in ["last", "recent", "today", "this month", "this year"]):
181
+ latest = get_latest_data_date()
182
+ return f"Note: Latest available data is from {latest}."
183
+ return None
184
+
185
+
186
+ # =========================
187
+ # Empty Result Interpreter
188
+ # =========================
189
+
190
+ def interpret_empty_result(question: str):
191
+ latest = get_latest_data_date()
192
+ return f"No results found. Available data is up to {latest}."
193
+
194
+
195
+ # =========================
196
+ # ORCHESTRATOR (Single Entry Point)
197
+ # =========================
198
+
199
+ def process_question(question: str):
200
+ # 1. Guardrail
201
+ if not is_question_answerable(question):
202
+ return {
203
+ "status": "rejected",
204
+ "message": "This question is not supported by the available data."
205
+ }
206
+
207
+ # 2. Time relevance
208
+ time_note = check_time_relevance(question)
209
+
210
+ # 3. Generate SQL
211
+ sql = generate_sql(question)
212
+
213
+ # 4. Validate SQL
214
+ sql = validate_sql(sql)
215
+
216
+ # 5. Execute query
217
+ columns, rows = run_query(sql)
218
+
219
+ # 6. Handle empty result
220
+ if len(rows) == 0:
221
+ return {
222
+ "status": "ok",
223
+ "sql": sql,
224
+ "message": interpret_empty_result(question),
225
+ "data": [],
226
+ "note": time_note
227
+ }
228
+
229
+ # 7. Normal response
230
+ return {
231
+ "status": "ok",
232
+ "sql": sql,
233
+ "columns": columns,
234
+ "data": rows[:50], # demo safety limit
235
+ "note": time_note
236
+ }
hospital.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d70473d08ef49bcb62c9c1edbcdb824014bd102e5235631167fb28b0d5732ad5
3
+ size 40407040
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- altair
2
- pandas
3
- streamlit
 
1
+ streamlit
2
+ openai
3
+ pandas