Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,10 +22,8 @@ from typing import List, Dict, Any, Tuple, Optional
|
|
| 22 |
|
| 23 |
import gradio as gr
|
| 24 |
import pandas as pd
|
| 25 |
-
import numpy as np
|
| 26 |
|
| 27 |
# -------------------- OpenAI (optional) --------------------
|
| 28 |
-
USE_RESPONSES_API = True
|
| 29 |
OPENAI_AVAILABLE = True
|
| 30 |
DEFAULT_MODEL = os.getenv("OPENAI_MODEL") # optional override
|
| 31 |
try:
|
|
@@ -37,11 +35,10 @@ except Exception:
|
|
| 37 |
|
| 38 |
def _candidate_models():
|
| 39 |
base = [
|
| 40 |
-
DEFAULT_MODEL,
|
| 41 |
"gpt-4o-mini",
|
| 42 |
"gpt-4o",
|
| 43 |
"gpt-4.1-mini",
|
| 44 |
-
"o3-mini",
|
| 45 |
]
|
| 46 |
seen = set()
|
| 47 |
return [m for m in base if m and (m not in seen and not seen.add(m))]
|
|
@@ -50,7 +47,6 @@ def _candidate_models():
|
|
| 50 |
DB_DIR = "/data" if os.path.exists("/data") else "."
|
| 51 |
DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
|
| 52 |
EXPORT_DIR = "."
|
| 53 |
-
ADMIN_KEY = os.getenv("ADMIN_KEY", "demo")
|
| 54 |
RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
|
| 55 |
random.seed(RANDOM_SEED)
|
| 56 |
SYS_RAND = random.SystemRandom()
|
|
@@ -60,8 +56,8 @@ DB_LOCK = threading.RLock()
|
|
| 60 |
|
| 61 |
def connect_db():
|
| 62 |
"""
|
| 63 |
-
Single shared connection
|
| 64 |
-
All operations (reads + writes)
|
| 65 |
WAL mode enables concurrent reads.
|
| 66 |
"""
|
| 67 |
con = sqlite3.connect(DB_PATH, check_same_thread=False)
|
|
@@ -221,137 +217,91 @@ FALLBACK_QUESTIONS = [
|
|
| 221 |
"requires_aliases":False,"required_aliases":[]},
|
| 222 |
]
|
| 223 |
|
| 224 |
-
# -------------------- OpenAI JSON schema --------------------
|
| 225 |
DOMAIN_AND_QUESTIONS_SCHEMA = {
|
| 226 |
-
"
|
| 227 |
-
"schema": {
|
| 228 |
-
"type": "object",
|
| 229 |
-
"additionalProperties": False,
|
| 230 |
-
"properties": {
|
| 231 |
-
"domain": {"type":"string"},
|
| 232 |
-
"tables": {
|
| 233 |
-
"type":"array",
|
| 234 |
-
"items": {
|
| 235 |
-
"type":"object",
|
| 236 |
-
"additionalProperties": False,
|
| 237 |
-
"properties": {
|
| 238 |
-
"name": {"type":"string"},
|
| 239 |
-
"pk": {"type":"array","items":{"type":"string"}},
|
| 240 |
-
"columns": {
|
| 241 |
-
"type":"array",
|
| 242 |
-
"items": {
|
| 243 |
-
"type":"object",
|
| 244 |
-
"additionalProperties": False,
|
| 245 |
-
"properties": {"name":{"type":"string"}, "type":{"type":"string"}},
|
| 246 |
-
"required":["name","type"]
|
| 247 |
-
}
|
| 248 |
-
},
|
| 249 |
-
"fks": {
|
| 250 |
-
"type":"array",
|
| 251 |
-
"items": {
|
| 252 |
-
"type":"object",
|
| 253 |
-
"additionalProperties": False,
|
| 254 |
-
"properties": {
|
| 255 |
-
"columns":{"type":"array","items":{"type":"string"}},
|
| 256 |
-
"ref_table":{"type":"string"},
|
| 257 |
-
"ref_columns":{"type":"array","items":{"type":"string"}}
|
| 258 |
-
},
|
| 259 |
-
"required":["columns","ref_table","ref_columns"]
|
| 260 |
-
}
|
| 261 |
-
},
|
| 262 |
-
"rows": {"type":"array","items":{"type":["object","array"]}}
|
| 263 |
-
},
|
| 264 |
-
"required":["name","pk","columns","fks","rows"]
|
| 265 |
-
},
|
| 266 |
-
"minItems":3,"maxItems":4
|
| 267 |
-
},
|
| 268 |
-
"questions": {
|
| 269 |
-
"type":"array",
|
| 270 |
-
"items": {
|
| 271 |
-
"type":"object",
|
| 272 |
-
"additionalProperties": False,
|
| 273 |
-
"properties": {
|
| 274 |
-
"id":{"type":"string"},
|
| 275 |
-
"category":{"type":"string"},
|
| 276 |
-
"difficulty":{"type":"integer"},
|
| 277 |
-
"prompt_md":{"type":"string"},
|
| 278 |
-
"answer_sql":{"type":"array","items":{"type":"string"}},
|
| 279 |
-
"requires_aliases":{"type":"boolean"},
|
| 280 |
-
"required_aliases":{"type":"array","items":{"type":"string"}}
|
| 281 |
-
},
|
| 282 |
-
"required":["id","category","difficulty","prompt_md","answer_sql"]
|
| 283 |
-
},
|
| 284 |
-
"minItems":8,"maxItems":12
|
| 285 |
-
}
|
| 286 |
-
},
|
| 287 |
-
"required":["domain","tables","questions"]
|
| 288 |
-
},
|
| 289 |
-
"strict": True
|
| 290 |
}
|
| 291 |
|
| 292 |
def _domain_prompt(prev_domain: Optional[str]) -> str:
|
| 293 |
extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
|
| 294 |
return f"""
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
- For 1–2 questions, require table aliases (set requires_aliases=true and list required_aliases).
|
| 315 |
-
|
| 316 |
-
Return JSON only.
|
| 317 |
"""
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str]]:
|
| 320 |
"""
|
| 321 |
Returns (obj, error_message, model_used).
|
|
|
|
| 322 |
"""
|
| 323 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 324 |
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None
|
| 325 |
|
| 326 |
errors = []
|
|
|
|
|
|
|
| 327 |
for model in _candidate_models():
|
|
|
|
| 328 |
try:
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
resp = _client.responses.create(
|
| 332 |
model=model,
|
| 333 |
-
|
| 334 |
-
input=[{"role":"user","content": prompt}],
|
| 335 |
temperature=0.6,
|
|
|
|
| 336 |
)
|
| 337 |
-
data_text =
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
data_text = resp.output[0].content[0].text # older SDK layout
|
| 341 |
-
except Exception:
|
| 342 |
-
data_text = None
|
| 343 |
-
else:
|
| 344 |
chat = _client.chat.completions.create(
|
| 345 |
model=model,
|
| 346 |
-
messages=[{"role":"
|
|
|
|
| 347 |
temperature=0.6
|
| 348 |
)
|
| 349 |
data_text = chat.choices[0].message.content
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
|
|
|
| 353 |
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
| 355 |
# Guardrails: strip RIGHT/FULL joins from answers
|
| 356 |
clean_qs = []
|
| 357 |
for q in obj.get("questions", []):
|
|
|
|
| 22 |
|
| 23 |
import gradio as gr
|
| 24 |
import pandas as pd
|
|
|
|
| 25 |
|
| 26 |
# -------------------- OpenAI (optional) --------------------
|
|
|
|
| 27 |
OPENAI_AVAILABLE = True
|
| 28 |
DEFAULT_MODEL = os.getenv("OPENAI_MODEL") # optional override
|
| 29 |
try:
|
|
|
|
| 35 |
|
| 36 |
def _candidate_models():
|
| 37 |
base = [
|
| 38 |
+
(DEFAULT_MODEL or "").strip() or None,
|
| 39 |
"gpt-4o-mini",
|
| 40 |
"gpt-4o",
|
| 41 |
"gpt-4.1-mini",
|
|
|
|
| 42 |
]
|
| 43 |
seen = set()
|
| 44 |
return [m for m in base if m and (m not in seen and not seen.add(m))]
|
|
|
|
| 47 |
DB_DIR = "/data" if os.path.exists("/data") else "."
|
| 48 |
DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
|
| 49 |
EXPORT_DIR = "."
|
|
|
|
| 50 |
RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
|
| 51 |
random.seed(RANDOM_SEED)
|
| 52 |
SYS_RAND = random.SystemRandom()
|
|
|
|
| 56 |
|
| 57 |
def connect_db():
|
| 58 |
"""
|
| 59 |
+
Single shared connection usable across threads.
|
| 60 |
+
All operations (reads + writes) serialized via DB_LOCK.
|
| 61 |
WAL mode enables concurrent reads.
|
| 62 |
"""
|
| 63 |
con = sqlite3.connect(DB_PATH, check_same_thread=False)
|
|
|
|
| 217 |
"requires_aliases":False,"required_aliases":[]},
|
| 218 |
]
|
| 219 |
|
| 220 |
+
# -------------------- OpenAI JSON schema (validated after parse) --------------------
|
| 221 |
DOMAIN_AND_QUESTIONS_SCHEMA = {
|
| 222 |
+
"required": ["domain", "tables", "questions"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
}
|
| 224 |
|
| 225 |
def _domain_prompt(prev_domain: Optional[str]) -> str:
|
| 226 |
extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
|
| 227 |
return f"""
|
| 228 |
+
Return ONLY a valid JSON object (no markdown, no prose).
|
| 229 |
+
The JSON must have: domain (string), tables (3–4 table objects), and questions (8–12 question objects).{extra}
|
| 230 |
+
|
| 231 |
+
Rules:
|
| 232 |
+
- One domain chosen from: bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
|
| 233 |
+
- Tables: SQLite-friendly. Use snake_case. Each table has: name, pk (list of column names),
|
| 234 |
+
columns (list of {{name,type}}), fks (list of {{columns,ref_table,ref_columns}}), rows (8–15 small seed rows).
|
| 235 |
+
- Questions: diverse natural language. Categories: "SELECT *", "SELECT columns", "WHERE", "Aliases",
|
| 236 |
+
"JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
|
| 237 |
+
Include at least one LEFT JOIN, one VIEW creation, one CTAS or SELECT INTO.
|
| 238 |
+
Provide 1–3 'answer_sql' strings per question. Prefer SQLite-compatible SQL. Do NOT use RIGHT/FULL OUTER JOIN.
|
| 239 |
+
For 1–2 questions, set requires_aliases=true and list required_aliases.
|
| 240 |
+
|
| 241 |
+
Example top-level keys (do not include comments in output):
|
| 242 |
+
{{
|
| 243 |
+
"domain": "retail sales",
|
| 244 |
+
"tables": [...],
|
| 245 |
+
"questions": [...]
|
| 246 |
+
}}
|
|
|
|
|
|
|
|
|
|
| 247 |
"""
|
| 248 |
|
| 249 |
+
def _loose_json_parse(s: str) -> Optional[dict]:
|
| 250 |
+
"""Extract the first JSON object from a possibly-wrapped string."""
|
| 251 |
+
try:
|
| 252 |
+
return json.loads(s)
|
| 253 |
+
except Exception:
|
| 254 |
+
pass
|
| 255 |
+
# Try to find the first {...} block
|
| 256 |
+
start = s.find("{")
|
| 257 |
+
end = s.rfind("}")
|
| 258 |
+
if start != -1 and end != -1 and end > start:
|
| 259 |
+
try:
|
| 260 |
+
return json.loads(s[start:end+1])
|
| 261 |
+
except Exception:
|
| 262 |
+
return None
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str]]:
|
| 266 |
"""
|
| 267 |
Returns (obj, error_message, model_used).
|
| 268 |
+
Uses Chat Completions JSON mode if available; otherwise falls back to strict-instruction parsing.
|
| 269 |
"""
|
| 270 |
if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
|
| 271 |
return None, "OpenAI client not available or OPENAI_API_KEY missing.", None
|
| 272 |
|
| 273 |
errors = []
|
| 274 |
+
prompt = _domain_prompt(prev_domain)
|
| 275 |
+
|
| 276 |
for model in _candidate_models():
|
| 277 |
+
# Try JSON mode first (if supported)
|
| 278 |
try:
|
| 279 |
+
try:
|
| 280 |
+
chat = _client.chat.completions.create(
|
|
|
|
| 281 |
model=model,
|
| 282 |
+
messages=[{"role":"user","content": prompt}],
|
|
|
|
| 283 |
temperature=0.6,
|
| 284 |
+
response_format={"type":"json_object"} # newer SDKs
|
| 285 |
)
|
| 286 |
+
data_text = chat.choices[0].message.content
|
| 287 |
+
except TypeError:
|
| 288 |
+
# Older SDKs: no response_format argument → plain completion with strict instructions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
chat = _client.chat.completions.create(
|
| 290 |
model=model,
|
| 291 |
+
messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
|
| 292 |
+
{"role":"user","content": prompt}],
|
| 293 |
temperature=0.6
|
| 294 |
)
|
| 295 |
data_text = chat.choices[0].message.content
|
| 296 |
|
| 297 |
+
obj = _loose_json_parse(data_text or "")
|
| 298 |
+
if not obj:
|
| 299 |
+
raise RuntimeError("Could not parse JSON from model output.")
|
| 300 |
|
| 301 |
+
# Minimal validation
|
| 302 |
+
for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
|
| 303 |
+
if k not in obj:
|
| 304 |
+
raise RuntimeError(f"Missing key '{k}'")
|
| 305 |
# Guardrails: strip RIGHT/FULL joins from answers
|
| 306 |
clean_qs = []
|
| 307 |
for q in obj.get("questions", []):
|