Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- engine.py +12 -30
- mimic_iv_demo.db +0 -0
engine.py
CHANGED
|
@@ -13,7 +13,8 @@ api_key = os.getenv("OPENAI_API_KEY")
|
|
| 13 |
if not api_key:
|
| 14 |
raise ValueError("OPENAI_API_KEY environment variable is not set")
|
| 15 |
client = OpenAI(api_key=api_key)
|
| 16 |
-
conn = sqlite3.connect("
|
|
|
|
| 17 |
|
| 18 |
# =========================
|
| 19 |
# CONVERSATION STATE
|
|
@@ -74,32 +75,13 @@ def correct_spelling(q):
|
|
| 74 |
# =========================
|
| 75 |
# SCHEMA
|
| 76 |
# =========================
|
|
|
|
| 77 |
from functools import lru_cache
|
| 78 |
|
| 79 |
@lru_cache(maxsize=1)
|
| 80 |
def load_ai_schema():
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
tables = cur.execute("""
|
| 85 |
-
SELECT table_name, description
|
| 86 |
-
FROM ai_tables
|
| 87 |
-
WHERE ai_enabled = 1
|
| 88 |
-
""").fetchall()
|
| 89 |
-
|
| 90 |
-
for table, desc in tables:
|
| 91 |
-
cols = cur.execute("""
|
| 92 |
-
SELECT column_name, description
|
| 93 |
-
FROM ai_columns
|
| 94 |
-
WHERE table_name = ? AND ai_allowed = 1
|
| 95 |
-
""", (table,)).fetchall()
|
| 96 |
-
|
| 97 |
-
schema[table] = {
|
| 98 |
-
"description": desc,
|
| 99 |
-
"columns": cols
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
return schema
|
| 103 |
|
| 104 |
# =========================
|
| 105 |
# TABLE MATCHING (CORE LOGIC)
|
|
@@ -216,14 +198,14 @@ def describe_schema(max_tables=10):
|
|
| 216 |
# =========================
|
| 217 |
|
| 218 |
def get_latest_data_date():
|
| 219 |
-
"""Get the latest data date from encounters table."""
|
| 220 |
-
cur = conn.cursor()
|
| 221 |
try:
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
| 225 |
return None
|
| 226 |
|
|
|
|
| 227 |
def normalize_time_question(q):
|
| 228 |
latest = get_latest_data_date()
|
| 229 |
if not latest:
|
|
@@ -319,7 +301,7 @@ If the question mentions "consultant" or "doctor", use the table name "encounter
|
|
| 319 |
|
| 320 |
for table, meta in schema.items():
|
| 321 |
prompt += f"\nTable: {table}\n"
|
| 322 |
-
for col, desc in meta["columns"]:
|
| 323 |
prompt += f"- {col}: {desc}\n"
|
| 324 |
|
| 325 |
prompt += f"\nQuestion: {question}\n"
|
|
@@ -471,7 +453,7 @@ def build_table_summary(table_name):
|
|
| 471 |
summary += f"• Total records: {total}\n"
|
| 472 |
|
| 473 |
# Try to summarize categorical columns using metadata
|
| 474 |
-
for col_name, col_desc in columns:
|
| 475 |
# Validate column name
|
| 476 |
if not validate_identifier(col_name):
|
| 477 |
continue
|
|
|
|
| 13 |
if not api_key:
|
| 14 |
raise ValueError("OPENAI_API_KEY environment variable is not set")
|
| 15 |
client = OpenAI(api_key=api_key)
|
| 16 |
+
conn = sqlite3.connect("mimic_iv_demo.db", check_same_thread=False)
|
| 17 |
+
|
| 18 |
|
| 19 |
# =========================
|
| 20 |
# CONVERSATION STATE
|
|
|
|
| 75 |
# =========================
|
| 76 |
# SCHEMA
|
| 77 |
# =========================
|
| 78 |
+
import json
|
| 79 |
from functools import lru_cache
|
| 80 |
|
| 81 |
@lru_cache(maxsize=1)
|
| 82 |
def load_ai_schema():
|
| 83 |
+
with open("hospital_metadata.json", "r") as f:
|
| 84 |
+
return json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
# =========================
|
| 87 |
# TABLE MATCHING (CORE LOGIC)
|
|
|
|
| 198 |
# =========================
|
| 199 |
|
| 200 |
def get_latest_data_date():
|
|
|
|
|
|
|
| 201 |
try:
|
| 202 |
+
return conn.execute(
|
| 203 |
+
"SELECT MAX(admittime) FROM admissions"
|
| 204 |
+
).fetchone()[0]
|
| 205 |
+
except:
|
| 206 |
return None
|
| 207 |
|
| 208 |
+
|
| 209 |
def normalize_time_question(q):
|
| 210 |
latest = get_latest_data_date()
|
| 211 |
if not latest:
|
|
|
|
| 301 |
|
| 302 |
for table, meta in schema.items():
|
| 303 |
prompt += f"\nTable: {table}\n"
|
| 304 |
+
for col, desc in meta["columns"].items():
|
| 305 |
prompt += f"- {col}: {desc}\n"
|
| 306 |
|
| 307 |
prompt += f"\nQuestion: {question}\n"
|
|
|
|
| 453 |
summary += f"• Total records: {total}\n"
|
| 454 |
|
| 455 |
# Try to summarize categorical columns using metadata
|
| 456 |
+
for col_name, col_desc in columns.items():
|
| 457 |
# Validate column name
|
| 458 |
if not validate_identifier(col_name):
|
| 459 |
continue
|
mimic_iv_demo.db
ADDED
|
Binary file (8.19 kB). View file
|
|
|