jtdearmon commited on
Commit
4909489
·
verified ·
1 Parent(s): f642f6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -110
app.py CHANGED
@@ -22,10 +22,8 @@ from typing import List, Dict, Any, Tuple, Optional
22
 
23
  import gradio as gr
24
  import pandas as pd
25
- import numpy as np
26
 
27
  # -------------------- OpenAI (optional) --------------------
28
- USE_RESPONSES_API = True
29
  OPENAI_AVAILABLE = True
30
  DEFAULT_MODEL = os.getenv("OPENAI_MODEL") # optional override
31
  try:
@@ -37,11 +35,10 @@ except Exception:
37
 
38
  def _candidate_models():
39
  base = [
40
- DEFAULT_MODEL,
41
  "gpt-4o-mini",
42
  "gpt-4o",
43
  "gpt-4.1-mini",
44
- "o3-mini",
45
  ]
46
  seen = set()
47
  return [m for m in base if m and (m not in seen and not seen.add(m))]
@@ -50,7 +47,6 @@ def _candidate_models():
50
  DB_DIR = "/data" if os.path.exists("/data") else "."
51
  DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
52
  EXPORT_DIR = "."
53
- ADMIN_KEY = os.getenv("ADMIN_KEY", "demo")
54
  RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
55
  random.seed(RANDOM_SEED)
56
  SYS_RAND = random.SystemRandom()
@@ -60,8 +56,8 @@ DB_LOCK = threading.RLock()
60
 
61
  def connect_db():
62
  """
63
- Single shared connection that can be used across threads.
64
- All operations (reads + writes) are serialized via DB_LOCK.
65
  WAL mode enables concurrent reads.
66
  """
67
  con = sqlite3.connect(DB_PATH, check_same_thread=False)
@@ -221,137 +217,91 @@ FALLBACK_QUESTIONS = [
221
  "requires_aliases":False,"required_aliases":[]},
222
  ]
223
 
224
- # -------------------- OpenAI JSON schema --------------------
225
  DOMAIN_AND_QUESTIONS_SCHEMA = {
226
- "name": "DomainSQLPack",
227
- "schema": {
228
- "type": "object",
229
- "additionalProperties": False,
230
- "properties": {
231
- "domain": {"type":"string"},
232
- "tables": {
233
- "type":"array",
234
- "items": {
235
- "type":"object",
236
- "additionalProperties": False,
237
- "properties": {
238
- "name": {"type":"string"},
239
- "pk": {"type":"array","items":{"type":"string"}},
240
- "columns": {
241
- "type":"array",
242
- "items": {
243
- "type":"object",
244
- "additionalProperties": False,
245
- "properties": {"name":{"type":"string"}, "type":{"type":"string"}},
246
- "required":["name","type"]
247
- }
248
- },
249
- "fks": {
250
- "type":"array",
251
- "items": {
252
- "type":"object",
253
- "additionalProperties": False,
254
- "properties": {
255
- "columns":{"type":"array","items":{"type":"string"}},
256
- "ref_table":{"type":"string"},
257
- "ref_columns":{"type":"array","items":{"type":"string"}}
258
- },
259
- "required":["columns","ref_table","ref_columns"]
260
- }
261
- },
262
- "rows": {"type":"array","items":{"type":["object","array"]}}
263
- },
264
- "required":["name","pk","columns","fks","rows"]
265
- },
266
- "minItems":3,"maxItems":4
267
- },
268
- "questions": {
269
- "type":"array",
270
- "items": {
271
- "type":"object",
272
- "additionalProperties": False,
273
- "properties": {
274
- "id":{"type":"string"},
275
- "category":{"type":"string"},
276
- "difficulty":{"type":"integer"},
277
- "prompt_md":{"type":"string"},
278
- "answer_sql":{"type":"array","items":{"type":"string"}},
279
- "requires_aliases":{"type":"boolean"},
280
- "required_aliases":{"type":"array","items":{"type":"string"}}
281
- },
282
- "required":["id","category","difficulty","prompt_md","answer_sql"]
283
- },
284
- "minItems":8,"maxItems":12
285
- }
286
- },
287
- "required":["domain","tables","questions"]
288
- },
289
- "strict": True
290
  }
291
 
292
  def _domain_prompt(prev_domain: Optional[str]) -> str:
293
  extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
294
  return f"""
295
- You are designing a small relational dataset and training questions for SQL basics.{extra}
296
-
297
- 1) Choose ONE domain at random from:
298
- - bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
299
-
300
- 2) Produce exactly 3–4 tables that fit together (SQLite-friendly):
301
- - Use snake_case, avoid reserved words.
302
- - Types: INTEGER, REAL, TEXT, NUMERIC, DATE (no advanced features).
303
- - Primary keys (pk) and foreign keys (fks) must align.
304
- - Provide 8–15 small, realistic seed rows per table (not huge).
305
-
306
- 3) Generate 812 SQL questions covering basics with varied, natural language:
307
- - Categories from: "SELECT *", "SELECT columns", "WHERE", "Aliases",
308
- "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
309
- - Include a few joins and at least one LEFT JOIN.
310
- - Include one view creation.
311
- - Include one table creation from SELECT (either CTAS or SELECT INTO).
312
- - Prefer SQLite-compatible SQL. DO NOT use RIGHT/FULL OUTER JOIN.
313
- - Offer 1–3 acceptable answer_sql variants per question.
314
- - For 1–2 questions, require table aliases (set requires_aliases=true and list required_aliases).
315
-
316
- Return JSON only.
317
  """
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str]]:
320
  """
321
  Returns (obj, error_message, model_used).
 
322
  """
323
  if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
324
  return None, "OpenAI client not available or OPENAI_API_KEY missing.", None
325
 
326
  errors = []
 
 
327
  for model in _candidate_models():
 
328
  try:
329
- prompt = _domain_prompt(prev_domain)
330
- if USE_RESPONSES_API:
331
- resp = _client.responses.create(
332
  model=model,
333
- response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA},
334
- input=[{"role":"user","content": prompt}],
335
  temperature=0.6,
 
336
  )
337
- data_text = getattr(resp, "output_text", None)
338
- if not data_text:
339
- try:
340
- data_text = resp.output[0].content[0].text # older SDK layout
341
- except Exception:
342
- data_text = None
343
- else:
344
  chat = _client.chat.completions.create(
345
  model=model,
346
- messages=[{"role":"user","content": prompt}],
 
347
  temperature=0.6
348
  )
349
  data_text = chat.choices[0].message.content
350
 
351
- if not data_text:
352
- raise RuntimeError("Empty response from model.")
 
353
 
354
- obj = json.loads(data_text)
 
 
 
355
  # Guardrails: strip RIGHT/FULL joins from answers
356
  clean_qs = []
357
  for q in obj.get("questions", []):
 
22
 
23
  import gradio as gr
24
  import pandas as pd
 
25
 
26
  # -------------------- OpenAI (optional) --------------------
 
27
  OPENAI_AVAILABLE = True
28
  DEFAULT_MODEL = os.getenv("OPENAI_MODEL") # optional override
29
  try:
 
35
 
36
  def _candidate_models():
37
  base = [
38
+ (DEFAULT_MODEL or "").strip() or None,
39
  "gpt-4o-mini",
40
  "gpt-4o",
41
  "gpt-4.1-mini",
 
42
  ]
43
  seen = set()
44
  return [m for m in base if m and (m not in seen and not seen.add(m))]
 
47
  DB_DIR = "/data" if os.path.exists("/data") else "."
48
  DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
49
  EXPORT_DIR = "."
 
50
  RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
51
  random.seed(RANDOM_SEED)
52
  SYS_RAND = random.SystemRandom()
 
56
 
57
  def connect_db():
58
  """
59
+ Single shared connection usable across threads.
60
+ All operations (reads + writes) serialized via DB_LOCK.
61
  WAL mode enables concurrent reads.
62
  """
63
  con = sqlite3.connect(DB_PATH, check_same_thread=False)
 
217
  "requires_aliases":False,"required_aliases":[]},
218
  ]
219
 
220
+ # -------------------- OpenAI JSON schema (validated after parse) --------------------
221
  DOMAIN_AND_QUESTIONS_SCHEMA = {
222
+ "required": ["domain", "tables", "questions"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  }
224
 
225
  def _domain_prompt(prev_domain: Optional[str]) -> str:
226
  extra = f" Avoid using the previous domain '{prev_domain}' if possible." if prev_domain else ""
227
  return f"""
228
+ Return ONLY a valid JSON object (no markdown, no prose).
229
+ The JSON must have: domain (string), tables (3–4 table objects), and questions (8–12 question objects).{extra}
230
+
231
+ Rules:
232
+ - One domain chosen from: bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.
233
+ - Tables: SQLite-friendly. Use snake_case. Each table has: name, pk (list of column names),
234
+ columns (list of {{name,type}}), fks (list of {{columns,ref_table,ref_columns}}), rows (8–15 small seed rows).
235
+ - Questions: diverse natural language. Categories: "SELECT *", "SELECT columns", "WHERE", "Aliases",
236
+ "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
237
+ Include at least one LEFT JOIN, one VIEW creation, one CTAS or SELECT INTO.
238
+ Provide 1–3 'answer_sql' strings per question. Prefer SQLite-compatible SQL. Do NOT use RIGHT/FULL OUTER JOIN.
239
+ For 12 questions, set requires_aliases=true and list required_aliases.
240
+
241
+ Example top-level keys (do not include comments in output):
242
+ {{
243
+ "domain": "retail sales",
244
+ "tables": [...],
245
+ "questions": [...]
246
+ }}
 
 
 
247
  """
248
 
249
+ def _loose_json_parse(s: str) -> Optional[dict]:
250
+ """Extract the first JSON object from a possibly-wrapped string."""
251
+ try:
252
+ return json.loads(s)
253
+ except Exception:
254
+ pass
255
+ # Try to find the first {...} block
256
+ start = s.find("{")
257
+ end = s.rfind("}")
258
+ if start != -1 and end != -1 and end > start:
259
+ try:
260
+ return json.loads(s[start:end+1])
261
+ except Exception:
262
+ return None
263
+ return None
264
+
265
  def llm_generate_domain_and_questions(prev_domain: Optional[str]) -> Tuple[Optional[Dict[str,Any]], Optional[str], Optional[str]]:
266
  """
267
  Returns (obj, error_message, model_used).
268
+ Uses Chat Completions JSON mode if available; otherwise falls back to strict-instruction parsing.
269
  """
270
  if not OPENAI_AVAILABLE or not os.getenv("OPENAI_API_KEY"):
271
  return None, "OpenAI client not available or OPENAI_API_KEY missing.", None
272
 
273
  errors = []
274
+ prompt = _domain_prompt(prev_domain)
275
+
276
  for model in _candidate_models():
277
+ # Try JSON mode first (if supported)
278
  try:
279
+ try:
280
+ chat = _client.chat.completions.create(
 
281
  model=model,
282
+ messages=[{"role":"user","content": prompt}],
 
283
  temperature=0.6,
284
+ response_format={"type":"json_object"} # newer SDKs
285
  )
286
+ data_text = chat.choices[0].message.content
287
+ except TypeError:
288
+ # Older SDKs: no response_format argument → plain completion with strict instructions
 
 
 
 
289
  chat = _client.chat.completions.create(
290
  model=model,
291
+ messages=[{"role":"system","content":"Return ONLY a JSON object. No markdown."},
292
+ {"role":"user","content": prompt}],
293
  temperature=0.6
294
  )
295
  data_text = chat.choices[0].message.content
296
 
297
+ obj = _loose_json_parse(data_text or "")
298
+ if not obj:
299
+ raise RuntimeError("Could not parse JSON from model output.")
300
 
301
+ # Minimal validation
302
+ for k in DOMAIN_AND_QUESTIONS_SCHEMA["required"]:
303
+ if k not in obj:
304
+ raise RuntimeError(f"Missing key '{k}'")
305
  # Guardrails: strip RIGHT/FULL joins from answers
306
  clean_qs = []
307
  for q in obj.get("questions", []):