Text_to_sql / engine.py
bhavika24's picture
Upload engine.py
726ac48 verified
raw
history blame
7.34 kB
import os
import sqlite3
from openai import OpenAI
from difflib import get_close_matches
# =========================
# Setup
# =========================
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
conn = sqlite3.connect("hospital.db", check_same_thread=False)
# =========================
# Known Terms for Spell Correction
# =========================
KNOWN_TERMS = [
"patient", "patients", "condition", "conditions", "diagnosis", "encounter", "encounters",
"visit", "visits", "observation", "observations", "lab", "labs", "test", "tests",
"medication", "medications", "drug", "drugs", "prescription", "prescriptions",
"diabetes", "hypertension", "asthma", "cancer", "admitted", "admission"
]
def correct_spelling(question: str) -> str:
words = question.split()
corrected_words = []
for word in words:
clean_word = word.lower().strip(",.?")
matches = get_close_matches(clean_word, KNOWN_TERMS, n=1, cutoff=0.8)
if matches:
corrected_words.append(matches[0])
else:
corrected_words.append(word)
return " ".join(corrected_words)
# =========================
# Metadata Loader
# =========================
def load_ai_schema():
cur = conn.cursor()
schema = {}
tables = cur.execute("""
SELECT table_name, description
FROM ai_tables
WHERE ai_enabled = 1
""").fetchall()
for table_name, desc in tables:
cols = cur.execute("""
SELECT column_name, description
FROM ai_columns
WHERE table_name = ? AND ai_allowed = 1
""", (table_name,)).fetchall()
schema[table_name] = {
"description": desc,
"columns": cols
}
return schema
# =========================
# Prompt Builder
# =========================
def build_prompt(question: str) -> str:
schema = load_ai_schema()
prompt = """
You are a hospital data assistant.
Rules:
- Generate only SELECT SQL queries.
- Use only the tables and columns provided.
- Do not invent tables or columns.
- This database is SQLite. Use SQLite-compatible date functions.
- For recent days use: date('now', '-N day')
- Use case-insensitive matching for text fields.
- Prefer LIKE with wildcards for medical condition names.
- Use COUNT, AVG, MIN, MAX, GROUP BY when the question asks for totals, averages, or comparisons.
- If the question cannot be answered using the schema, return NOT_ANSWERABLE.
- Do not explain the query.
- Return only SQL or NOT_ANSWERABLE.
Available schema:
"""
for table, meta in schema.items():
prompt += f"\nTable: {table} - {meta['description']}\n"
for col, desc in meta["columns"]:
prompt += f" - {col}: {desc}\n"
prompt += f"\nUser question: {question}\n"
return prompt
# =========================
# LLM Call
# =========================
def call_llm(prompt: str) -> str:
response = client.chat.completions.create(
model="gpt-4.1-mini",
messages=[
{"role": "system", "content": "You are a SQL generator. Return only SQL. No explanation."},
{"role": "user", "content": prompt}
],
temperature=0.0
)
return response.choices[0].message.content.strip()
# =========================
# SQL Generation
# =========================
def generate_sql(question: str) -> str:
prompt = build_prompt(question)
sql = call_llm(prompt)
return sql.strip()
# =========================
# SQL Cleaning & Validation
# =========================
def clean_sql(sql: str) -> str:
sql = sql.strip()
# Remove markdown code fences if present
if sql.startswith("```"):
parts = sql.split("```")
if len(parts) > 1:
sql = parts[1]
sql = sql.replace("sql\n", "").strip()
return sql
def validate_sql(sql: str) -> str:
sql = clean_sql(sql)
s = sql.lower()
forbidden = ["insert", "update", "delete", "drop", "alter", "truncate"]
if not s.startswith("select"):
raise Exception("Only SELECT queries allowed")
if any(f in s for f in forbidden):
raise Exception("Forbidden SQL operation detected")
return sql
# =========================
# Query Runner
# =========================
def run_query(sql: str):
cur = conn.cursor()
result = cur.execute(sql).fetchall()
columns = [desc[0] for desc in cur.description]
return columns, result
# =========================
# Guardrails
# =========================
def is_question_answerable(question):
keywords = [
"patient", "encounter", "condition", "observation",
"medication", "visit", "diagnosis", "lab", "vital", "admitted"
]
q = question.lower()
if not any(k in q for k in keywords):
return False
return True
# =========================
# Time Awareness
# =========================
def get_latest_data_date():
sql = "SELECT MAX(start_date) FROM encounters;"
_, rows = run_query(sql)
return rows[0][0]
def check_time_relevance(question: str):
q = question.lower()
if any(word in q for word in ["last", "recent", "today", "this month", "this year"]):
latest = get_latest_data_date()
return f"Latest available data is from {latest}."
return None
# =========================
# Empty Result Interpreter
# =========================
def interpret_empty_result(question: str):
latest = get_latest_data_date()
return f"No results found. Available data is up to {latest}."
# =========================
# ORCHESTRATOR (Single Entry Point)
# =========================
def process_question(question: str):
# 0. Spell correction
question = correct_spelling(question)
# 1. Guardrail
if not is_question_answerable(question):
return {
"status": "rejected",
"message": "This question is not supported by the available data."
}
# 2. Time relevance
time_note = check_time_relevance(question)
# 3. Generate SQL
sql = generate_sql(question)
# 4. Validate SQL
sql = validate_sql(sql)
# 5. Execute query
columns, rows = run_query(sql)
# 6. Handle empty result with data coverage awareness
if len(rows) == 0:
latest = get_latest_data_date()
q = question.lower()
if any(word in q for word in ["last", "recent", "this month", "this year"]):
return {
"status": "ok",
"sql": sql,
"message": f"No data available for the requested time period. Latest available data is from {latest}.",
"data": [],
"note": None
}
return {
"status": "ok",
"sql": sql,
"message": interpret_empty_result(question),
"data": [],
"note": time_note
}
# 7. Normal response
return {
"status": "ok",
"sql": sql,
"columns": columns,
"data": rows[:50], # demo safety limit
"note": time_note
}