File size: 7,344 Bytes
3a7a8cd
 
 
03bec39
3a7a8cd
 
 
 
 
 
 
 
 
 
03bec39
 
 
 
 
 
 
 
726ac48
03bec39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a7a8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726ac48
 
 
 
3a7a8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726ac48
3a7a8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03bec39
 
 
3a7a8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726ac48
3a7a8cd
726ac48
 
 
 
 
 
 
 
 
 
 
 
3a7a8cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import os
import sqlite3
from openai import OpenAI
from difflib import get_close_matches


# =========================
# Setup
# =========================

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
conn = sqlite3.connect("hospital.db", check_same_thread=False)


# =========================
# Known Terms for Spell Correction
# =========================

KNOWN_TERMS = [
    "patient", "patients", "condition", "conditions", "diagnosis", "encounter", "encounters",
    "visit", "visits", "observation", "observations", "lab", "labs", "test", "tests",
    "medication", "medications", "drug", "drugs", "prescription", "prescriptions",
    "diabetes", "hypertension", "asthma", "cancer", "admitted", "admission"
]


def correct_spelling(question: str) -> str:
    words = question.split()
    corrected_words = []

    for word in words:
        clean_word = word.lower().strip(",.?")
        matches = get_close_matches(clean_word, KNOWN_TERMS, n=1, cutoff=0.8)

        if matches:
            corrected_words.append(matches[0])
        else:
            corrected_words.append(word)

    return " ".join(corrected_words)


# =========================
# Metadata Loader
# =========================

def load_ai_schema():
    cur = conn.cursor()

    schema = {}

    tables = cur.execute("""

        SELECT table_name, description

        FROM ai_tables

        WHERE ai_enabled = 1

    """).fetchall()

    for table_name, desc in tables:
        cols = cur.execute("""

            SELECT column_name, description

            FROM ai_columns

            WHERE table_name = ? AND ai_allowed = 1

        """, (table_name,)).fetchall()

        schema[table_name] = {
            "description": desc,
            "columns": cols
        }

    return schema


# =========================
# Prompt Builder
# =========================

def build_prompt(question: str) -> str:
    schema = load_ai_schema()

    prompt = """

You are a hospital data assistant.



Rules:

- Generate only SELECT SQL queries.

- Use only the tables and columns provided.

- Do not invent tables or columns.

- This database is SQLite. Use SQLite-compatible date functions.

- For recent days use: date('now', '-N day')

- Use case-insensitive matching for text fields.

- Prefer LIKE with wildcards for medical condition names.

- Use COUNT, AVG, MIN, MAX, GROUP BY when the question asks for totals, averages, or comparisons.

- If the question cannot be answered using the schema, return NOT_ANSWERABLE.

- Do not explain the query.

- Return only SQL or NOT_ANSWERABLE.



Available schema:

"""

    for table, meta in schema.items():
        prompt += f"\nTable: {table} - {meta['description']}\n"
        for col, desc in meta["columns"]:
            prompt += f"  - {col}: {desc}\n"

    prompt += f"\nUser question: {question}\n"
    return prompt


# =========================
# LLM Call
# =========================

def call_llm(prompt: str) -> str:
    response = client.chat.completions.create(
        model="gpt-4.1-mini",
        messages=[
            {"role": "system", "content": "You are a SQL generator. Return only SQL. No explanation."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.0
    )

    return response.choices[0].message.content.strip()


# =========================
# SQL Generation
# =========================

def generate_sql(question: str) -> str:
    prompt = build_prompt(question)
    sql = call_llm(prompt)
    return sql.strip()


# =========================
# SQL Cleaning & Validation
# =========================

def clean_sql(sql: str) -> str:
    sql = sql.strip()

    # Remove markdown code fences if present
    if sql.startswith("```"):
        parts = sql.split("```")
        if len(parts) > 1:
            sql = parts[1]

    sql = sql.replace("sql\n", "").strip()
    return sql


def validate_sql(sql: str) -> str:
    sql = clean_sql(sql)
    s = sql.lower()

    forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"]

    if not s.startswith("select"):
        raise Exception("Only SELECT queries allowed")

    if any(f in s for f in forbidden):
        raise Exception("Forbidden SQL operation detected")

    return sql


# =========================
# Query Runner
# =========================

def run_query(sql: str):
    cur = conn.cursor()
    result = cur.execute(sql).fetchall()
    columns = [desc[0] for desc in cur.description]
    return columns, result


# =========================
# Guardrails
# =========================

def is_question_answerable(question):
    keywords = [
        "patient", "encounter", "condition", "observation",
        "medication", "visit", "diagnosis", "lab", "vital", "admitted"
    ]

    q = question.lower()

    if not any(k in q for k in keywords):
        return False

    return True


# =========================
# Time Awareness
# =========================

def get_latest_data_date():
    sql = "SELECT MAX(start_date) FROM encounters;"
    _, rows = run_query(sql)
    return rows[0][0]


def check_time_relevance(question: str):
    q = question.lower()
    if any(word in q for word in ["last", "recent", "today", "this month", "this year"]):
        latest = get_latest_data_date()
        return f"Latest available data is from {latest}."
    return None


# =========================
# Empty Result Interpreter
# =========================

def interpret_empty_result(question: str):
    latest = get_latest_data_date()
    return f"No results found. Available data is up to {latest}."


# =========================
# ORCHESTRATOR (Single Entry Point)
# =========================

def process_question(question: str):
    # 0. Spell correction
    question = correct_spelling(question)

    # 1. Guardrail
    if not is_question_answerable(question):
        return {
            "status": "rejected",
            "message": "This question is not supported by the available data."
        }

    # 2. Time relevance
    time_note = check_time_relevance(question)

    # 3. Generate SQL
    sql = generate_sql(question)

    # 4. Validate SQL
    sql = validate_sql(sql)

    # 5. Execute query
    columns, rows = run_query(sql)

    # 6. Handle empty result with data coverage awareness
    if len(rows) == 0:
        latest = get_latest_data_date()
        q = question.lower()

        if any(word in q for word in ["last", "recent", "this month", "this year"]):
            return {
                "status": "ok",
                "sql": sql,
                "message": f"No data available for the requested time period. Latest available data is from {latest}.",
                "data": [],
                "note": None
            }

        return {
            "status": "ok",
            "sql": sql,
            "message": interpret_empty_result(question),
            "data": [],
            "note": time_note
        }

    # 7. Normal response
    return {
        "status": "ok",
        "sql": sql,
        "columns": columns,
        "data": rows[:50],  # demo safety limit
        "note": time_note
    }