import os import re import sqlite3 from openai import OpenAI from difflib import get_close_matches from datetime import datetime TRANSCRIPT = [] #memory log #store interaction in transcript def log_interaction(user_q, sql=None, result=None, error=None): TRANSCRIPT.append({ "timestamp": datetime.utcnow().isoformat(), "question": user_q, "sql": sql, "result_preview": result[:10] if isinstance(result, list) else result, "error": error }) # ========================= # SETUP # ========================= # Validate API key api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable is not set") client = OpenAI(api_key=api_key) conn = sqlite3.connect("mimic_iv.db", check_same_thread=False) # ========================= # CONVERSATION STATE # ========================= LAST_PROMPT_TYPE = None LAST_SUGGESTED_DATE = None # ========================= # HUMAN RESPONSE HELPERS # ========================= def humanize(text): return f"Sure \n\n{text}" def friendly(text): global LAST_SUGGESTED_DATE if LAST_SUGGESTED_DATE: return f"{text}\n\nLast data available is {LAST_SUGGESTED_DATE}" else: # If date not set yet, try to get it date = get_latest_data_date() if date: return f"{text}\n\nLast data available is {date}" return text def is_confirmation(text): return text.strip().lower() in ["yes", "yep", "yeah", "ok", "okay", "sure"] def is_why_question(text): return text.strip().lower().startswith("why") # ========================= # SPELL CORRECTION # ========================= KNOWN_TERMS = [ "patient", "patients", "admission", "admissions", "icu", "stay", "icustay", "diagnosis", "procedure", "medication", "lab", "year", "month", "recent", "today" ] def correct_spelling(q): words = q.split() fixed = [] for w in words: clean = w.lower().strip(",.?") match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8) fixed.append(match[0] if match else clean) return " ".join(fixed) # ========================= # SCHEMA # ========================= import json from functools import lru_cache def col_desc(desc):#extract description """Safely extract column description from metadata.""" if isinstance(desc, dict): return desc.get("description", "") return str(desc) @lru_cache(maxsize=1) def load_ai_schema(): #load metadata """Load schema from metadata JSON file with error handling.""" try: with open("metadata.json", "r") as f: schema = json.load(f) if not isinstance(schema, dict): raise ValueError("Invalid metadata format: expected a dictionary") return schema except FileNotFoundError: raise FileNotFoundError("metadata.json file not found. Please create it with your table metadata.") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in metadata.json: {str(e)}") except Exception as e: raise ValueError(f"Error loading metadata: {str(e)}") # ========================= # TABLE MATCHING (CORE LOGIC) # ========================= def extract_relevant_tables(question, max_tables=4): schema = load_ai_schema() q = question.lower() tokens = set(q.replace("?", "").replace(",", "").split()) matched = [] # Lightweight intent hints - dynamically filter to only include tables that exist # Map natural language terms to potential table names (check against schema) all_tables = list(schema.keys()) table_names_lower = [t.lower() for t in all_tables] DOMAIN_HINTS = {} # Build hints only for tables that actually exist hint_mappings = { # Patients & visits "patient": ["patients"], "patients": ["patients"], "admission": ["admissions"], "admissions": ["admissions"], "visit": ["admissions", "icustays"], "visits": ["admissions", "icustays"], # ICU "icu": ["icustays", "chartevents"], "stay": ["icustays"], "stays": ["icustays"], # Diagnoses / conditions "diagnosis": ["diagnoses_icd"], "diagnoses": ["diagnoses_icd"], "condition": ["diagnoses_icd"], "conditions": ["diagnoses_icd"], # Procedures "procedure": ["procedures_icd"], "procedures": ["procedures_icd"], # Medications "medication": ["prescriptions", "emar", "pharmacy"], "medications": ["prescriptions", "emar", "pharmacy"], "drug": ["prescriptions"], "drugs": ["prescriptions"], # Labs & vitals "lab": ["labevents"], "labs": ["labevents"], "vital": ["chartevents"], "vitals": ["chartevents"], } # Only include hints for tables that exist in the schema for intent, possible_tables in hint_mappings.items(): matching_tables = [t for t in possible_tables if t in table_names_lower] if matching_tables: DOMAIN_HINTS[intent] = matching_tables # Early exit threshold - if we find a perfect match, we can stop early VERY_HIGH_SCORE = 10 for table, meta in schema.items(): score = 0 table_l = table.lower() # 1️⃣ Strong signal: table name (exact match is very high confidence) if table_l in q: score += 6 # Early exit optimization: if exact table match found, prioritize it if score >= VERY_HIGH_SCORE: matched.append((table, score)) continue # 2️⃣ Column relevance for col, desc in meta["columns"].items(): desc_text = col_desc(desc) desc_tokens = set(desc_text.lower().split()) col_l = col.lower() if col_l in q: score += 3 elif any(tok in col_l for tok in tokens): score += 1 # 3️⃣ Description relevance (less weight to avoid false positives) if meta.get("description"): desc_tokens = set(col_desc(meta.get("description", "")).lower().split()) # Only count meaningful word matches, not common words common_words = {"the", "is", "at", "which", "on", "for", "a", "an"} meaningful_matches = tokens & desc_tokens - common_words if meaningful_matches: score += len(meaningful_matches) * 0.5 # Reduced weight # 4️⃣ Semantic intent mapping (important - highest priority) for intent, tables in DOMAIN_HINTS.items(): if intent in q and table_l in tables: score += 5 # 5️⃣ Only add if meets minimum threshold (prevents low-quality matches) # Use lower threshold for small schemas (more lenient) # Increased threshold from 3 to 4 for better precision, but lower to 2 for small schemas threshold = 2 if len(schema) <= 5 else 4 if score >= threshold: matched.append((table, score)) # Sort by relevance matched.sort(key=lambda x: x[1], reverse=True) # If no matches but schema is very small, return all tables (with lower confidence) if not matched and len(schema) <= 3: return list(schema.keys())[:max_tables] return [t[0] for t in matched[:max_tables]] # ========================= # HUMAN SCHEMA DESCRIPTION # ========================= def describe_schema(max_tables=10):#what data you have or which table exist schema = load_ai_schema() total_tables = len(schema) response = f"Here's the data I currently have access to ({total_tables} tables):\n\n" # Show only top N tables to avoid overwhelming output shown_tables = list(schema.items())[:max_tables] for table, meta in shown_tables: response += f"• **{table.capitalize()}** — {meta['description']}\n" # Show only first 5 columns per table for col, desc in list(meta["columns"].items())[:5]: response += f" - {col}: {col_desc(desc)}\n" if len(meta["columns"]) > 5: response += f" ... and {len(meta['columns']) - 5} more columns\n" response += "\n" if total_tables > max_tables: response += f"\n... and {total_tables - max_tables} more tables.\n" response += "Ask about a specific table to see its details.\n\n" response += ( "You can ask things like:\n" "• How many patients are there?\n" "• Patient count by gender\n" "• Admissions by year\n\n" "Just tell me what you want to explore " ) return response # ========================= # TIME HANDLING # ========================= def get_latest_data_date(): """ Returns the most meaningful 'latest date' for the system. Priority: 1. admissions.admittime 2. icustays.intime 3. chartevents.charttime """ checks = [ ("admissions", "admittime"), ("icustays", "intime"), ("chartevents", "charttime"), ] for table, column in checks: try: result = conn.execute( f"SELECT MAX({column}) FROM {table}" ).fetchone() if result and result[0]: return result[0] except Exception: continue return None def normalize_time_question(q):#total-actual date latest = get_latest_data_date() if not latest: return q if "today" in q: return q.replace("today", f"on {latest[:10]}") if "yesterday" in q: return q.replace("yesterday", f"on {latest[:10]}") return q # ========================= # UNSUPPORTED QUESTIONS # ========================= def is_question_supported(question): q = question.lower() tokens = set(q.replace("?", "").replace(",", "").split()) # 1️⃣ Allow analytical intent even without table names analytic_keywords = { "count", "total", "average", "avg", "sum", "how many", "number of", "trend", "increase", "decrease", "compare", "more than", "less than" } if any(k in q for k in analytic_keywords): return True # 2️⃣ Check schema relevance (table-by-table) schema = load_ai_schema() for table, meta in schema.items(): score = 0 table_l = table.lower() # Table name match if table_l in q: score += 3 # Column name match for col, desc in meta["columns"].items(): col_l = col.lower() if col_l in q: score += 2 elif any(tok in col_l for tok in tokens): score += 1 # Description match if meta.get("description"): desc_tokens = set(col_desc(meta["description"]).lower().split()) score += len(tokens & desc_tokens) # ✅ If any table is relevant enough → supported if score >= 2: return True return False # ========================= # SQL GENERATION # ========================= def build_prompt(question): matched = extract_relevant_tables(question) full_schema = load_ai_schema() if not matched: available_tables = list(full_schema.keys())[:10] tables_list = "\n".join(f"- {t}" for t in available_tables) if len(full_schema) > 10: tables_list += f"\n... and {len(full_schema) - 10} more tables" raise ValueError( "I couldn't find any relevant tables for your question.\n\n" f"Available tables:\n{tables_list}\n\n" "Try mentioning a table name or ask: 'what data is available?'" ) schema = {t: full_schema[t] for t in matched} IMPORTANT_COLS = { "subject_id", "hadm_id", "stay_id", "icustay_id", "itemid", "charttime", "starttime", "endtime" } prompt = """ You are an expert SQLite query generator. STRICT RULES: - Use ONLY the tables and columns listed below - NEVER invent table or column names - If the answer cannot be derived, return: NOT_ANSWERABLE - Do NOT explain the SQL - Do NOT wrap SQL in markdown - Use explicit JOIN conditions - Prefer COUNT(*) for totals Always use these joins: - patients.subject_id = admissions.subject_id - admissions.hadm_id = icustays.hadm_id - icustays.stay_id = chartevents.stay_id Schema: """ for table, meta in schema.items(): prompt += f"\nTable: {table}\n" for col, desc in meta["columns"].items(): text = f"{col} {col_desc(desc)}".lower() # Keep columns relevant to question if any(w in text for w in question.lower().split()): prompt += f"- {col}\n" # Always keep join / key columns elif col in IMPORTANT_COLS or col.endswith("_id"): prompt += f"- {col}\n" # Optional: help LLM with joins (very helpful for MIMIC) prompt += """ Join hints: - patients.subject_id ↔ admissions.subject_id - admissions.hadm_id ↔ icustays.hadm_id - icustays.stay_id ↔ chartevents.stay_id """ prompt += f"\nQuestion: {question}\n" prompt += "\nUse EXACT table and column names as shown above." # Safety cap if len(prompt) > 6000: prompt = prompt[:6000] + "\n\n# Schema truncated for safety\n" return prompt def call_llm(prompt): """Call OpenAI API with error handling.""" try: res = client.chat.completions.create( model="gpt-4.1-mini", messages=[ {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"}, {"role": "user", "content": prompt} ], temperature=0 ) if not res.choices or not res.choices[0].message.content: raise ValueError("Empty response from OpenAI API") return res.choices[0].message.content.strip() except Exception as e: raise ValueError(f"OpenAI API error: {str(e)}") # ========================= # SQL SAFETY # ========================= def sanitize_sql(sql): # Remove code fence markers but preserve legitimate SQL sql = sql.replace("```sql", "").replace("```", "").strip() # Remove leading/trailing markdown code markers if sql.startswith("sql"): sql = sql[3:].strip() sql = sql.split(";")[0] return sql.replace("\n", " ").strip() def correct_table_names(sql): schema = load_ai_schema() valid_tables = {t.lower() for t in schema.keys()} table_corrections = { "visit": "admissions", "visits": "admissions", "provider": "caregiver", "providers": "caregiver" } def replace_table(match): keyword = match.group(1) table = match.group(2) table_l = table.lower() if table_l in valid_tables: return match.group(0) if table_l in table_corrections: corrected = table_corrections[table_l] if corrected in valid_tables: return f"{keyword} {corrected}" return match.group(0) pattern = re.compile( r"\b(from|join)\s+([a-zA-Z_][a-zA-Z0-9_]*)", re.IGNORECASE ) return pattern.sub(replace_table, sql) def validate_sql(sql): if " join " in sql.lower() and " on " not in sql.lower(): raise ValueError("JOIN without ON condition is not allowed") if ";" in sql.strip()[:-1]: raise ValueError("Multiple SQL statements are not allowed") FORBIDDEN = ["insert", "update", "delete", "drop", "alter"] if any(k in sql.lower() for k in FORBIDDEN): raise ValueError("Unsafe SQL detected") if not sql.lower().startswith("select"): raise ValueError("Only SELECT allowed") return sql def run_query(sql): """Execute SQL query with proper error handling.""" cur = conn.cursor() try: rows = cur.execute(sql).fetchall() if cur.description: cols = [c[0] for c in cur.description] else: cols = [] return cols, rows except sqlite3.Error as e: raise ValueError(f"Database query error: {str(e)}") # ========================= # AGGREGATE SAFETY # ========================= def is_aggregate_only_query(sql): s = sql.lower() return ( any(fn in s for fn in ["count(", "sum(", "avg("]) and "group by" not in s and "over(" not in s ) def has_underlying_data(sql): """Check if underlying data exists for the SQL query.""" base = sql.lower() if "from" not in base: return False base = base.split("from", 1)[1] # Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause for clause in ["group by", "order by", "limit", "having"]: base = base.split(clause)[0] test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1" cur = conn.cursor() try: return cur.execute(test_sql).fetchone() is not None except sqlite3.Error: return False # ========================= # PATIENT SUMMARY # ========================= def validate_identifier(name): """Validate that identifier is safe (only alphanumeric and underscores).""" if not name or not isinstance(name, str): return False # Check for SQL injection attempts forbidden = [";", "--", "/*", "*/", "'", '"', "`", "(", ")", " ", "\n", "\t"] if any(char in name for char in forbidden): return False # Must start with letter or underscore, rest alphanumeric/underscore return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name)) def build_table_summary(table_name): """Build summary for a table using metadata.""" # Validate table name against metadata first schema = load_ai_schema() if table_name not in schema: return f"Table {table_name} not found in metadata." # Additional safety check if not validate_identifier(table_name): return f"Invalid table name: {table_name}" cur = conn.cursor() # Total rows (still need to query actual data for count) # Note: SQLite doesn't support parameterized table names # Since we validated table_name against metadata, it's safe try: total = cur.execute( f"SELECT COUNT(*) FROM {table_name}" ).fetchone()[0] except sqlite3.Error as e: return f"Error querying table {table_name}: {str(e)}" columns = schema[table_name]["columns"] # {col_name: description, ...} summary = f"Here's a summary of **{table_name}**:\n\n" summary += f"• Total records: {total}\n" # Try to summarize categorical columns using metadata for col_name, col_desc in columns.items(): # Validate column name if not validate_identifier(col_name): continue # Try to determine if it's a categorical column based on name/description # Skip likely numeric/date columns col_lower = col_name.lower() if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]): continue # Try to get breakdown for text-like columns try: # Note: SQLite doesn't support parameterized identifiers, so we validate rows = cur.execute( f""" SELECT {col_name}, COUNT(*) FROM {table_name} GROUP BY {col_name} ORDER BY COUNT(*) DESC LIMIT 5 """ ).fetchall() if rows: summary += f"\n• {col_name.capitalize()} breakdown:\n" for val, count in rows: summary += f" - {val}: {count}\n" except (sqlite3.Error, sqlite3.OperationalError) as e: # Ignore columns that can't be grouped (likely not categorical) pass summary += "\nYou can ask more detailed questions about this data." return summary # ========================= # MAIN ENGINE # ========================= def process_question(question): global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE q = question.strip().lower() # ---------------------------------- # Normalize first # ---------------------------------- question = correct_spelling(question) question = normalize_time_question(question) LAST_PROMPT_TYPE = None LAST_SUGGESTED_DATE = None # ---------------------------------- # Handle "data updated till" # ---------------------------------- if any(x in q for x in ["updated", "upto", "up to", "latest data"]): return { "status": "ok", "message": f"Data is available up to {get_latest_data_date()}", "data": [] } # ---------------------------------- # Extract relevant tables # ---------------------------------- matched_tables = extract_relevant_tables(question) # ---------------------------------- # SUMMARY ONLY IF USER ASKS FOR IT # ---------------------------------- if ( len(matched_tables) == 1 and any(k in q for k in ["summary", "overview", "describe"]) and not any(k in q for k in ["count", "total", "how many", "average"]) ): return { "status": "ok", "message": build_table_summary(matched_tables[0]), "data": [] } # Only block if too many tables matched AND it's not an analytical question # Analytical questions (how many, count, etc.) often need multiple tables is_analytical = any(k in q for k in [ "how many", "count", "total", "number of", "average", "avg", "sum", "more than", "less than", "compare", "trend" ]) if len(matched_tables) > 4 and not is_analytical: return { "status": "ok", "message": ( "Your question matches too many datasets:\n" + "\n".join(f"- {t}" for t in matched_tables[:5]) + "\n\nPlease be more specific about what you want to know." ), "data": [] } # ---------------------------------- # Metadata discovery # ---------------------------------- if any(x in q for x in ["what data", "what tables", "which data"]): return { "status": "ok", "message": humanize(describe_schema()), "data": [] } # ---------------------------------- # # LAST DATA / RECENT DATA HANDLING # # ---------------------------------- if any(x in q for x in ["last data", "latest data"]): return { "status": "ok", "message": f"Latest data available is from {get_latest_data_date()}", "data": [] } if "last" in q and "day" in q and ("visit" in q or "admission" in q): sql = """ SELECT subject_id, admittime FROM admissions WHERE admittime >= date( (SELECT MAX(admittime) FROM admissions), '-30 days' ) ORDER BY admittime DESC """ cols, rows = run_query(sql) log_interaction( user_q=question, sql=sql, result=rows ) return { "status": "ok", "sql": sql, "columns": cols, "data": rows } # ---------------------------------- # Unsupported question check # ---------------------------------- if not is_question_supported(question): log_interaction( user_q=question, error="Unsupported question" ) return { "status": "ok", "message": ( "That information isn’t available in the system.\n\n" "You can ask about:\n" "• Patients\n" "• Admissions / Visits\n" "• ICU stays\n" "• Diagnoses / Conditions\n" "• Vitals & lab measurements" ), "data": [] } # ---------------------------------- # Generate SQL # ---------------------------------- try: sql = call_llm(build_prompt(question)) except ValueError as e: log_interaction( user_q=question, error=str(e) ) return { "status": "ok", "message": str(e), "data": [] } if sql == "NOT_ANSWERABLE": return { "status": "ok", "message": "I don't have enough data to answer that.", "data": [] } # Sanitize, correct table names, then validate sql = sanitize_sql(sql) sql = correct_table_names(sql) sql = validate_sql(sql) cols, rows = run_query(sql) # ✅ LOG ONCE (THIS FIXES YOUR DOWNLOAD ISSUE) log_interaction( user_q=question, sql=sql, result=rows ) if not rows: return { "status": "ok", "message": friendly("No records found."), "data": [] } return { "status": "ok", "sql": sql, "columns": cols, "data": rows } # ---------------------------------- # No data handling # ---------------------------------- if is_aggregate_only_query(sql) and not has_underlying_data(sql): LAST_PROMPT_TYPE = "NO_DATA" LAST_SUGGESTED_DATE = get_latest_data_date() return { "status": "ok", "message": friendly("No data is available for that time period."), "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.", "data": [] } if not rows: log_interaction( user_q=question, sql=sql, result=[] ) LAST_PROMPT_TYPE = "NO_DATA" LAST_SUGGESTED_DATE = get_latest_data_date() return { "status": "ok", "message": friendly("No records found."), "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.", "data": [] } # ---------------------------------- # Success # ---------------------------------- return { "status": "ok", "sql": sql, "columns": cols, "data": rows }