Text_to_sql / engine.py
bhavika24's picture
Upload engine.py
2a990de verified
raw
history blame
23.6 kB
import os
import re
import sqlite3
from openai import OpenAI
from difflib import get_close_matches
# =========================
# SETUP
# =========================
# Validate API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
client = OpenAI(api_key=api_key)
conn = sqlite3.connect("mimic_iv_demo.db", check_same_thread=False)
# =========================
# CONVERSATION STATE
# =========================
LAST_PROMPT_TYPE = None
LAST_SUGGESTED_DATE = None
# =========================
# HUMAN RESPONSE HELPERS
# =========================
def humanize(text):
return f"Sure \n\n{text}"
def friendly(text):
global LAST_SUGGESTED_DATE
if LAST_SUGGESTED_DATE:
return f"{text}\n\nLast data available is {LAST_SUGGESTED_DATE}"
else:
# If date not set yet, try to get it
date = get_latest_data_date()
if date:
return f"{text}\n\nLast data available is {date}"
return text
def is_confirmation(text):
return text.strip().lower() in ["yes", "yep", "yeah", "ok", "okay", "sure"]
def is_why_question(text):
return text.strip().lower().startswith("why")
# =========================
# SPELL CORRECTION
# =========================
KNOWN_TERMS = [
"patient", "patients", "condition", "conditions",
"encounter", "encounters", "visit", "visits",
"medication", "medications",
"admitted", "admission",
"year", "month", "last", "recent", "today"
]
def correct_spelling(q):
words = q.split()
fixed = []
for w in words:
clean = w.lower().strip(",.?")
match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
fixed.append(match[0] if match else w)
return " ".join(fixed)
# =========================
# SCHEMA
# =========================
import json
from functools import lru_cache
@lru_cache(maxsize=1)
def load_ai_schema():
"""Load schema from metadata JSON file with error handling."""
try:
with open("hospital_metadata.json", "r") as f:
schema = json.load(f)
if not isinstance(schema, dict):
raise ValueError("Invalid metadata format: expected a dictionary")
return schema
except FileNotFoundError:
raise FileNotFoundError("hospital_metadata.json file not found. Please create it with your table metadata.")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in hospital_metadata.json: {str(e)}")
except Exception as e:
raise ValueError(f"Error loading metadata: {str(e)}")
# =========================
# TABLE MATCHING (CORE LOGIC)
# =========================
def extract_relevant_tables(question, max_tables=4):
schema = load_ai_schema()
q = question.lower()
tokens = set(q.replace("?", "").replace(",", "").split())
matched = []
# Lightweight intent hints - dynamically filter to only include tables that exist
# Map natural language terms to potential table names (check against schema)
all_tables = list(schema.keys())
table_names_lower = [t.lower() for t in all_tables]
DOMAIN_HINTS = {}
# Build hints only for tables that actually exist
hint_mappings = {
"consultant": ["encounter", "encounters", "visit", "visits"],
"doctor": ["encounter", "encounters", "provider", "providers"],
"visit": ["encounter", "encounters", "visit", "visits"],
"visited": ["encounter", "encounters", "visit", "visits"],
"visits": ["encounter", "encounters", "visit", "visits"],
"appointment": ["encounter", "encounters", "appointment", "appointments"],
"patient": ["patient", "patients"],
"medication": ["medication", "medications", "drug", "drugs"],
"drug": ["medication", "medications", "drug", "drugs"],
"condition": ["condition", "conditions", "diagnosis", "diagnoses"],
"diagnosis": ["condition", "conditions", "diagnosis", "diagnoses"]
}
# Only include hints for tables that exist in the schema
for intent, possible_tables in hint_mappings.items():
matching_tables = [t for t in possible_tables if t in table_names_lower]
if matching_tables:
DOMAIN_HINTS[intent] = matching_tables
# Early exit threshold - if we find a perfect match, we can stop early
VERY_HIGH_SCORE = 10
for table, meta in schema.items():
score = 0
table_l = table.lower()
# 1️⃣ Strong signal: table name (exact match is very high confidence)
if table_l in q:
score += 6
# Early exit optimization: if exact table match found, prioritize it
if score >= VERY_HIGH_SCORE:
matched.append((table, score))
continue
# 2️⃣ Column relevance
for col, desc in meta["columns"].items():
col_l = col.lower()
if col_l in q:
score += 3
elif any(tok in col_l for tok in tokens):
score += 1
# 3️⃣ Description relevance (less weight to avoid false positives)
if meta.get("description"):
desc_tokens = set(meta["description"].lower().split())
# Only count meaningful word matches, not common words
common_words = {"the", "is", "at", "which", "on", "for", "a", "an"}
meaningful_matches = tokens & desc_tokens - common_words
if meaningful_matches:
score += len(meaningful_matches) * 0.5 # Reduced weight
# 4️⃣ Semantic intent mapping (important - highest priority)
for intent, tables in DOMAIN_HINTS.items():
if intent in q and table_l in tables:
score += 5
# 5️⃣ Only add if meets minimum threshold (prevents low-quality matches)
# Use lower threshold for small schemas (more lenient)
# Increased threshold from 3 to 4 for better precision, but lower to 2 for small schemas
threshold = 2 if len(schema) <= 5 else 4
if score >= threshold:
matched.append((table, score))
# Sort by relevance
matched.sort(key=lambda x: x[1], reverse=True)
# If no matches but schema is very small, return all tables (with lower confidence)
if not matched and len(schema) <= 3:
return list(schema.keys())[:max_tables]
return [t[0] for t in matched[:max_tables]]
# =========================
# HUMAN SCHEMA DESCRIPTION
# =========================
def describe_schema(max_tables=10):
schema = load_ai_schema()
total_tables = len(schema)
response = f"Here's the data I currently have access to ({total_tables} tables):\n\n"
# Show only top N tables to avoid overwhelming output
shown_tables = list(schema.items())[:max_tables]
for table, meta in shown_tables:
response += f"• **{table.capitalize()}** — {meta['description']}\n"
# Show only first 5 columns per table
for col, desc in list(meta["columns"].items())[:5]:
response += f" - {col}: {desc}\n"
if len(meta["columns"]) > 5:
response += f" ... and {len(meta['columns']) - 5} more columns\n"
response += "\n"
if total_tables > max_tables:
response += f"\n... and {total_tables - max_tables} more tables.\n"
response += "Ask about a specific table to see its details.\n\n"
response += (
"You can ask things like:\n"
"• How many patients are there?\n"
"• Patient count by gender\n"
"• Admissions by year\n\n"
"Just tell me what you want to explore "
)
return response
# =========================
# TIME HANDLING
# =========================
def get_latest_data_date():
"""Get the latest data date by checking tables with date columns."""
schema = load_ai_schema()
# Common date column names to check
date_columns = ["date", "start_date", "end_date", "admission_date", "admittime", "dischtime", "created_at", "updated_at"]
# Try to find a table with a date column
for table_name in schema.keys():
columns = schema[table_name].get("columns", {})
# Check if table has any date-like column
for col_name in columns.keys():
col_lower = col_name.lower()
if any(date_col in col_lower for date_col in date_columns):
try:
result = conn.execute(
f"SELECT MAX({col_name}) FROM {table_name}"
).fetchone()
if result and result[0]:
return result[0]
except (sqlite3.Error, sqlite3.OperationalError):
continue # Try next table/column
return None
def normalize_time_question(q):
latest = get_latest_data_date()
if not latest:
return q
if "today" in q:
return q.replace("today", f"on {latest[:10]}")
if "yesterday" in q:
return q.replace("yesterday", f"on {latest[:10]}")
return q
# =========================
# UNSUPPORTED QUESTIONS
# =========================
def is_question_supported(question):
q = question.lower()
tokens = set(q.replace("?", "").replace(",", "").split())
# 1️⃣ Allow analytical intent even without table names
analytic_keywords = {
"count", "total", "average", "avg", "sum",
"how many", "number of", "trend",
"increase", "decrease", "compare", "more than", "less than"
}
if any(k in q for k in analytic_keywords):
return True
# 2️⃣ Check schema relevance (table-by-table)
schema = load_ai_schema()
for table, meta in schema.items():
score = 0
table_l = table.lower()
# Table name match
if table_l in q:
score += 3
# Column name match
for col, desc in meta["columns"].items():
col_l = col.lower()
if col_l in q:
score += 2
elif any(tok in col_l for tok in tokens):
score += 1
# Description match
if meta.get("description"):
desc_tokens = set(meta["description"].lower().split())
score += len(tokens & desc_tokens)
# ✅ If any table is relevant enough → supported
if score >= 2:
return True
return False
# =========================
# SQL GENERATION
# =========================
def build_prompt(question):
matched = extract_relevant_tables(question)
full_schema = load_ai_schema()
if matched:
schema = {t: full_schema[t] for t in matched}
else:
# 🚫 Don't send all 100+ tables! Return a helpful error with available tables
available_tables = list(full_schema.keys())[:10] # Show first 10 tables
tables_list = "\n".join(f"- {t}" for t in available_tables)
if len(full_schema) > 10:
tables_list += f"\n... and {len(full_schema) - 10} more tables"
raise ValueError(
f"I couldn't find any relevant tables for your question.\n\n"
f"Available tables:\n{tables_list}\n\n"
f"Please try mentioning a specific table name or use 'what data' to see all available tables."
)
prompt = """
You are a hospital SQL assistant.
Rules:
- Use only SELECT
- SQLite syntax
- Use ONLY the exact table names listed below (do not create or infer table names)
- Use only listed tables/columns
- Return ONLY SQL or NOT_ANSWERABLE
IMPORTANT: Use EXACTLY the table names provided in the list below. Do not pluralize, modify, or guess table names.
"""
for table, meta in schema.items():
prompt += f"\nTable: {table}\n"
for col, desc in meta["columns"].items():
prompt += f"- {col}: {desc}\n"
prompt += f"\nQuestion: {question}\n"
prompt += "\nRemember: Use EXACT table names from the list above. Do not pluralize or modify table names."
return prompt
def call_llm(prompt):
"""Call OpenAI API with error handling."""
try:
res = client.chat.completions.create(
model="gpt-4.1-mini",
messages=[
{"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"},
{"role": "user", "content": prompt}
],
temperature=0
)
if not res.choices or not res.choices[0].message.content:
raise ValueError("Empty response from OpenAI API")
return res.choices[0].message.content.strip()
except Exception as e:
raise ValueError(f"OpenAI API error: {str(e)}")
# =========================
# SQL SAFETY
# =========================
def sanitize_sql(sql):
# Remove code fence markers but preserve legitimate SQL
sql = sql.replace("```sql", "").replace("```", "").strip()
# Remove leading/trailing markdown code markers
if sql.startswith("sql"):
sql = sql[3:].strip()
sql = sql.split(";")[0]
return sql.replace("\n", " ").strip()
def correct_table_names(sql):
"""Fix common table name mistakes in generated SQL."""
schema = load_ai_schema()
valid_tables = set(schema.keys())
sql_lower = sql.lower()
sql_corrected = sql
# Common table name mappings (case-insensitive replacement)
table_corrections = {
"visits": "encounters",
"visit": "encounters",
"providers": "encounters", # if this table doesn't exist
}
# Check each correction
for wrong_name, correct_name in table_corrections.items():
# Only correct if the wrong table doesn't exist AND correct one does
if wrong_name.lower() not in valid_tables and correct_name.lower() in valid_tables:
# Use word boundaries to avoid partial replacements
pattern = r'\b' + re.escape(wrong_name) + r'\b'
sql_corrected = re.sub(pattern, correct_name, sql_corrected, flags=re.IGNORECASE)
return sql_corrected
def validate_sql(sql):
if not sql.lower().startswith("select"):
raise ValueError("Only SELECT allowed")
return sql
def run_query(sql):
"""Execute SQL query with proper error handling."""
cur = conn.cursor()
try:
rows = cur.execute(sql).fetchall()
if cur.description:
cols = [c[0] for c in cur.description]
else:
cols = []
return cols, rows
except sqlite3.Error as e:
raise ValueError(f"Database query error: {str(e)}")
# =========================
# AGGREGATE SAFETY
# =========================
def is_aggregate_only_query(sql):
s = sql.lower()
return ("count(" in s or "sum(" in s or "avg(" in s) and "group by" not in s
def has_underlying_data(sql):
"""Check if underlying data exists for the SQL query."""
base = sql.lower()
if "from" not in base:
return False
base = base.split("from", 1)[1]
# Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause
for clause in ["group by", "order by", "limit", "having"]:
base = base.split(clause)[0]
test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1"
cur = conn.cursor()
try:
return cur.execute(test_sql).fetchone() is not None
except sqlite3.Error:
return False
# =========================
# PATIENT SUMMARY
# =========================
def validate_identifier(name):
"""Validate that identifier is safe (only alphanumeric and underscores)."""
if not name or not isinstance(name, str):
return False
# Check for SQL injection attempts
forbidden = [";", "--", "/*", "*/", "'", '"', "`", "(", ")", " ", "\n", "\t"]
if any(char in name for char in forbidden):
return False
# Must start with letter or underscore, rest alphanumeric/underscore
return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
def build_table_summary(table_name):
"""Build summary for a table using metadata."""
# Validate table name against metadata first
schema = load_ai_schema()
if table_name not in schema:
return f"Table {table_name} not found in metadata."
# Additional safety check
if not validate_identifier(table_name):
return f"Invalid table name: {table_name}"
cur = conn.cursor()
# Total rows (still need to query actual data for count)
# Note: SQLite doesn't support parameterized table names
# Since we validated table_name against metadata, it's safe
try:
total = cur.execute(
f"SELECT COUNT(*) FROM {table_name}"
).fetchone()[0]
except sqlite3.Error as e:
return f"Error querying table {table_name}: {str(e)}"
columns = schema[table_name]["columns"] # {col_name: description, ...}
summary = f"Here's a summary of **{table_name}**:\n\n"
summary += f"• Total records: {total}\n"
# Try to summarize categorical columns using metadata
for col_name, col_desc in columns.items():
# Validate column name
if not validate_identifier(col_name):
continue
# Try to determine if it's a categorical column based on name/description
# Skip likely numeric/date columns
col_lower = col_name.lower()
if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]):
continue
# Try to get breakdown for text-like columns
try:
# Note: SQLite doesn't support parameterized identifiers, so we validate
rows = cur.execute(
f"""
SELECT {col_name}, COUNT(*)
FROM {table_name}
GROUP BY {col_name}
ORDER BY COUNT(*) DESC
LIMIT 5
"""
).fetchall()
if rows:
summary += f"\n• {col_name.capitalize()} breakdown:\n"
for val, count in rows:
summary += f" - {val}: {count}\n"
except (sqlite3.Error, sqlite3.OperationalError) as e:
# Ignore columns that can't be grouped (likely not categorical)
pass
summary += "\nYou can ask more detailed questions about this data."
return summary
# =========================
# MAIN ENGINE
# =========================
def process_question(question):
global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE
q = question.strip().lower()
# ----------------------------------
# Normalize first
# ----------------------------------
question = correct_spelling(question)
question = normalize_time_question(question)
LAST_PROMPT_TYPE = None
LAST_SUGGESTED_DATE = None
# ----------------------------------
# Handle "data updated till"
# ----------------------------------
if any(x in q for x in ["updated", "upto", "up to", "latest data"]):
return {
"status": "ok",
"message": f"Data is available up to {get_latest_data_date()}",
"data": []
}
# ----------------------------------
# Extract relevant tables
# ----------------------------------
matched_tables = extract_relevant_tables(question)
# ----------------------------------
# SUMMARY ONLY IF USER ASKS FOR IT
# ----------------------------------
if (
len(matched_tables) == 1
and any(k in q for k in ["summary", "overview", "describe"])
and not any(k in q for k in ["count", "total", "how many", "average"])
):
return {
"status": "ok",
"message": build_table_summary(matched_tables[0]),
"data": []
}
# Only block if too many tables matched AND it's not an analytical question
# Analytical questions (how many, count, etc.) often need multiple tables
is_analytical = any(k in q for k in [
"how many", "count", "total", "number of",
"average", "avg", "sum", "more than", "less than",
"compare", "trend"
])
if len(matched_tables) > 4 and not is_analytical:
return {
"status": "ok",
"message": (
"Your question matches too many datasets:\n"
+ "\n".join(f"- {t}" for t in matched_tables[:5])
+ "\n\nPlease be more specific about what you want to know."
),
"data": []
}
# ----------------------------------
# Metadata discovery
# ----------------------------------
if any(x in q for x in ["what data", "what tables", "which data"]):
return {
"status": "ok",
"message": humanize(describe_schema()),
"data": []
}
# ----------------------------------
# Unsupported question check
# ----------------------------------
if not is_question_supported(question):
return {
"status": "ok",
"message": (
"That information isn’t available in the system.\n\n"
"You can ask about:\n"
"• Patients\n"
"• Visits\n"
"• Conditions\n"
"• Medications"
),
"data": []
}
# ----------------------------------
# Generate SQL
# ----------------------------------
try:
sql = call_llm(build_prompt(question))
except ValueError as e:
# Handle case where no relevant tables found
return {
"status": "ok",
"message": str(e),
"data": []
}
if sql == "NOT_ANSWERABLE":
return {
"status": "ok",
"message": "I don't have enough data to answer that.",
"data": []
}
# Sanitize, correct table names, then validate
sql = sanitize_sql(sql)
sql = correct_table_names(sql)
sql = validate_sql(sql)
cols, rows = run_query(sql)
# ----------------------------------
# No data handling
# ----------------------------------
if is_aggregate_only_query(sql) and not has_underlying_data(sql):
LAST_PROMPT_TYPE = "NO_DATA"
LAST_SUGGESTED_DATE = get_latest_data_date()
return {
"status": "ok",
"message": friendly("No data is available for that time period."),
"note": f"Available data is only up to {LAST_SUGGESTED_DATE}.",
"data": []
}
if not rows:
LAST_PROMPT_TYPE = "NO_DATA"
LAST_SUGGESTED_DATE = get_latest_data_date()
return {
"status": "ok",
"message": friendly("No records found."),
"note": f"Available data is only up to {LAST_SUGGESTED_DATE}.",
"data": []
}
# ----------------------------------
# Success
# ----------------------------------
return {
"status": "ok",
"sql": sql,
"columns": cols,
"data": rows
}