import os import sqlite3 from openai import OpenAI from difflib import get_close_matches # ========================= # Setup # ========================= client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) conn = sqlite3.connect("hospital.db", check_same_thread=False) # ========================= # Known Terms for Spell Correction # ========================= KNOWN_TERMS = [ "patient", "patients", "condition", "conditions", "diagnosis", "encounter", "encounters", "visit", "visits", "observation", "observations", "lab", "labs", "test", "tests", "medication", "medications", "drug", "drugs", "prescription", "prescriptions", "diabetes", "hypertension", "asthma", "cancer", "admitted", "admission" ] def correct_spelling(question: str) -> str: words = question.split() corrected_words = [] for word in words: clean_word = word.lower().strip(",.?") matches = get_close_matches(clean_word, KNOWN_TERMS, n=1, cutoff=0.8) if matches: corrected_words.append(matches[0]) else: corrected_words.append(word) return " ".join(corrected_words) # ========================= # Metadata Loader # ========================= def load_ai_schema(): cur = conn.cursor() schema = {} tables = cur.execute(""" SELECT table_name, description FROM ai_tables WHERE ai_enabled = 1 """).fetchall() for table_name, desc in tables: cols = cur.execute(""" SELECT column_name, description FROM ai_columns WHERE table_name = ? AND ai_allowed = 1 """, (table_name,)).fetchall() schema[table_name] = { "description": desc, "columns": cols } return schema # ========================= # Prompt Builder # ========================= def build_prompt(question: str) -> str: schema = load_ai_schema() prompt = """ You are a hospital data assistant. Rules: - Generate only SELECT SQL queries. - Use only the tables and columns provided. - Do not invent tables or columns. - This database is SQLite. Use SQLite-compatible date functions. - For recent days use: date('now', '-N day') - Use case-insensitive matching for text fields. - Prefer LIKE with wildcards for medical condition names. - Use COUNT, AVG, MIN, MAX, GROUP BY when the question asks for totals, averages, or comparisons. - If the question cannot be answered using the schema, return NOT_ANSWERABLE. - Do not explain the query. - Return only SQL or NOT_ANSWERABLE. Available schema: """ for table, meta in schema.items(): prompt += f"\nTable: {table} - {meta['description']}\n" for col, desc in meta["columns"]: prompt += f" - {col}: {desc}\n" prompt += f"\nUser question: {question}\n" return prompt # ========================= # LLM Call # ========================= def call_llm(prompt: str) -> str: response = client.chat.completions.create( model="gpt-4.1-mini", messages=[ {"role": "system", "content": "You are a SQL generator. Return only SQL. No explanation."}, {"role": "user", "content": prompt} ], temperature=0.0 ) return response.choices[0].message.content.strip() # ========================= # SQL Generation # ========================= def generate_sql(question: str) -> str: prompt = build_prompt(question) sql = call_llm(prompt) return sql.strip() # ========================= # SQL Cleaning & Validation # ========================= def clean_sql(sql: str) -> str: sql = sql.strip() # Remove markdown code fences if present if sql.startswith("```"): parts = sql.split("```") if len(parts) > 1: sql = parts[1] sql = sql.replace("sql\n", "").strip() return sql def validate_sql(sql: str) -> str: sql = clean_sql(sql) s = sql.lower() forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"] if not s.startswith("select"): raise Exception("Only SELECT queries allowed") if any(f in s for f in forbidden): raise Exception("Forbidden SQL operation detected") return sql # ========================= # Query Runner # ========================= def run_query(sql: str): cur = conn.cursor() result = cur.execute(sql).fetchall() columns = [desc[0] for desc in cur.description] return columns, result # ========================= # Guardrails # ========================= def is_question_answerable(question): keywords = [ "patient", "encounter", "condition", "observation", "medication", "visit", "diagnosis", "lab", "vital", "admitted" ] q = question.lower() if not any(k in q for k in keywords): return False return True # ========================= # Time Awareness # ========================= def get_latest_data_date(): sql = "SELECT MAX(start_date) FROM encounters;" _, rows = run_query(sql) return rows[0][0] def check_time_relevance(question: str): q = question.lower() if any(word in q for word in ["last", "recent", "today", "this month", "this year"]): latest = get_latest_data_date() return f"Latest available data is from {latest}." return None # ========================= # Empty Result Interpreter # ========================= def interpret_empty_result(question: str): latest = get_latest_data_date() return f"No results found. Available data is up to {latest}." # ========================= # ORCHESTRATOR (Single Entry Point) # ========================= def process_question(question: str): # 0. Spell correction question = correct_spelling(question) # 1. Guardrail if not is_question_answerable(question): return { "status": "rejected", "message": "This question is not supported by the available data." } # 2. Time relevance time_note = check_time_relevance(question) # 3. Generate SQL sql = generate_sql(question) # 4. Validate SQL sql = validate_sql(sql) # 5. Execute query columns, rows = run_query(sql) # 6. Handle empty result with data coverage awareness if len(rows) == 0: latest = get_latest_data_date() q = question.lower() if any(word in q for word in ["last", "recent", "this month", "this year"]): return { "status": "ok", "sql": sql, "message": f"No data available for the requested time period. Latest available data is from {latest}.", "data": [], "note": None } return { "status": "ok", "sql": sql, "message": interpret_empty_result(question), "data": [], "note": time_note } # 7. Normal response return { "status": "ok", "sql": sql, "columns": columns, "data": rows[:50], # demo safety limit "note": time_note }