import os import re import sqlite3 from openai import OpenAI from difflib import get_close_matches # ========================= # SETUP # ========================= # Validate API key api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable is not set") client = OpenAI(api_key=api_key) conn = sqlite3.connect("mimic_iv_demo.db", check_same_thread=False) # ========================= # CONVERSATION STATE # ========================= LAST_PROMPT_TYPE = None LAST_SUGGESTED_DATE = None # ========================= # HUMAN RESPONSE HELPERS # ========================= def humanize(text): return f"Sure \n\n{text}" def friendly(text): global LAST_SUGGESTED_DATE if LAST_SUGGESTED_DATE: return f"{text}\n\nLast data available is {LAST_SUGGESTED_DATE}" else: # If date not set yet, try to get it date = get_latest_data_date() if date: return f"{text}\n\nLast data available is {date}" return text def is_confirmation(text): return text.strip().lower() in ["yes", "yep", "yeah", "ok", "okay", "sure"] def is_why_question(text): return text.strip().lower().startswith("why") # ========================= # SPELL CORRECTION # ========================= KNOWN_TERMS = [ "patient", "patients", "condition", "conditions", "encounter", "encounters", "visit", "visits", "medication", "medications", "admitted", "admission", "year", "month", "last", "recent", "today" ] def correct_spelling(q): words = q.split() fixed = [] for w in words: clean = w.lower().strip(",.?") match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8) fixed.append(match[0] if match else w) return " ".join(fixed) # ========================= # SCHEMA # ========================= import json from functools import lru_cache @lru_cache(maxsize=1) def load_ai_schema(): """Load schema from metadata JSON file with error handling.""" try: with open("hospital_metadata.json", "r") as f: schema = json.load(f) if not isinstance(schema, dict): raise ValueError("Invalid metadata format: expected a dictionary") return schema except FileNotFoundError: raise FileNotFoundError("hospital_metadata.json file not found. Please create it with your table metadata.") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in hospital_metadata.json: {str(e)}") except Exception as e: raise ValueError(f"Error loading metadata: {str(e)}") # ========================= # TABLE MATCHING (CORE LOGIC) # ========================= def extract_relevant_tables(question, max_tables=4): schema = load_ai_schema() q = question.lower() tokens = set(q.replace("?", "").replace(",", "").split()) matched = [] # Lightweight intent hints - dynamically filter to only include tables that exist # Map natural language terms to potential table names (check against schema) all_tables = list(schema.keys()) table_names_lower = [t.lower() for t in all_tables] DOMAIN_HINTS = {} # Build hints only for tables that actually exist hint_mappings = { "consultant": ["encounter", "encounters", "visit", "visits"], "doctor": ["encounter", "encounters", "provider", "providers"], "visit": ["encounter", "encounters", "visit", "visits"], "visited": ["encounter", "encounters", "visit", "visits"], "visits": ["encounter", "encounters", "visit", "visits"], "appointment": ["encounter", "encounters", "appointment", "appointments"], "patient": ["patient", "patients"], "medication": ["medication", "medications", "drug", "drugs"], "drug": ["medication", "medications", "drug", "drugs"], "condition": ["condition", "conditions", "diagnosis", "diagnoses"], "diagnosis": ["condition", "conditions", "diagnosis", "diagnoses"] } # Only include hints for tables that exist in the schema for intent, possible_tables in hint_mappings.items(): matching_tables = [t for t in possible_tables if t in table_names_lower] if matching_tables: DOMAIN_HINTS[intent] = matching_tables # Early exit threshold - if we find a perfect match, we can stop early VERY_HIGH_SCORE = 10 for table, meta in schema.items(): score = 0 table_l = table.lower() # 1️⃣ Strong signal: table name (exact match is very high confidence) if table_l in q: score += 6 # Early exit optimization: if exact table match found, prioritize it if score >= VERY_HIGH_SCORE: matched.append((table, score)) continue # 2️⃣ Column relevance for col, desc in meta["columns"].items(): col_l = col.lower() if col_l in q: score += 3 elif any(tok in col_l for tok in tokens): score += 1 # 3️⃣ Description relevance (less weight to avoid false positives) if meta.get("description"): desc_tokens = set(meta["description"].lower().split()) # Only count meaningful word matches, not common words common_words = {"the", "is", "at", "which", "on", "for", "a", "an"} meaningful_matches = tokens & desc_tokens - common_words if meaningful_matches: score += len(meaningful_matches) * 0.5 # Reduced weight # 4️⃣ Semantic intent mapping (important - highest priority) for intent, tables in DOMAIN_HINTS.items(): if intent in q and table_l in tables: score += 5 # 5️⃣ Only add if meets minimum threshold (prevents low-quality matches) # Use lower threshold for small schemas (more lenient) # Increased threshold from 3 to 4 for better precision, but lower to 2 for small schemas threshold = 2 if len(schema) <= 5 else 4 if score >= threshold: matched.append((table, score)) # Sort by relevance matched.sort(key=lambda x: x[1], reverse=True) # If no matches but schema is very small, return all tables (with lower confidence) if not matched and len(schema) <= 3: return list(schema.keys())[:max_tables] return [t[0] for t in matched[:max_tables]] # ========================= # HUMAN SCHEMA DESCRIPTION # ========================= def describe_schema(max_tables=10): schema = load_ai_schema() total_tables = len(schema) response = f"Here's the data I currently have access to ({total_tables} tables):\n\n" # Show only top N tables to avoid overwhelming output shown_tables = list(schema.items())[:max_tables] for table, meta in shown_tables: response += f"• **{table.capitalize()}** — {meta['description']}\n" # Show only first 5 columns per table for col, desc in list(meta["columns"].items())[:5]: response += f" - {col}: {desc}\n" if len(meta["columns"]) > 5: response += f" ... and {len(meta['columns']) - 5} more columns\n" response += "\n" if total_tables > max_tables: response += f"\n... and {total_tables - max_tables} more tables.\n" response += "Ask about a specific table to see its details.\n\n" response += ( "You can ask things like:\n" "• How many patients are there?\n" "• Patient count by gender\n" "• Admissions by year\n\n" "Just tell me what you want to explore " ) return response # ========================= # TIME HANDLING # ========================= def get_latest_data_date(): """Get the latest data date by checking tables with date columns.""" schema = load_ai_schema() # Common date column names to check date_columns = ["date", "start_date", "end_date", "admission_date", "admittime", "dischtime", "created_at", "updated_at"] # Try to find a table with a date column for table_name in schema.keys(): columns = schema[table_name].get("columns", {}) # Check if table has any date-like column for col_name in columns.keys(): col_lower = col_name.lower() if any(date_col in col_lower for date_col in date_columns): try: result = conn.execute( f"SELECT MAX({col_name}) FROM {table_name}" ).fetchone() if result and result[0]: return result[0] except (sqlite3.Error, sqlite3.OperationalError): continue # Try next table/column return None def normalize_time_question(q): latest = get_latest_data_date() if not latest: return q if "today" in q: return q.replace("today", f"on {latest[:10]}") if "yesterday" in q: return q.replace("yesterday", f"on {latest[:10]}") return q # ========================= # UNSUPPORTED QUESTIONS # ========================= def is_question_supported(question): q = question.lower() tokens = set(q.replace("?", "").replace(",", "").split()) # 1️⃣ Allow analytical intent even without table names analytic_keywords = { "count", "total", "average", "avg", "sum", "how many", "number of", "trend", "increase", "decrease", "compare", "more than", "less than" } if any(k in q for k in analytic_keywords): return True # 2️⃣ Check schema relevance (table-by-table) schema = load_ai_schema() for table, meta in schema.items(): score = 0 table_l = table.lower() # Table name match if table_l in q: score += 3 # Column name match for col, desc in meta["columns"].items(): col_l = col.lower() if col_l in q: score += 2 elif any(tok in col_l for tok in tokens): score += 1 # Description match if meta.get("description"): desc_tokens = set(meta["description"].lower().split()) score += len(tokens & desc_tokens) # ✅ If any table is relevant enough → supported if score >= 2: return True return False # ========================= # SQL GENERATION # ========================= def build_prompt(question): matched = extract_relevant_tables(question) full_schema = load_ai_schema() if matched: schema = {t: full_schema[t] for t in matched} else: # 🚫 Don't send all 100+ tables! Return a helpful error with available tables available_tables = list(full_schema.keys())[:10] # Show first 10 tables tables_list = "\n".join(f"- {t}" for t in available_tables) if len(full_schema) > 10: tables_list += f"\n... and {len(full_schema) - 10} more tables" raise ValueError( f"I couldn't find any relevant tables for your question.\n\n" f"Available tables:\n{tables_list}\n\n" f"Please try mentioning a specific table name or use 'what data' to see all available tables." ) prompt = """ You are a hospital SQL assistant. Rules: - Use only SELECT - SQLite syntax - Use ONLY the exact table names listed below (do not create or infer table names) - Use only listed tables/columns - Return ONLY SQL or NOT_ANSWERABLE IMPORTANT: Use EXACTLY the table names provided in the list below. Do not pluralize, modify, or guess table names. """ for table, meta in schema.items(): prompt += f"\nTable: {table}\n" for col, desc in meta["columns"].items(): prompt += f"- {col}: {desc}\n" prompt += f"\nQuestion: {question}\n" prompt += "\nRemember: Use EXACT table names from the list above. Do not pluralize or modify table names." return prompt def call_llm(prompt): """Call OpenAI API with error handling.""" try: res = client.chat.completions.create( model="gpt-4.1-mini", messages=[ {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"}, {"role": "user", "content": prompt} ], temperature=0 ) if not res.choices or not res.choices[0].message.content: raise ValueError("Empty response from OpenAI API") return res.choices[0].message.content.strip() except Exception as e: raise ValueError(f"OpenAI API error: {str(e)}") # ========================= # SQL SAFETY # ========================= def sanitize_sql(sql): # Remove code fence markers but preserve legitimate SQL sql = sql.replace("```sql", "").replace("```", "").strip() # Remove leading/trailing markdown code markers if sql.startswith("sql"): sql = sql[3:].strip() sql = sql.split(";")[0] return sql.replace("\n", " ").strip() def correct_table_names(sql): """Fix common table name mistakes in generated SQL.""" schema = load_ai_schema() valid_tables = set(schema.keys()) sql_lower = sql.lower() sql_corrected = sql # Common table name mappings (case-insensitive replacement) table_corrections = { "visits": "encounters", "visit": "encounters", "providers": "encounters", # if this table doesn't exist } # Check each correction for wrong_name, correct_name in table_corrections.items(): # Only correct if the wrong table doesn't exist AND correct one does if wrong_name.lower() not in valid_tables and correct_name.lower() in valid_tables: # Use word boundaries to avoid partial replacements pattern = r'\b' + re.escape(wrong_name) + r'\b' sql_corrected = re.sub(pattern, correct_name, sql_corrected, flags=re.IGNORECASE) return sql_corrected def validate_sql(sql): if not sql.lower().startswith("select"): raise ValueError("Only SELECT allowed") return sql def run_query(sql): """Execute SQL query with proper error handling.""" cur = conn.cursor() try: rows = cur.execute(sql).fetchall() if cur.description: cols = [c[0] for c in cur.description] else: cols = [] return cols, rows except sqlite3.Error as e: raise ValueError(f"Database query error: {str(e)}") # ========================= # AGGREGATE SAFETY # ========================= def is_aggregate_only_query(sql): s = sql.lower() return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s def has_underlying_data(sql): """Check if underlying data exists for the SQL query.""" base = sql.lower() if "from" not in base: return False base = base.split("from", 1)[1] # Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause for clause in ["group by", "order by", "limit", "having"]: base = base.split(clause)[0] test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1" cur = conn.cursor() try: return cur.execute(test_sql).fetchone() is not None except sqlite3.Error: return False # ========================= # PATIENT SUMMARY # ========================= def validate_identifier(name): """Validate that identifier is safe (only alphanumeric and underscores).""" if not name or not isinstance(name, str): return False # Check for SQL injection attempts forbidden = [";", "--", "/*", "*/", "'", '"', "`", "(", ")", " ", "\n", "\t"] if any(char in name for char in forbidden): return False # Must start with letter or underscore, rest alphanumeric/underscore return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name)) def build_table_summary(table_name): """Build summary for a table using metadata.""" # Validate table name against metadata first schema = load_ai_schema() if table_name not in schema: return f"Table {table_name} not found in metadata." # Additional safety check if not validate_identifier(table_name): return f"Invalid table name: {table_name}" cur = conn.cursor() # Total rows (still need to query actual data for count) # Note: SQLite doesn't support parameterized table names # Since we validated table_name against metadata, it's safe try: total = cur.execute( f"SELECT COUNT(*) FROM {table_name}" ).fetchone()[0] except sqlite3.Error as e: return f"Error querying table {table_name}: {str(e)}" columns = schema[table_name]["columns"] # {col_name: description, ...} summary = f"Here's a summary of **{table_name}**:\n\n" summary += f"• Total records: {total}\n" # Try to summarize categorical columns using metadata for col_name, col_desc in columns.items(): # Validate column name if not validate_identifier(col_name): continue # Try to determine if it's a categorical column based on name/description # Skip likely numeric/date columns col_lower = col_name.lower() if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]): continue # Try to get breakdown for text-like columns try: # Note: SQLite doesn't support parameterized identifiers, so we validate rows = cur.execute( f""" SELECT {col_name}, COUNT(*) FROM {table_name} GROUP BY {col_name} ORDER BY COUNT(*) DESC LIMIT 5 """ ).fetchall() if rows: summary += f"\n• {col_name.capitalize()} breakdown:\n" for val, count in rows: summary += f" - {val}: {count}\n" except (sqlite3.Error, sqlite3.OperationalError) as e: # Ignore columns that can't be grouped (likely not categorical) pass summary += "\nYou can ask more detailed questions about this data." return summary # ========================= # MAIN ENGINE # ========================= def process_question(question): global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE q = question.strip().lower() # ---------------------------------- # Normalize first # ---------------------------------- question = correct_spelling(question) question = normalize_time_question(question) LAST_PROMPT_TYPE = None LAST_SUGGESTED_DATE = None # ---------------------------------- # Handle "data updated till" # ---------------------------------- if any(x in q for x in ["updated", "upto", "up to", "latest data"]): return { "status": "ok", "message": f"Data is available up to {get_latest_data_date()}", "data": [] } # ---------------------------------- # Extract relevant tables # ---------------------------------- matched_tables = extract_relevant_tables(question) # ---------------------------------- # SUMMARY ONLY IF USER ASKS FOR IT # ---------------------------------- if ( len(matched_tables) == 1 and any(k in q for k in ["summary", "overview", "describe"]) and not any(k in q for k in ["count", "total", "how many", "average"]) ): return { "status": "ok", "message": build_table_summary(matched_tables[0]), "data": [] } # Only block if too many tables matched AND it's not an analytical question # Analytical questions (how many, count, etc.) often need multiple tables is_analytical = any(k in q for k in [ "how many", "count", "total", "number of", "average", "avg", "sum", "more than", "less than", "compare", "trend" ]) if len(matched_tables) > 4 and not is_analytical: return { "status": "ok", "message": ( "Your question matches too many datasets:\n" + "\n".join(f"- {t}" for t in matched_tables[:5]) + "\n\nPlease be more specific about what you want to know." ), "data": [] } # ---------------------------------- # Metadata discovery # ---------------------------------- if any(x in q for x in ["what data", "what tables", "which data"]): return { "status": "ok", "message": humanize(describe_schema()), "data": [] } # ---------------------------------- # Unsupported question check # ---------------------------------- if not is_question_supported(question): return { "status": "ok", "message": ( "That information isn’t available in the system.\n\n" "You can ask about:\n" "• Patients\n" "• Visits\n" "• Conditions\n" "• Medications" ), "data": [] } # ---------------------------------- # Generate SQL # ---------------------------------- try: sql = call_llm(build_prompt(question)) except ValueError as e: # Handle case where no relevant tables found return { "status": "ok", "message": str(e), "data": [] } if sql == "NOT_ANSWERABLE": return { "status": "ok", "message": "I don't have enough data to answer that.", "data": [] } # Sanitize, correct table names, then validate sql = sanitize_sql(sql) sql = correct_table_names(sql) sql = validate_sql(sql) cols, rows = run_query(sql) # ---------------------------------- # No data handling # ---------------------------------- if is_aggregate_only_query(sql) and not has_underlying_data(sql): LAST_PROMPT_TYPE = "NO_DATA" LAST_SUGGESTED_DATE = get_latest_data_date() return { "status": "ok", "message": friendly("No data is available for that time period."), "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.", "data": [] } if not rows: LAST_PROMPT_TYPE = "NO_DATA" LAST_SUGGESTED_DATE = get_latest_data_date() return { "status": "ok", "message": friendly("No records found."), "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.", "data": [] } # ---------------------------------- # Success # ---------------------------------- return { "status": "ok", "sql": sql, "columns": cols, "data": rows }