Text_to_sql / engine.py
bhavika24's picture
Upload engine.py
2c656c5 verified
raw
history blame
26.8 kB
import os
import re
import sqlite3
from openai import OpenAI
from difflib import get_close_matches
from datetime import datetime
TRANSCRIPT = [] #memory log
#store interaction in transcript
def log_interaction(user_q, sql=None, result=None, error=None):
TRANSCRIPT.append({
"timestamp": datetime.utcnow().isoformat(),
"question": user_q,
"sql": sql,
"result_preview": result[:10] if isinstance(result, list) else result,
"error": error
})
# =========================
# SETUP
# =========================
# Validate API key
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
client = OpenAI(api_key=api_key)
conn = sqlite3.connect("mimic_iv.db", check_same_thread=False)
# =========================
# CONVERSATION STATE
# =========================
LAST_PROMPT_TYPE = None
LAST_SUGGESTED_DATE = None
# =========================
# HUMAN RESPONSE HELPERS
# =========================
def humanize(text):
return f"Sure \n\n{text}"
def friendly(text):
global LAST_SUGGESTED_DATE
if LAST_SUGGESTED_DATE:
return f"{text}\n\nLast data available is {LAST_SUGGESTED_DATE}"
else:
# If date not set yet, try to get it
date = get_latest_data_date()
if date:
return f"{text}\n\nLast data available is {date}"
return text
def is_confirmation(text):
return text.strip().lower() in ["yes", "yep", "yeah", "ok", "okay", "sure"]
def is_why_question(text):
return text.strip().lower().startswith("why")
# =========================
# SPELL CORRECTION
# =========================
KNOWN_TERMS = [
"patient", "patients",
"admission", "admissions",
"icu", "stay", "icustay",
"diagnosis", "procedure",
"medication", "lab",
"year", "month", "recent", "today"
]
def correct_spelling(q):
words = q.split()
fixed = []
for w in words:
clean = w.lower().strip(",.?")
match = get_close_matches(clean, KNOWN_TERMS, n=1, cutoff=0.8)
fixed.append(match[0] if match else clean)
return " ".join(fixed)
# =========================
# SCHEMA
# =========================
import json
from functools import lru_cache
def col_desc(desc):#extract description
"""Safely extract column description from metadata."""
if isinstance(desc, dict):
return desc.get("description", "")
return str(desc)
@lru_cache(maxsize=1)
def load_ai_schema():
#load metadata
"""Load schema from metadata JSON file with error handling."""
try:
with open("metadata.json", "r") as f:
schema = json.load(f)
if not isinstance(schema, dict):
raise ValueError("Invalid metadata format: expected a dictionary")
return schema
except FileNotFoundError:
raise FileNotFoundError("metadata.json file not found. Please create it with your table metadata.")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in metadata.json: {str(e)}")
except Exception as e:
raise ValueError(f"Error loading metadata: {str(e)}")
# =========================
# TABLE MATCHING (CORE LOGIC)
# =========================
def extract_relevant_tables(question, max_tables=4):
schema = load_ai_schema()
q = question.lower()
tokens = set(q.replace("?", "").replace(",", "").split())
matched = []
# Lightweight intent hints - dynamically filter to only include tables that exist
# Map natural language terms to potential table names (check against schema)
all_tables = list(schema.keys())
table_names_lower = [t.lower() for t in all_tables]
DOMAIN_HINTS = {}
# Build hints only for tables that actually exist
hint_mappings = {
# Patients & visits
"patient": ["patients"],
"patients": ["patients"],
"admission": ["admissions"],
"admissions": ["admissions"],
"visit": ["admissions", "icustays"],
"visits": ["admissions", "icustays"],
# ICU
"icu": ["icustays", "chartevents"],
"stay": ["icustays"],
"stays": ["icustays"],
# Diagnoses / conditions
"diagnosis": ["diagnoses_icd"],
"diagnoses": ["diagnoses_icd"],
"condition": ["diagnoses_icd"],
"conditions": ["diagnoses_icd"],
# Procedures
"procedure": ["procedures_icd"],
"procedures": ["procedures_icd"],
# Medications
"medication": ["prescriptions", "emar", "pharmacy"],
"medications": ["prescriptions", "emar", "pharmacy"],
"drug": ["prescriptions"],
"drugs": ["prescriptions"],
# Labs & vitals
"lab": ["labevents"],
"labs": ["labevents"],
"vital": ["chartevents"],
"vitals": ["chartevents"],
}
# Only include hints for tables that exist in the schema
for intent, possible_tables in hint_mappings.items():
matching_tables = [t for t in possible_tables if t in table_names_lower]
if matching_tables:
DOMAIN_HINTS[intent] = matching_tables
# Early exit threshold - if we find a perfect match, we can stop early
VERY_HIGH_SCORE = 10
for table, meta in schema.items():
score = 0
table_l = table.lower()
# 1️⃣ Strong signal: table name (exact match is very high confidence)
if table_l in q:
score += 6
# Early exit optimization: if exact table match found, prioritize it
if score >= VERY_HIGH_SCORE:
matched.append((table, score))
continue
# 2️⃣ Column relevance
for col, desc in meta["columns"].items():
desc_text = col_desc(desc)
desc_tokens = set(desc_text.lower().split())
col_l = col.lower()
if col_l in q:
score += 3
elif any(tok in col_l for tok in tokens):
score += 1
# 3️⃣ Description relevance (less weight to avoid false positives)
if meta.get("description"):
desc_tokens = set(col_desc(meta.get("description", "")).lower().split())
# Only count meaningful word matches, not common words
common_words = {"the", "is", "at", "which", "on", "for", "a", "an"}
meaningful_matches = tokens & desc_tokens - common_words
if meaningful_matches:
score += len(meaningful_matches) * 0.5 # Reduced weight
# 4️⃣ Semantic intent mapping (important - highest priority)
for intent, tables in DOMAIN_HINTS.items():
if intent in q and table_l in tables:
score += 5
# 5️⃣ Only add if meets minimum threshold (prevents low-quality matches)
# Use lower threshold for small schemas (more lenient)
# Increased threshold from 3 to 4 for better precision, but lower to 2 for small schemas
threshold = 2 if len(schema) <= 5 else 4
if score >= threshold:
matched.append((table, score))
# Sort by relevance
matched.sort(key=lambda x: x[1], reverse=True)
# If no matches but schema is very small, return all tables (with lower confidence)
if not matched and len(schema) <= 3:
return list(schema.keys())[:max_tables]
return [t[0] for t in matched[:max_tables]]
# =========================
# HUMAN SCHEMA DESCRIPTION
# =========================
def describe_schema(max_tables=10):#what data you have or which table exist
schema = load_ai_schema()
total_tables = len(schema)
response = f"Here's the data I currently have access to ({total_tables} tables):\n\n"
# Show only top N tables to avoid overwhelming output
shown_tables = list(schema.items())[:max_tables]
for table, meta in shown_tables:
response += f"• **{table.capitalize()}** — {meta['description']}\n"
# Show only first 5 columns per table
for col, desc in list(meta["columns"].items())[:5]:
response += f" - {col}: {col_desc(desc)}\n"
if len(meta["columns"]) > 5:
response += f" ... and {len(meta['columns']) - 5} more columns\n"
response += "\n"
if total_tables > max_tables:
response += f"\n... and {total_tables - max_tables} more tables.\n"
response += "Ask about a specific table to see its details.\n\n"
response += (
"You can ask things like:\n"
"• How many patients are there?\n"
"• Patient count by gender\n"
"• Admissions by year\n\n"
"Just tell me what you want to explore "
)
return response
# =========================
# TIME HANDLING
# =========================
def get_latest_data_date():
"""
Returns the most meaningful 'latest date' for the system.
Priority:
1. admissions.admittime
2. icustays.intime
3. chartevents.charttime
"""
checks = [
("admissions", "admittime"),
("icustays", "intime"),
("chartevents", "charttime"),
]
for table, column in checks:
try:
result = conn.execute(
f"SELECT MAX({column}) FROM {table}"
).fetchone()
if result and result[0]:
return result[0]
except Exception:
continue
return None
def normalize_time_question(q):#total-actual date
latest = get_latest_data_date()
if not latest:
return q
if "today" in q:
return q.replace("today", f"on {latest[:10]}")
if "yesterday" in q:
return q.replace("yesterday", f"on {latest[:10]}")
return q
# =========================
# UNSUPPORTED QUESTIONS
# =========================
def is_question_supported(question):
q = question.lower()
tokens = set(q.replace("?", "").replace(",", "").split())
# 1️⃣ Allow analytical intent even without table names
analytic_keywords = {
"count", "total", "average", "avg", "sum",
"how many", "number of", "trend",
"increase", "decrease", "compare", "more than", "less than"
}
if any(k in q for k in analytic_keywords):
return True
# 2️⃣ Check schema relevance (table-by-table)
schema = load_ai_schema()
for table, meta in schema.items():
score = 0
table_l = table.lower()
# Table name match
if table_l in q:
score += 3
# Column name match
for col, desc in meta["columns"].items():
col_l = col.lower()
if col_l in q:
score += 2
elif any(tok in col_l for tok in tokens):
score += 1
# Description match
if meta.get("description"):
desc_tokens = set(col_desc(meta["description"]).lower().split())
score += len(tokens & desc_tokens)
# ✅ If any table is relevant enough → supported
if score >= 2:
return True
return False
# =========================
# SQL GENERATION
# =========================
def build_prompt(question):
matched = extract_relevant_tables(question)
full_schema = load_ai_schema()
if not matched:
available_tables = list(full_schema.keys())[:10]
tables_list = "\n".join(f"- {t}" for t in available_tables)
if len(full_schema) > 10:
tables_list += f"\n... and {len(full_schema) - 10} more tables"
raise ValueError(
"I couldn't find any relevant tables for your question.\n\n"
f"Available tables:\n{tables_list}\n\n"
"Try mentioning a table name or ask: 'what data is available?'"
)
schema = {t: full_schema[t] for t in matched}
IMPORTANT_COLS = {
"subject_id", "hadm_id", "stay_id",
"icustay_id", "itemid",
"charttime", "starttime", "endtime"
}
prompt = """
You are an expert SQLite query generator.
STRICT RULES:
- Use ONLY the tables and columns listed below
- NEVER invent table or column names
- If the answer cannot be derived, return: NOT_ANSWERABLE
- Do NOT explain the SQL
- Do NOT wrap SQL in markdown
- Use explicit JOIN conditions
- Prefer COUNT(*) for totals
Always use these joins:
- patients.subject_id = admissions.subject_id
- admissions.hadm_id = icustays.hadm_id
- icustays.stay_id = chartevents.stay_id
Schema:
"""
for table, meta in schema.items():
prompt += f"\nTable: {table}\n"
for col, desc in meta["columns"].items():
text = f"{col} {col_desc(desc)}".lower()
# Keep columns relevant to question
if any(w in text for w in question.lower().split()):
prompt += f"- {col}\n"
# Always keep join / key columns
elif col in IMPORTANT_COLS or col.endswith("_id"):
prompt += f"- {col}\n"
# Optional: help LLM with joins (very helpful for MIMIC)
prompt += """
Join hints:
- patients.subject_id ↔ admissions.subject_id
- admissions.hadm_id ↔ icustays.hadm_id
- icustays.stay_id ↔ chartevents.stay_id
"""
prompt += f"\nQuestion: {question}\n"
prompt += "\nUse EXACT table and column names as shown above."
# Safety cap
if len(prompt) > 6000:
prompt = prompt[:6000] + "\n\n# Schema truncated for safety\n"
return prompt
def call_llm(prompt):
"""Call OpenAI API with error handling."""
try:
res = client.chat.completions.create(
model="gpt-4.1-mini",
messages=[
{"role": "system", "content": "Return only SQL or NOT_ANSWERABLE"},
{"role": "user", "content": prompt}
],
temperature=0
)
if not res.choices or not res.choices[0].message.content:
raise ValueError("Empty response from OpenAI API")
return res.choices[0].message.content.strip()
except Exception as e:
raise ValueError(f"OpenAI API error: {str(e)}")
# =========================
# SQL SAFETY
# =========================
def sanitize_sql(sql):
# Remove code fence markers but preserve legitimate SQL
sql = sql.replace("```sql", "").replace("```", "").strip()
# Remove leading/trailing markdown code markers
if sql.startswith("sql"):
sql = sql[3:].strip()
sql = sql.split(";")[0]
return sql.replace("\n", " ").strip()
def correct_table_names(sql):
schema = load_ai_schema()
valid_tables = {t.lower() for t in schema.keys()}
table_corrections = {
"visit": "admissions",
"visits": "admissions",
"provider": "caregiver",
"providers": "caregiver"
}
def replace_table(match):
keyword = match.group(1)
table = match.group(2)
table_l = table.lower()
if table_l in valid_tables:
return match.group(0)
if table_l in table_corrections:
corrected = table_corrections[table_l]
if corrected in valid_tables:
return f"{keyword} {corrected}"
return match.group(0)
pattern = re.compile(
r"\b(from|join)\s+([a-zA-Z_][a-zA-Z0-9_]*)",
re.IGNORECASE
)
return pattern.sub(replace_table, sql)
def validate_sql(sql):
if " join " in sql.lower() and " on " not in sql.lower():
raise ValueError("JOIN without ON condition is not allowed")
if ";" in sql.strip()[:-1]:
raise ValueError("Multiple SQL statements are not allowed")
FORBIDDEN = ["insert", "update", "delete", "drop", "alter"]
if any(k in sql.lower() for k in FORBIDDEN):
raise ValueError("Unsafe SQL detected")
if not sql.lower().startswith("select"):
raise ValueError("Only SELECT allowed")
return sql
def run_query(sql):
"""Execute SQL query with proper error handling."""
cur = conn.cursor()
try:
rows = cur.execute(sql).fetchall()
if cur.description:
cols = [c[0] for c in cur.description]
else:
cols = []
return cols, rows
except sqlite3.Error as e:
raise ValueError(f"Database query error: {str(e)}")
# =========================
# AGGREGATE SAFETY
# =========================
def is_aggregate_only_query(sql):
s = sql.lower()
return (
any(fn in s for fn in ["count(", "sum(", "avg("])
and "group by" not in s
and "over(" not in s
)
def has_underlying_data(sql):
"""Check if underlying data exists for the SQL query."""
base = sql.lower()
if "from" not in base:
return False
base = base.split("from", 1)[1]
# Split at GROUP BY, ORDER BY, LIMIT, etc. to get just the FROM clause
for clause in ["group by", "order by", "limit", "having"]:
base = base.split(clause)[0]
test_sql = "SELECT 1 FROM " + base.strip() + " LIMIT 1"
cur = conn.cursor()
try:
return cur.execute(test_sql).fetchone() is not None
except sqlite3.Error:
return False
# =========================
# PATIENT SUMMARY
# =========================
def validate_identifier(name):
"""Validate that identifier is safe (only alphanumeric and underscores)."""
if not name or not isinstance(name, str):
return False
# Check for SQL injection attempts
forbidden = [";", "--", "/*", "*/", "'", '"', "`", "(", ")", " ", "\n", "\t"]
if any(char in name for char in forbidden):
return False
# Must start with letter or underscore, rest alphanumeric/underscore
return bool(re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name))
def build_table_summary(table_name):
"""Build summary for a table using metadata."""
# Validate table name against metadata first
schema = load_ai_schema()
if table_name not in schema:
return f"Table {table_name} not found in metadata."
# Additional safety check
if not validate_identifier(table_name):
return f"Invalid table name: {table_name}"
cur = conn.cursor()
# Total rows (still need to query actual data for count)
# Note: SQLite doesn't support parameterized table names
# Since we validated table_name against metadata, it's safe
try:
total = cur.execute(
f"SELECT COUNT(*) FROM {table_name}"
).fetchone()[0]
except sqlite3.Error as e:
return f"Error querying table {table_name}: {str(e)}"
columns = schema[table_name]["columns"] # {col_name: description, ...}
summary = f"Here's a summary of **{table_name}**:\n\n"
summary += f"• Total records: {total}\n"
# Try to summarize categorical columns using metadata
for col_name, col_desc in columns.items():
# Validate column name
if not validate_identifier(col_name):
continue
# Try to determine if it's a categorical column based on name/description
# Skip likely numeric/date columns
col_lower = col_name.lower()
if any(skip in col_lower for skip in ["id", "_id", "date", "time", "count", "amount", "price"]):
continue
# Try to get breakdown for text-like columns
try:
# Note: SQLite doesn't support parameterized identifiers, so we validate
rows = cur.execute(
f"""
SELECT {col_name}, COUNT(*)
FROM {table_name}
GROUP BY {col_name}
ORDER BY COUNT(*) DESC
LIMIT 5
"""
).fetchall()
if rows:
summary += f"\n• {col_name.capitalize()} breakdown:\n"
for val, count in rows:
summary += f" - {val}: {count}\n"
except (sqlite3.Error, sqlite3.OperationalError) as e:
# Ignore columns that can't be grouped (likely not categorical)
pass
summary += "\nYou can ask more detailed questions about this data."
return summary
# =========================
# MAIN ENGINE
# =========================
def process_question(question):
global LAST_PROMPT_TYPE, LAST_SUGGESTED_DATE
q = question.strip().lower()
# ----------------------------------
# Normalize first
# ----------------------------------
question = correct_spelling(question)
question = normalize_time_question(question)
LAST_PROMPT_TYPE = None
LAST_SUGGESTED_DATE = None
# ----------------------------------
# Handle "data updated till"
# ----------------------------------
if any(x in q for x in ["updated", "upto", "up to", "latest data"]):
return {
"status": "ok",
"message": f"Data is available up to {get_latest_data_date()}",
"data": []
}
# ----------------------------------
# Extract relevant tables
# ----------------------------------
matched_tables = extract_relevant_tables(question)
# ----------------------------------
# SUMMARY ONLY IF USER ASKS FOR IT
# ----------------------------------
if (
len(matched_tables) == 1
and any(k in q for k in ["summary", "overview", "describe"])
and not any(k in q for k in ["count", "total", "how many", "average"])
):
return {
"status": "ok",
"message": build_table_summary(matched_tables[0]),
"data": []
}
# Only block if too many tables matched AND it's not an analytical question
# Analytical questions (how many, count, etc.) often need multiple tables
is_analytical = any(k in q for k in [
"how many", "count", "total", "number of",
"average", "avg", "sum", "more than", "less than",
"compare", "trend"
])
if len(matched_tables) > 4 and not is_analytical:
return {
"status": "ok",
"message": (
"Your question matches too many datasets:\n"
+ "\n".join(f"- {t}" for t in matched_tables[:5])
+ "\n\nPlease be more specific about what you want to know."
),
"data": []
}
# ----------------------------------
# Metadata discovery
# ----------------------------------
if any(x in q for x in ["what data", "what tables", "which data"]):
return {
"status": "ok",
"message": humanize(describe_schema()),
"data": []
}
# ----------------------------------
# # LAST DATA / RECENT DATA HANDLING
# # ----------------------------------
if any(x in q for x in ["last data", "latest data"]):
return {
"status": "ok",
"message": f"Latest data available is from {get_latest_data_date()}",
"data": []
}
if "last" in q and "day" in q and ("visit" in q or "admission" in q):
sql = """
SELECT subject_id, admittime
FROM admissions
WHERE admittime >= date(
(SELECT MAX(admittime) FROM admissions),
'-30 days'
)
ORDER BY admittime DESC
"""
cols, rows = run_query(sql)
log_interaction(
user_q=question,
sql=sql,
result=rows
)
return {
"status": "ok",
"sql": sql,
"columns": cols,
"data": rows
}
# ----------------------------------
# Unsupported question check
# ----------------------------------
if not is_question_supported(question):
log_interaction(
user_q=question,
error="Unsupported question"
)
return {
"status": "ok",
"message": (
"That information isn’t available in the system.\n\n"
"You can ask about:\n"
"• Patients\n"
"• Admissions / Visits\n"
"• ICU stays\n"
"• Diagnoses / Conditions\n"
"• Vitals & lab measurements"
),
"data": []
}
# ----------------------------------
# Generate SQL
# ----------------------------------
try:
sql = call_llm(build_prompt(question))
except ValueError as e:
log_interaction(
user_q=question,
error=str(e)
)
return {
"status": "ok",
"message": str(e),
"data": []
}
if sql == "NOT_ANSWERABLE":
return {
"status": "ok",
"message": "I don't have enough data to answer that.",
"data": []
}
# Sanitize, correct table names, then validate
sql = sanitize_sql(sql)
sql = correct_table_names(sql)
sql = validate_sql(sql)
cols, rows = run_query(sql)
# ✅ LOG ONCE (THIS FIXES YOUR DOWNLOAD ISSUE)
log_interaction(
user_q=question,
sql=sql,
result=rows
)
if not rows:
return {
"status": "ok",
"message": friendly("No records found."),
"data": []
}
return {
"status": "ok",
"sql": sql,
"columns": cols,
"data": rows
}
# ----------------------------------
# No data handling
# ----------------------------------
if is_aggregate_only_query(sql) and not has_underlying_data(sql):
LAST_PROMPT_TYPE = "NO_DATA"
LAST_SUGGESTED_DATE = get_latest_data_date()
return {
"status": "ok",
"message": friendly("No data is available for that time period."),
"note": f"Available data is only up to {LAST_SUGGESTED_DATE}.",
"data": []
}
if not rows:
log_interaction(
user_q=question,
sql=sql,
result=[]
)
LAST_PROMPT_TYPE = "NO_DATA"
LAST_SUGGESTED_DATE = get_latest_data_date()
return {
"status": "ok",
"message": friendly("No records found."),
"note": f"Available data is only up to {LAST_SUGGESTED_DATE}.",
"data": []
}
# ----------------------------------
# Success
# ----------------------------------
return {
"status": "ok",
"sql": sql,
"columns": cols,
"data": rows
}