Spaces:
Sleeping
Sleeping
Upload engine.py
Browse files
engine.py
CHANGED
|
@@ -47,6 +47,12 @@ KNOWN_TERMS = [
|
|
| 47 |
"admitted", "admission",
|
| 48 |
"year", "month", "last", "recent", "today"
|
| 49 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def correct_spelling(q):
|
| 52 |
words = q.split()
|
|
@@ -93,27 +99,42 @@ def load_ai_schema():
|
|
| 93 |
# TABLE MATCHING (CORE LOGIC)
|
| 94 |
# =========================
|
| 95 |
|
| 96 |
-
def extract_relevant_tables(question, max_tables=
|
| 97 |
schema = load_ai_schema()
|
| 98 |
q = question.lower()
|
| 99 |
-
|
| 100 |
tokens = set(q.replace("?", "").replace(",", "").split())
|
|
|
|
| 101 |
matched = []
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
for table, meta in schema.items():
|
| 104 |
score = 0
|
| 105 |
table_l = table.lower()
|
| 106 |
|
| 107 |
-
# 1️⃣
|
| 108 |
if table_l in q:
|
| 109 |
-
score +=
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
score += len(tokens & set(desc_words)) * 2
|
| 115 |
|
| 116 |
-
#
|
| 117 |
for col, _ in meta["columns"]:
|
| 118 |
col_l = col.lower()
|
| 119 |
if col_l in q:
|
|
@@ -121,46 +142,52 @@ def extract_relevant_tables(question, max_tables=5):
|
|
| 121 |
elif any(tok in col_l for tok in tokens):
|
| 122 |
score += 1
|
| 123 |
|
| 124 |
-
#
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
"medication": ["drug", "medicine"],
|
| 129 |
-
"admission": ["admit", "admission"],
|
| 130 |
-
"date": ["date", "year", "month"]
|
| 131 |
-
}
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
-
if
|
|
|
|
| 138 |
matched.append((table, score))
|
| 139 |
|
| 140 |
# Sort by relevance
|
| 141 |
matched.sort(key=lambda x: x[1], reverse=True)
|
| 142 |
|
| 143 |
-
# Return top N tables
|
| 144 |
return [t[0] for t in matched[:max_tables]]
|
| 145 |
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
# =========================
|
| 150 |
# HUMAN SCHEMA DESCRIPTION
|
| 151 |
# =========================
|
| 152 |
|
| 153 |
-
def describe_schema():
|
| 154 |
schema = load_ai_schema()
|
|
|
|
| 155 |
|
| 156 |
-
response = "Here
|
| 157 |
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
| 159 |
response += f"• **{table.capitalize()}** — {meta['description']}\n"
|
| 160 |
-
|
|
|
|
| 161 |
response += f" - {col}: {desc}\n"
|
|
|
|
|
|
|
| 162 |
response += "\n"
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
response += (
|
| 165 |
"You can ask things like:\n"
|
| 166 |
"• How many patients are there?\n"
|
|
@@ -168,9 +195,6 @@ def describe_schema():
|
|
| 168 |
"• Admissions by year\n\n"
|
| 169 |
"Just tell me what you want to explore "
|
| 170 |
)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
|
| 175 |
return response
|
| 176 |
|
|
@@ -204,21 +228,21 @@ def is_question_supported(question):
|
|
| 204 |
q = question.lower()
|
| 205 |
tokens = set(q.replace("?", "").replace(",", "").split())
|
| 206 |
|
| 207 |
-
# 1️⃣ Allow analytical intent even
|
| 208 |
analytic_keywords = {
|
| 209 |
"count", "total", "average", "avg", "sum",
|
| 210 |
-
"how many", "number of", "trend",
|
| 211 |
-
"increase", "decrease", "compare"
|
| 212 |
}
|
| 213 |
|
| 214 |
if any(k in q for k in analytic_keywords):
|
| 215 |
return True
|
| 216 |
|
| 217 |
-
# 2️⃣
|
| 218 |
schema = load_ai_schema()
|
| 219 |
-
score = 0
|
| 220 |
|
| 221 |
for table, meta in schema.items():
|
|
|
|
| 222 |
table_l = table.lower()
|
| 223 |
|
| 224 |
# Table name match
|
|
@@ -235,11 +259,15 @@ def is_question_supported(question):
|
|
| 235 |
|
| 236 |
# Description match
|
| 237 |
if meta.get("description"):
|
| 238 |
-
desc_tokens = meta["description"].lower().split()
|
| 239 |
-
score += len(tokens &
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
# 3️⃣ Threshold — prevents random questions
|
| 242 |
-
return score >= 2
|
| 243 |
|
| 244 |
|
| 245 |
# =========================
|
|
@@ -252,7 +280,11 @@ def build_prompt(question):
|
|
| 252 |
if matched:
|
| 253 |
schema = {t: load_ai_schema()[t] for t in matched}
|
| 254 |
else:
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
prompt = """
|
| 258 |
You are a hospital SQL assistant.
|
|
@@ -323,9 +355,6 @@ def has_underlying_data(sql):
|
|
| 323 |
cur = conn.cursor()
|
| 324 |
return cur.execute(test_sql).fetchone() is not None
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
# =========================
|
| 330 |
# PATIENT SUMMARY
|
| 331 |
# =========================
|
|
@@ -466,12 +495,20 @@ def process_question(question):
|
|
| 466 |
# ----------------------------------
|
| 467 |
# Generate SQL
|
| 468 |
# ----------------------------------
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
if sql == "NOT_ANSWERABLE":
|
| 472 |
return {
|
| 473 |
"status": "ok",
|
| 474 |
-
"message": "I don
|
| 475 |
"data": []
|
| 476 |
}
|
| 477 |
|
|
|
|
| 47 |
"admitted", "admission",
|
| 48 |
"year", "month", "last", "recent", "today"
|
| 49 |
]
|
| 50 |
+
DOMAIN_ALIASES = {
|
| 51 |
+
"consultant": ["provider", "encounter"],
|
| 52 |
+
"doctor": ["provider"],
|
| 53 |
+
"appointment": ["encounter"],
|
| 54 |
+
"visit": ["encounter"],
|
| 55 |
+
}
|
| 56 |
|
| 57 |
def correct_spelling(q):
|
| 58 |
words = q.split()
|
|
|
|
| 99 |
# TABLE MATCHING (CORE LOGIC)
|
| 100 |
# =========================
|
| 101 |
|
| 102 |
+
def extract_relevant_tables(question, max_tables=4):
|
| 103 |
schema = load_ai_schema()
|
| 104 |
q = question.lower()
|
|
|
|
| 105 |
tokens = set(q.replace("?", "").replace(",", "").split())
|
| 106 |
+
|
| 107 |
matched = []
|
| 108 |
|
| 109 |
+
# Lightweight intent hints (NO hard dependency)
|
| 110 |
+
DOMAIN_HINTS = {
|
| 111 |
+
"consultant": ["encounters"],
|
| 112 |
+
"doctor": ["encounters"],
|
| 113 |
+
"visit": ["encounters"],
|
| 114 |
+
"appointment": ["encounters"],
|
| 115 |
+
"patient": ["patients"],
|
| 116 |
+
"medication": ["medications"],
|
| 117 |
+
"drug": ["medications"],
|
| 118 |
+
"condition": ["conditions"],
|
| 119 |
+
"diagnosis": ["conditions"]
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Early exit threshold - if we find a perfect match, we can stop early
|
| 123 |
+
VERY_HIGH_SCORE = 10
|
| 124 |
+
|
| 125 |
for table, meta in schema.items():
|
| 126 |
score = 0
|
| 127 |
table_l = table.lower()
|
| 128 |
|
| 129 |
+
# 1️⃣ Strong signal: table name (exact match is very high confidence)
|
| 130 |
if table_l in q:
|
| 131 |
+
score += 6
|
| 132 |
+
# Early exit optimization: if exact table match found, prioritize it
|
| 133 |
+
if score >= VERY_HIGH_SCORE:
|
| 134 |
+
matched.append((table, score))
|
| 135 |
+
continue
|
|
|
|
| 136 |
|
| 137 |
+
# 2️⃣ Column relevance
|
| 138 |
for col, _ in meta["columns"]:
|
| 139 |
col_l = col.lower()
|
| 140 |
if col_l in q:
|
|
|
|
| 142 |
elif any(tok in col_l for tok in tokens):
|
| 143 |
score += 1
|
| 144 |
|
| 145 |
+
# 3️⃣ Description relevance
|
| 146 |
+
if meta.get("description"):
|
| 147 |
+
desc_tokens = set(meta["description"].lower().split())
|
| 148 |
+
score += len(tokens & desc_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
+
# 4️⃣ Semantic intent mapping (important)
|
| 151 |
+
for intent, tables in DOMAIN_HINTS.items():
|
| 152 |
+
if intent in q and table_l in tables:
|
| 153 |
+
score += 5
|
| 154 |
|
| 155 |
+
# 5️⃣ Only add if meets minimum threshold (prevents low-quality matches)
|
| 156 |
+
if score >= 3:
|
| 157 |
matched.append((table, score))
|
| 158 |
|
| 159 |
# Sort by relevance
|
| 160 |
matched.sort(key=lambda x: x[1], reverse=True)
|
| 161 |
|
|
|
|
| 162 |
return [t[0] for t in matched[:max_tables]]
|
| 163 |
|
| 164 |
|
|
|
|
|
|
|
| 165 |
# =========================
|
| 166 |
# HUMAN SCHEMA DESCRIPTION
|
| 167 |
# =========================
|
| 168 |
|
| 169 |
+
def describe_schema(max_tables=10):
|
| 170 |
schema = load_ai_schema()
|
| 171 |
+
total_tables = len(schema)
|
| 172 |
|
| 173 |
+
response = f"Here's the data I currently have access to ({total_tables} tables):\n\n"
|
| 174 |
|
| 175 |
+
# Show only top N tables to avoid overwhelming output
|
| 176 |
+
shown_tables = list(schema.items())[:max_tables]
|
| 177 |
+
|
| 178 |
+
for table, meta in shown_tables:
|
| 179 |
response += f"• **{table.capitalize()}** — {meta['description']}\n"
|
| 180 |
+
# Show only first 5 columns per table
|
| 181 |
+
for col, desc in list(meta["columns"])[:5]:
|
| 182 |
response += f" - {col}: {desc}\n"
|
| 183 |
+
if len(meta["columns"]) > 5:
|
| 184 |
+
response += f" ... and {len(meta['columns']) - 5} more columns\n"
|
| 185 |
response += "\n"
|
| 186 |
|
| 187 |
+
if total_tables > max_tables:
|
| 188 |
+
response += f"\n... and {total_tables - max_tables} more tables.\n"
|
| 189 |
+
response += "Ask about a specific table to see its details.\n\n"
|
| 190 |
+
|
| 191 |
response += (
|
| 192 |
"You can ask things like:\n"
|
| 193 |
"• How many patients are there?\n"
|
|
|
|
| 195 |
"• Admissions by year\n\n"
|
| 196 |
"Just tell me what you want to explore "
|
| 197 |
)
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
return response
|
| 200 |
|
|
|
|
| 228 |
q = question.lower()
|
| 229 |
tokens = set(q.replace("?", "").replace(",", "").split())
|
| 230 |
|
| 231 |
+
# 1️⃣ Allow analytical intent even without table names
|
| 232 |
analytic_keywords = {
|
| 233 |
"count", "total", "average", "avg", "sum",
|
| 234 |
+
"how many", "number of", "trend",
|
| 235 |
+
"increase", "decrease", "compare", "more than", "less than"
|
| 236 |
}
|
| 237 |
|
| 238 |
if any(k in q for k in analytic_keywords):
|
| 239 |
return True
|
| 240 |
|
| 241 |
+
# 2️⃣ Check schema relevance (table-by-table)
|
| 242 |
schema = load_ai_schema()
|
|
|
|
| 243 |
|
| 244 |
for table, meta in schema.items():
|
| 245 |
+
score = 0
|
| 246 |
table_l = table.lower()
|
| 247 |
|
| 248 |
# Table name match
|
|
|
|
| 259 |
|
| 260 |
# Description match
|
| 261 |
if meta.get("description"):
|
| 262 |
+
desc_tokens = set(meta["description"].lower().split())
|
| 263 |
+
score += len(tokens & desc_tokens)
|
| 264 |
+
|
| 265 |
+
# ✅ If any table is relevant enough → supported
|
| 266 |
+
if score >= 2:
|
| 267 |
+
return True
|
| 268 |
+
|
| 269 |
+
return False
|
| 270 |
|
|
|
|
|
|
|
| 271 |
|
| 272 |
|
| 273 |
# =========================
|
|
|
|
| 280 |
if matched:
|
| 281 |
schema = {t: load_ai_schema()[t] for t in matched}
|
| 282 |
else:
|
| 283 |
+
# 🚫 Don't send all 100+ tables! Return a helpful error instead
|
| 284 |
+
raise ValueError(
|
| 285 |
+
"I couldn't find any relevant tables for your question. "
|
| 286 |
+
"Please try mentioning a specific table name or use 'what data' to see available tables."
|
| 287 |
+
)
|
| 288 |
|
| 289 |
prompt = """
|
| 290 |
You are a hospital SQL assistant.
|
|
|
|
| 355 |
cur = conn.cursor()
|
| 356 |
return cur.execute(test_sql).fetchone() is not None
|
| 357 |
|
|
|
|
|
|
|
|
|
|
| 358 |
# =========================
|
| 359 |
# PATIENT SUMMARY
|
| 360 |
# =========================
|
|
|
|
| 495 |
# ----------------------------------
|
| 496 |
# Generate SQL
|
| 497 |
# ----------------------------------
|
| 498 |
+
try:
|
| 499 |
+
sql = call_llm(build_prompt(question))
|
| 500 |
+
except ValueError as e:
|
| 501 |
+
# Handle case where no relevant tables found
|
| 502 |
+
return {
|
| 503 |
+
"status": "ok",
|
| 504 |
+
"message": str(e),
|
| 505 |
+
"data": []
|
| 506 |
+
}
|
| 507 |
|
| 508 |
if sql == "NOT_ANSWERABLE":
|
| 509 |
return {
|
| 510 |
"status": "ok",
|
| 511 |
+
"message": "I don't have enough data to answer that.",
|
| 512 |
"data": []
|
| 513 |
}
|
| 514 |
|