Spaces:
Sleeping
Sleeping
| import os | |
| import sqlite3 | |
| from openai import OpenAI | |
| from difflib import get_close_matches | |
| # ========================= | |
| # Setup | |
| # ========================= | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| conn = sqlite3.connect("hospital.db", check_same_thread=False) | |
| # ========================= | |
| # Known Terms for Spell Correction | |
| # ========================= | |
| KNOWN_TERMS = [ | |
| "patient", "patients", "condition", "conditions", "diagnosis", "encounter", "encounters", | |
| "visit", "visits", "observation", "observations", "lab", "labs", "test", "tests", | |
| "medication", "medications", "drug", "drugs", "prescription", "prescriptions", | |
| "diabetes", "hypertension", "asthma", "cancer", "admitted", "admission" | |
| ] | |
| def correct_spelling(question: str) -> str: | |
| words = question.split() | |
| corrected_words = [] | |
| for word in words: | |
| clean_word = word.lower().strip(",.?") | |
| matches = get_close_matches(clean_word, KNOWN_TERMS, n=1, cutoff=0.8) | |
| if matches: | |
| corrected_words.append(matches[0]) | |
| else: | |
| corrected_words.append(word) | |
| return " ".join(corrected_words) | |
| # ========================= | |
| # Metadata Loader | |
| # ========================= | |
| def load_ai_schema(): | |
| cur = conn.cursor() | |
| schema = {} | |
| tables = cur.execute(""" | |
| SELECT table_name, description | |
| FROM ai_tables | |
| WHERE ai_enabled = 1 | |
| """).fetchall() | |
| for table_name, desc in tables: | |
| cols = cur.execute(""" | |
| SELECT column_name, description | |
| FROM ai_columns | |
| WHERE table_name = ? AND ai_allowed = 1 | |
| """, (table_name,)).fetchall() | |
| schema[table_name] = { | |
| "description": desc, | |
| "columns": cols | |
| } | |
| return schema | |
| # ========================= | |
| # Prompt Builder | |
| # ========================= | |
| def build_prompt(question: str) -> str: | |
| schema = load_ai_schema() | |
| prompt = """ | |
| You are a hospital data assistant. | |
| Rules: | |
| - Generate only SELECT SQL queries. | |
| - Use only the tables and columns provided. | |
| - Do not invent tables or columns. | |
| - This database is SQLite. Use SQLite-compatible date functions. | |
| - For recent days use: date('now', '-N day') | |
| - Use case-insensitive matching for text fields. | |
| - Prefer LIKE with wildcards for medical condition names. | |
| - Use COUNT, AVG, MIN, MAX, GROUP BY when the question asks for totals, averages, or comparisons. | |
| - If the question cannot be answered using the schema, return NOT_ANSWERABLE. | |
| - Do not explain the query. | |
| - Return only SQL or NOT_ANSWERABLE. | |
| Available schema: | |
| """ | |
| for table, meta in schema.items(): | |
| prompt += f"\nTable: {table} - {meta['description']}\n" | |
| for col, desc in meta["columns"]: | |
| prompt += f" - {col}: {desc}\n" | |
| prompt += f"\nUser question: {question}\n" | |
| return prompt | |
| # ========================= | |
| # LLM Call | |
| # ========================= | |
| def call_llm(prompt: str) -> str: | |
| response = client.chat.completions.create( | |
| model="gpt-4.1-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are a SQL generator. Return only SQL. No explanation."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.0 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # ========================= | |
| # SQL Generation | |
| # ========================= | |
| def generate_sql(question: str) -> str: | |
| prompt = build_prompt(question) | |
| sql = call_llm(prompt) | |
| return sql.strip() | |
| # ========================= | |
| # SQL Cleaning & Validation | |
| # ========================= | |
| def clean_sql(sql: str) -> str: | |
| sql = sql.strip() | |
| # Remove markdown code fences if present | |
| if sql.startswith("```"): | |
| parts = sql.split("```") | |
| if len(parts) > 1: | |
| sql = parts[1] | |
| sql = sql.replace("sql\n", "").strip() | |
| return sql | |
| def validate_sql(sql: str) -> str: | |
| sql = clean_sql(sql) | |
| s = sql.lower() | |
| forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"] | |
| if not s.startswith("select"): | |
| raise Exception("Only SELECT queries allowed") | |
| if any(f in s for f in forbidden): | |
| raise Exception("Forbidden SQL operation detected") | |
| return sql | |
| # ========================= | |
| # Query Runner | |
| # ========================= | |
| def run_query(sql: str): | |
| cur = conn.cursor() | |
| result = cur.execute(sql).fetchall() | |
| columns = [desc[0] for desc in cur.description] | |
| return columns, result | |
| # ========================= | |
| # Guardrails | |
| # ========================= | |
| def is_question_answerable(question): | |
| keywords = [ | |
| "patient", "encounter", "condition", "observation", | |
| "medication", "visit", "diagnosis", "lab", "vital", "admitted" | |
| ] | |
| q = question.lower() | |
| if not any(k in q for k in keywords): | |
| return False | |
| return True | |
| # ========================= | |
| # Time Awareness | |
| # ========================= | |
| def get_latest_data_date(): | |
| sql = "SELECT MAX(start_date) FROM encounters;" | |
| _, rows = run_query(sql) | |
| return rows[0][0] | |
| def check_time_relevance(question: str): | |
| q = question.lower() | |
| if any(word in q for word in ["last", "recent", "today", "this month", "this year"]): | |
| latest = get_latest_data_date() | |
| return f"Latest available data is from {latest}." | |
| return None | |
| # ========================= | |
| # Empty Result Interpreter | |
| # ========================= | |
| def interpret_empty_result(question: str): | |
| latest = get_latest_data_date() | |
| return f"No results found. Available data is up to {latest}." | |
| # ========================= | |
| # ORCHESTRATOR (Single Entry Point) | |
| # ========================= | |
| def process_question(question: str): | |
| # 0. Spell correction | |
| question = correct_spelling(question) | |
| # 1. Guardrail | |
| if not is_question_answerable(question): | |
| return { | |
| "status": "rejected", | |
| "message": "This question is not supported by the available data." | |
| } | |
| # 2. Time relevance | |
| time_note = check_time_relevance(question) | |
| # 3. Generate SQL | |
| sql = generate_sql(question) | |
| # 4. Validate SQL | |
| sql = validate_sql(sql) | |
| # 5. Execute query | |
| columns, rows = run_query(sql) | |
| # 6. Handle empty result with data coverage awareness | |
| if len(rows) == 0: | |
| latest = get_latest_data_date() | |
| q = question.lower() | |
| if any(word in q for word in ["last", "recent", "this month", "this year"]): | |
| return { | |
| "status": "ok", | |
| "sql": sql, | |
| "message": f"No data available for the requested time period. Latest available data is from {latest}.", | |
| "data": [], | |
| "note": None | |
| } | |
| return { | |
| "status": "ok", | |
| "sql": sql, | |
| "message": interpret_empty_result(question), | |
| "data": [], | |
| "note": time_note | |
| } | |
| # 7. Normal response | |
| return { | |
| "status": "ok", | |
| "sql": sql, | |
| "columns": columns, | |
| "data": rows[:50], # demo safety limit | |
| "note": time_note | |
| } | |