Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import sqlite3 | |
| from openai import OpenAI | |
| from difflib import get_close_matches | |
| # ========================= | |
| # SETUP | |
| # ========================= | |
| # Validate API key | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable is not set") | |
| client = OpenAI(api_key=api_key) | |
| conn = sqlite3.connect("mimic_iv_demo.db", check_same_thread=False) | |
| # ========================= | |
| # CONVERSATION STATE | |
| # ========================= | |
| LAST_PROMPT_TYPE = None | |
| LAST_SUGGESTED_DATE = None | |
| # ========================= | |
| # HUMAN RESPONSE HELPERS | |
| # ========================= | |
| def humanize(text): | |
| return f"Sure \n\n{text}" | |
| def friendly(text): | |
| global LAST_SUGGESTED_DATE | |
| if LAST_SUGGESTED_DATE: | |
| return f"{text}\n\nLast data available is {LAST_SUGGESTED_DATE}" | |
| else: | |
| # If date not set yet, try to get it | |
| date = get_latest_data_date() | |
| if date: | |
| return f"{text}\n\nLast data available is {date}" | |
| return text | |
| def is_confirmation(text): | |
| return text.strip().lower() in ["yes", "yep", "yeah", "ok", "okay", "sure"] | |
| def is_why_question(text): | |
| return text.strip().lower().startswith("why") | |
| # ========================= | |
| # SPELL CORRECTION | |
| # ========================= | |
| KNOWN_TERMS = [ | |
| "patient", "patients", "condition", "conditions", | |
| "encounter", "encounters", "visit", "visits", | |
| "medication", "medications", | |
| "admitted", "admission", | |
| "year", "month", "last", "recent", "today" | |
| ] | |
| def correct_spelling(q): | |
| words = q.split() | |
| fixed = [] | |
| for w in words: | |
| clean = w.lower().strip(",.?") | |
| match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8) | |
| fixed.append(match[0] if match else w) | |
| return " ".join(fixed) | |
| # ========================= | |
| # SCHEMA | |
| # ========================= | |
| import json | |
| from functools import lru_cache | |
| def load_ai_schema(): | |
| """Load schema from metadata JSON file with error handling.""" | |
| try: | |
| with open("hospital_metadata.json", "r") as f: | |
| schema = json.load(f) | |
| if not isinstance(schema, dict): | |
| raise ValueError("Invalid metadata format: expected a dictionary") | |
| return schema | |
| except FileNotFoundError: | |
| raise FileNotFoundError("hospital_metadata.json file not found. Please create it with your table metadata.") | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"Invalid JSON in hospital_metadata.json: {str(e)}") | |
| except Exception as e: | |
| raise ValueError(f"Error loading metadata: {str(e)}") | |
| # ========================= | |
| # TABLE MATCHING (CORE LOGIC) | |
| # ========================= | |
| def extract_relevant_tables(question, max_tables=4): | |
| schema = load_ai_schema() | |
| q = question.lower() | |
| tokens = set(q.replace("?", "").replace(",", "").split()) | |
| matched = [] | |
| # Lightweight intent hints - dynamically filter to only include tables that exist | |
| # Map natural language terms to potential table names (check against schema) | |
| all_tables = list(schema.keys()) | |
| table_names_lower = [t.lower() for t in all_tables] | |
| DOMAIN_HINTS = {} | |
| # Build hints only for tables that actually exist | |
| hint_mappings = { | |
| "consultant": ["encounter", "encounters", "visit", "visits"], | |
| "doctor": ["encounter", "encounters", "provider", "providers"], | |
| "visit": ["encounter", "encounters", "visit", "visits"], | |
| "visited": ["encounter", "encounters", "visit", "visits"], | |
| "visits": ["encounter", "encounters", "visit", "visits"], | |
| "appointment": ["encounter", "encounters", "appointment", "appointments"], | |
| "patient": ["patient", "patients"], | |
| "medication": ["medication", "medications", "drug", "drugs"], | |
| "drug": ["medication", "medications", "drug", "drugs"], | |
| "condition": ["condition", "conditions", "diagnosis", "diagnoses"], | |
| "diagnosis": ["condition", "conditions", "diagnosis", "diagnoses"] | |
| } | |
| # Only include hints for tables that exist in the schema | |
| for intent, possible_tables in hint_mappings.items(): | |
| matching_tables = [t for t in possible_tables if t in table_names_lower] | |
| if matching_tables: | |
| DOMAIN_HINTS[intent] = matching_tables | |
| # Early exit threshold - if we find a perfect match, we can stop early | |
| VERY_HIGH_SCORE = 10 | |
| for table, meta in schema.items(): | |
| score = 0 | |
| table_l = table.lower() | |
| # 1️⃣ Strong signal: table name (exact match is very high confidence) | |
| if table_l in q: | |
| score += 6 | |
| # Early exit optimization: if exact table match found, prioritize it | |
| if score >= VERY_HIGH_SCORE: | |
| matched.append((table, score)) | |
| continue | |
| # 2️⃣ Column relevance | |
| for col, desc in meta["columns"].items(): | |
| col_l = col.lower() | |
| if col_l in q: | |
| score += 3 | |
| elif any(tok in col_l for tok in tokens): | |
| score += 1 | |
| # 3️⃣ Description relevance (less weight to avoid false positives) | |
| if meta.get("description"): | |
| desc_tokens = set(meta["description"].lower().split()) | |
| # Only count meaningful word matches, not common words | |
| common_words = {"the", "is", "at", "which", "on", "for", "a", "an"} | |
| meaningful_matches = tokens & desc_tokens - common_words | |
| if meaningful_matches: | |
| score += len(meaningful_matches) * 0.5 # Reduced weight | |
| # 4️⃣ Semantic intent mapping (important - highest priority) | |
| for intent, tables in DOMAIN_HINTS.items(): | |
| if intent in q and table_l in tables: | |
| score += 5 | |
| # 5️⃣ Only add if meets minimum threshold (prevents low-quality matches) | |
| # Use lower threshold for small schemas (more lenient) | |
| # Increased threshold from 3 to 4 for better precision, but lower to 2 for small schemas | |
| threshold = 2 if len(schema) <= 5 else 4 | |
| if score >= threshold: | |
| matched.append((table, score)) | |
| # Sort by relevance | |
| matched.sort(key=lambda x: x[1], reverse=True) | |
| # If no matches but schema is very small, return all tables (with lower confidence) | |
| if not matched and len(schema) <= 3: | |
| return list(schema.keys())[:max_tables] | |
| return [t[0] for t in matched[:max_tables]] | |
| # ========================= | |
| # HUMAN SCHEMA DESCRIPTION | |
| # ========================= | |
| def describe_schema(max_tables=10): | |
| schema = load_ai_schema() | |
| total_tables = len(schema) | |
| response = f"Here's the data I currently have access to ({total_tables} tables):\n\n" | |
| # Show only top N tables to avoid overwhelming output | |
| shown_tables = list(schema.items())[:max_tables] | |
| for table, meta in shown_tables: | |
| response += f"• **{table.capitalize()}** — {meta['description']}\n" | |
| # Show only first 5 columns per table | |
| for col, desc in list(meta["columns"].items())[:5]: | |
| response += f" - {col}: {desc}\n" | |
| if len(meta["columns"]) > 5: | |
| response += f" ... and {len(meta['columns']) - 5} more columns\n" | |
| response += "\n" | |
| if total_tables > max_tables: | |
| response += f"\n... and {total_tables - max_tables} more tables.\n" | |
| response += "Ask about a specific table to see its details.\n\n" | |
| response += ( | |
| "You can ask things like:\n" | |
| "• How many patients are there?\n" | |
| "• Patient count by gender\n" | |
| "• Admissions by year\n\n" | |
| "Just tell me what you want to explore " | |
| ) | |
| return response | |
| # ========================= | |
| # TIME HANDLING | |
| # ========================= | |
| def get_latest_data_date(): | |
| """Get the latest data date by checking tables with date columns.""" | |
| schema = load_ai_schema() | |
| # Common date column names to check | |
| date_columns = ["date", "start_date", "end_date", "admission_date", "admittime", "dischtime", "created_at", "updated_at"] | |
| # Try to find a table with a date column | |
| for table_name in schema.keys(): | |
| columns = schema[table_name].get("columns", {}) | |
| # Check if table has any date-like column | |
| for col_name in columns.keys(): | |
| col_lower = col_name.lower() | |
| if any(date_col in col_lower for date_col in date_columns): | |
| try: | |
| result = conn.execute( | |
| f"SELECT MAX({col_name}) FROM {table_name}" | |
| ).fetchone() | |
| if result and result[0]: | |
| return result[0] | |
| except (sqlite3.Error, sqlite3.OperationalError): | |
| continue # Try next table/column | |
| return None | |
| def normalize_time_question(q): | |
| latest = get_latest_data_date() | |
| if not latest: | |
| return q | |
| if "today" in q: | |
| return q.replace("today", f"on {latest[:10]}") | |
| if "yesterday" in q: | |
| return q.replace("yesterday", f"on {latest[:10]}") | |
| return q | |
| # ========================= | |
| # UNSUPPORTED QUESTIONS | |
| # ========================= | |
| def is_question_supported(question): | |
| q = question.lower() | |
| tokens = set(q.replace("?", "").replace(",", "").split()) | |
| # 1️⃣ Allow analytical intent even without table names | |
| analytic_keywords = { | |
| "count", "total", "average", "avg", "sum", | |
| "how many", "number of", "trend", | |
| "increase", "decrease", "compare", "more than", "less than" | |
| } | |
| if any(k in q for k in analytic_keywords): | |
| return True | |
| # 2️⃣ Check schema relevance (table-by-table) | |
| schema = load_ai_schema() | |
| for table, meta in schema.items(): | |
| score = 0 | |
| table_l = table.lower() | |
| # Table name match | |
| if table_l in q: | |
| score += 3 | |
| # Column name match | |
| for col, desc in meta["columns"].items(): | |
| col_l = col.lower() | |
| if col_l in q: | |
| score += 2 | |
| elif any(tok in col_l for tok in tokens): | |
| score += 1 | |
| # Description match | |
| if meta.get("description"): | |
| desc_tokens = set(meta["description"].lower().split()) | |
| score += len(tokens & desc_tokens) | |
| # ✅ If any table is relevant enough → supported | |
| if score >= 2: | |
| return True | |
| return False | |
| # ========================= | |
| # SQL GENERATION | |
| # ========================= | |
| def build_prompt(question): | |
| matched = extract_relevant_tables(question) | |
| full_schema = load_ai_schema() | |
| if matched: | |
| schema = {t: full_schema[t] for t in matched} | |
| else: | |
| # 🚫 Don't send all 100+ tables! Return a helpful error with available tables | |
| available_tables = list(full_schema.keys())[:10] # Show first 10 tables | |
| tables_list = "\n".join(f"- {t}" for t in available_tables) | |
| if len(full_schema) > 10: | |
| tables_list += f"\n... and {len(full_schema) - 10} more tables" | |
| raise ValueError( | |
| f"I couldn't find any relevant tables for your question.\n\n" | |
| f"Available tables:\n{tables_list}\n\n" | |
| f"Please try mentioning a specific table name or use 'what data' to see all available tables." | |
| ) | |
| prompt = """ | |
| You are a hospital SQL assistant. | |
| Rules: | |
| - Use only SELECT | |
| - SQLite syntax | |
| - Use ONLY the exact table names listed below (do not create or infer table names) | |
| - Use only listed tables/columns | |
| - Return ONLY SQL or NOT_ANSWERABLE | |
| IMPORTANT: Use EXACTLY the table names provided in the list below. Do not pluralize, modify, or guess table names. | |
| """ | |
| for table, meta in schema.items(): | |
| prompt += f"\nTable: {table}\n" | |
| for col, desc in meta["columns"].items(): | |
| prompt += f"- {col}: {desc}\n" | |
| prompt += f"\nQuestion: {question}\n" | |
| prompt += "\nRemember: Use EXACT table names from the list above. Do not pluralize or modify table names." | |
| return prompt | |
| def call_llm(prompt): | |
| """Call OpenAI API with error handling.""" | |
| try: | |
| res = client.chat.completions.create( | |
| model="gpt-4.1-mini", | |
| messages=[ | |
| {"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0 | |
| ) | |
| if not res.choices or not res.choices[0].message.content: | |
| raise ValueError("Empty response from OpenAI API") | |
| return res.choices[0].message.content.strip() | |
| except Exception as e: | |
| raise ValueError(f"OpenAI API error: {str(e)}") | |
| # ========================= | |
| # SQL SAFETY | |
| # ========================= | |
| def sanitize_sql(sql): | |
| # Remove code fence markers but preserve legitimate SQL | |
| sql = sql.replace("```sql", "").replace("```", "").strip() | |
| # Remove leading/trailing markdown code markers | |
| if sql.startswith("sql"): | |
| sql = sql[3:].strip() | |
| sql = sql.split(";")[0] | |
| return sql.replace("\n", " ").strip() | |
| def correct_table_names(sql): | |
| """Fix common table name mistakes in generated SQL.""" | |
| schema = load_ai_schema() | |
| valid_tables = set(schema.keys()) | |
| sql_lower = sql.lower() | |
| sql_corrected = sql | |
| # Common table name mappings (case-insensitive replacement) | |
| table_corrections = { | |
| "visits": "encounters", | |
| "visit": "encounters", | |
| "providers": "encounters", # if this table doesn't exist | |
| } | |
| # Check each correction | |
| for wrong_name, correct_name in table_corrections.items(): | |
| # Only correct if the wrong table doesn't exist AND correct one does | |
| if wrong_name.lower() not in valid_tables and correct_name.lower() in valid_tables: | |
| # Use word boundaries to avoid partial replacements | |
| pattern = r'\b' + re.escape(wrong_name) + r'\b' | |
| sql_corrected = re.sub(pattern, correct_name, sql_corrected, flags=re.IGNORECASE) | |
| return sql_corrected | |
| def validate_sql(sql): | |
| if not sql.lower().startswith("select"): | |
| raise ValueError("Only SELECT allowed") | |
| return sql | |
| def run_query(sql): | |
| """Execute SQL query with proper error handling.""" | |
| cur = conn.cursor() | |
| try: | |
| rows = cur.execute(sql).fetchall() | |
| if cur.description: | |
| cols = [c[0] for c in cur.description] | |
| else: | |
| cols = [] | |
| return cols, rows | |
| except sqlite3.Error as e: | |
| raise ValueError(f"Database query error: {str(e)}") | |
| # ========================= | |
| # AGGREGATE SAFETY | |
| # ========================= | |
| def is_aggregate_only_query(sql): | |
| s = sql.lower() | |
| return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s | |
| def has_underlying_data(sql): | |
| """Check if underlying data exists for the SQL query.""" | |
| base = sql.lower() | |
| if "from" not in base: | |
| return False | |
| base = base.split("from", 1)[1] | |
| # Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause | |
| for clause in ["group by", "order by", "limit", "having"]: | |
| base = base.split(clause)[0] | |
| test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1" | |
| cur = conn.cursor() | |
| try: | |
| return cur.execute(test_sql).fetchone() is not None | |
| except sqlite3.Error: | |
| return False | |
| # ========================= | |
| # PATIENT SUMMARY | |
| # ========================= | |
| def validate_identifier(name): | |
| """Validate that identifier is safe (only alphanumeric and underscores).""" | |
| if not name or not isinstance(name, str): | |
| return False | |
| # Check for SQL injection attempts | |
| forbidden = [";", "--", "/*", "*/", "'", '"', "`", "(", ")", " ", "\n", "\t"] | |
| if any(char in name for char in forbidden): | |
| return False | |
| # Must start with letter or underscore, rest alphanumeric/underscore | |
| return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name)) | |
| def build_table_summary(table_name): | |
| """Build summary for a table using metadata.""" | |
| # Validate table name against metadata first | |
| schema = load_ai_schema() | |
| if table_name not in schema: | |
| return f"Table {table_name} not found in metadata." | |
| # Additional safety check | |
| if not validate_identifier(table_name): | |
| return f"Invalid table name: {table_name}" | |
| cur = conn.cursor() | |
| # Total rows (still need to query actual data for count) | |
| # Note: SQLite doesn't support parameterized table names | |
| # Since we validated table_name against metadata, it's safe | |
| try: | |
| total = cur.execute( | |
| f"SELECT COUNT(*) FROM {table_name}" | |
| ).fetchone()[0] | |
| except sqlite3.Error as e: | |
| return f"Error querying table {table_name}: {str(e)}" | |
| columns = schema[table_name]["columns"] # {col_name: description, ...} | |
| summary = f"Here's a summary of **{table_name}**:\n\n" | |
| summary += f"• Total records: {total}\n" | |
| # Try to summarize categorical columns using metadata | |
| for col_name, col_desc in columns.items(): | |
| # Validate column name | |
| if not validate_identifier(col_name): | |
| continue | |
| # Try to determine if it's a categorical column based on name/description | |
| # Skip likely numeric/date columns | |
| col_lower = col_name.lower() | |
| if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]): | |
| continue | |
| # Try to get breakdown for text-like columns | |
| try: | |
| # Note: SQLite doesn't support parameterized identifiers, so we validate | |
| rows = cur.execute( | |
| f""" | |
| SELECT {col_name}, COUNT(*) | |
| FROM {table_name} | |
| GROUP BY {col_name} | |
| ORDER BY COUNT(*) DESC | |
| LIMIT 5 | |
| """ | |
| ).fetchall() | |
| if rows: | |
| summary += f"\n• {col_name.capitalize()} breakdown:\n" | |
| for val, count in rows: | |
| summary += f" - {val}: {count}\n" | |
| except (sqlite3.Error, sqlite3.OperationalError) as e: | |
| # Ignore columns that can't be grouped (likely not categorical) | |
| pass | |
| summary += "\nYou can ask more detailed questions about this data." | |
| return summary | |
| # ========================= | |
| # MAIN ENGINE | |
| # ========================= | |
| def process_question(question): | |
| global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE | |
| q = question.strip().lower() | |
| # ---------------------------------- | |
| # Normalize first | |
| # ---------------------------------- | |
| question = correct_spelling(question) | |
| question = normalize_time_question(question) | |
| LAST_PROMPT_TYPE = None | |
| LAST_SUGGESTED_DATE = None | |
| # ---------------------------------- | |
| # Handle "data updated till" | |
| # ---------------------------------- | |
| if any(x in q for x in ["updated", "upto", "up to", "latest data"]): | |
| return { | |
| "status": "ok", | |
| "message": f"Data is available up to {get_latest_data_date()}", | |
| "data": [] | |
| } | |
| # ---------------------------------- | |
| # Extract relevant tables | |
| # ---------------------------------- | |
| matched_tables = extract_relevant_tables(question) | |
| # ---------------------------------- | |
| # SUMMARY ONLY IF USER ASKS FOR IT | |
| # ---------------------------------- | |
| if ( | |
| len(matched_tables) == 1 | |
| and any(k in q for k in ["summary", "overview", "describe"]) | |
| and not any(k in q for k in ["count", "total", "how many", "average"]) | |
| ): | |
| return { | |
| "status": "ok", | |
| "message": build_table_summary(matched_tables[0]), | |
| "data": [] | |
| } | |
| # Only block if too many tables matched AND it's not an analytical question | |
| # Analytical questions (how many, count, etc.) often need multiple tables | |
| is_analytical = any(k in q for k in [ | |
| "how many", "count", "total", "number of", | |
| "average", "avg", "sum", "more than", "less than", | |
| "compare", "trend" | |
| ]) | |
| if len(matched_tables) > 4 and not is_analytical: | |
| return { | |
| "status": "ok", | |
| "message": ( | |
| "Your question matches too many datasets:\n" | |
| + "\n".join(f"- {t}" for t in matched_tables[:5]) | |
| + "\n\nPlease be more specific about what you want to know." | |
| ), | |
| "data": [] | |
| } | |
| # ---------------------------------- | |
| # Metadata discovery | |
| # ---------------------------------- | |
| if any(x in q for x in ["what data", "what tables", "which data"]): | |
| return { | |
| "status": "ok", | |
| "message": humanize(describe_schema()), | |
| "data": [] | |
| } | |
| # ---------------------------------- | |
| # Unsupported question check | |
| # ---------------------------------- | |
| if not is_question_supported(question): | |
| return { | |
| "status": "ok", | |
| "message": ( | |
| "That information isn’t available in the system.\n\n" | |
| "You can ask about:\n" | |
| "• Patients\n" | |
| "• Visits\n" | |
| "• Conditions\n" | |
| "• Medications" | |
| ), | |
| "data": [] | |
| } | |
| # ---------------------------------- | |
| # Generate SQL | |
| # ---------------------------------- | |
| try: | |
| sql = call_llm(build_prompt(question)) | |
| except ValueError as e: | |
| # Handle case where no relevant tables found | |
| return { | |
| "status": "ok", | |
| "message": str(e), | |
| "data": [] | |
| } | |
| if sql == "NOT_ANSWERABLE": | |
| return { | |
| "status": "ok", | |
| "message": "I don't have enough data to answer that.", | |
| "data": [] | |
| } | |
| # Sanitize, correct table names, then validate | |
| sql = sanitize_sql(sql) | |
| sql = correct_table_names(sql) | |
| sql = validate_sql(sql) | |
| cols, rows = run_query(sql) | |
| # ---------------------------------- | |
| # No data handling | |
| # ---------------------------------- | |
| if is_aggregate_only_query(sql) and not has_underlying_data(sql): | |
| LAST_PROMPT_TYPE = "NO_DATA" | |
| LAST_SUGGESTED_DATE = get_latest_data_date() | |
| return { | |
| "status": "ok", | |
| "message": friendly("No data is available for that time period."), | |
| "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.", | |
| "data": [] | |
| } | |
| if not rows: | |
| LAST_PROMPT_TYPE = "NO_DATA" | |
| LAST_SUGGESTED_DATE = get_latest_data_date() | |
| return { | |
| "status": "ok", | |
| "message": friendly("No records found."), | |
| "note": f"Available data is only up to {LAST_SUGGESTED_DATE}.", | |
| "data": [] | |
| } | |
| # ---------------------------------- | |
| # Success | |
| # ---------------------------------- | |
| return { | |
| "status": "ok", | |
| "sql": sql, | |
| "columns": cols, | |
| "data": rows | |
| } | |