File size: 42,367 Bytes
51df8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d86caf
 
 
 
 
 
 
 
51df8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d86caf
51df8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c666316
51df8be
6d86caf
51df8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d86caf
 
 
 
 
 
51df8be
 
 
6d86caf
51df8be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
# Adaptive SQL Trainer — Domain Randomized with OpenAI (Gradio + SQLite)
# - Randomizes a relational domain via OpenAI (bookstore, retail sales, wholesaler,
#   sales tax, oil & gas wells, marketing) OR falls back to a built-in dataset.
# - Builds 3–4 related tables (schema + seed rows) in SQLite.
# - Generates 8–12 randomized SQL questions with varied phrasings.
# - Validates answers by executing canonical SQL and comparing result sets.
# - Provides tailored feedback (SQLite dialect, cartesian products, aggregates, aliases).
# - Shows data results at the bottom pane for every run (SELECT or preview for VIEW/CTAS).
#
# Hugging Face Spaces: set OPENAI_API_KEY as a secret to enable LLM randomization.

import os
import re
import json
import time
import random
import sqlite3
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from typing import List, Dict, Any, Tuple, Optional

import gradio as gr
import pandas as pd
import numpy as np

# Matplotlib for ERD drawing (headless)
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image

# -------------------- OpenAI (optional) --------------------
USE_RESPONSES_API = True
OPENAI_AVAILABLE = True
MODEL_ID = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
try:
    from openai import OpenAI
    _client = OpenAI()  # requires OPENAI_API_KEY
except Exception:
    OPENAI_AVAILABLE = False
    _client = None

# -------------------- Global settings --------------------
DB_DIR = "/data" if os.path.exists("/data") else "."
DB_PATH = os.path.join(DB_DIR, "sql_trainer_dynamic.db")
EXPORT_DIR = "."
ADMIN_KEY = os.getenv("ADMIN_KEY", "demo")
RANDOM_SEED = int(os.getenv("RANDOM_SEED", "7"))
random.seed(RANDOM_SEED)
SYS_RAND = random.SystemRandom()

PLOT_FIGSIZE = (6.8, 3.4)
PLOT_DPI = 110
PLOT_HEIGHT = 300

# -------------------- ERD helpers --------------------
def _to_pil(fig) -> Image.Image:
    buf = BytesIO()
    fig.tight_layout()
    fig.savefig(buf, format="png", dpi=PLOT_DPI, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

def draw_dynamic_erd(schema: Dict[str, Any]) -> Image.Image:
    """
    schema = {
      "domain": "bookstore",
      "tables": [
          {"name":"authors","columns":[{"name":"author_id","type":"INTEGER",...}, ...],
           "pk":["author_id"], "fks":[{"columns":["author_id"],"ref_table":"...","ref_columns":["..."]}],
           "rows":[{...}, {...}]}
      ]
    }
    """
    fig, ax = plt.subplots(figsize=PLOT_FIGSIZE)
    ax.axis("off")
    tables = schema.get("tables", [])
    n = max(1, len(tables))
    # Lay out boxes horizontally
    margin = 0.03
    width = (1 - margin*(n+1)) / n
    height = 0.65
    y = 0.25
    boxes = {}
    for i, t in enumerate(tables):
        x = margin + i*(width + margin)
        boxes[t["name"]] = (x, y, width, height)
        ax.add_patch(plt.Rectangle((x, y), width, height, fill=False))
        ax.text(x + 0.01, y + height - 0.05, f"**{t['name']}**", fontsize=10, ha="left", va="top")
        yy = y + height - 0.10
        pk = set(t.get("pk", []))
        cols = t.get("columns", [])
        for col in cols:
            nm = col["name"]
            mark = " (PK)" if nm in pk else ""
            ax.text(x + 0.02, yy, f"{nm}{mark}", fontsize=9, ha="left", va="top")
            yy -= 0.06

    # Draw FK arrows
    for t in tables:
        for fk in t.get("fks", []):
            src_tbl = t["name"]
            dst_tbl = fk.get("ref_table")
            if src_tbl in boxes and dst_tbl in boxes:
                (x1, y1, w1, h1) = boxes[src_tbl]
                (x2, y2, w2, h2) = boxes[dst_tbl]
                ax.annotate("", xy=(x2 + w2/2, y2 + h2), xytext=(x1 + w1/2, y1),
                            arrowprops=dict(arrowstyle="->", lw=1.1))
    ax.text(0.5, 0.06, f"Domain: {schema.get('domain','unknown')}", fontsize=9, ha="center")
    return _to_pil(fig)

# -------------------- SQLite helpers --------------------
def connect_db():
    con = sqlite3.connect(DB_PATH)
    con.execute("PRAGMA foreign_keys = ON;")
    return con

CONN = connect_db()

def init_progress_tables(con: sqlite3.Connection):
    cur = con.cursor()
    cur.execute("""
        CREATE TABLE IF NOT EXISTS users (
            user_id TEXT PRIMARY KEY,
            name TEXT,
            created_at TEXT
        )
    """)
    cur.execute("""
        CREATE TABLE IF NOT EXISTS attempts (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            user_id TEXT,
            question_id TEXT,
            category TEXT,
            correct INTEGER,
            sql_text TEXT,
            timestamp TEXT,
            time_taken REAL,
            difficulty INTEGER,
            source TEXT,
            notes TEXT
        )
    """)
    cur.execute("""
        CREATE TABLE IF NOT EXISTS session_meta (
            id INTEGER PRIMARY KEY CHECK (id=1),
            domain TEXT,
            schema_json TEXT
        )
    """)
    con.commit()

init_progress_tables(CONN)

# -------------------- Fallback dataset (if no OpenAI) --------------------
FALLBACK_SCHEMA = {
    "domain": "bookstore",
    "tables": [
        {
            "name": "authors",
            "pk": ["author_id"],
            "columns": [
                {"name":"author_id","type":"INTEGER"},
                {"name":"name","type":"TEXT"},
                {"name":"country","type":"TEXT"},
                {"name":"birth_year","type":"INTEGER"},
            ],
            "fks": [],
            "rows": [
                {"author_id":1,"name":"Isaac Asimov","country":"USA","birth_year":1920},
                {"author_id":2,"name":"Ursula K. Le Guin","country":"USA","birth_year":1929},
                {"author_id":3,"name":"Haruki Murakami","country":"Japan","birth_year":1949},
                {"author_id":4,"name":"Chinua Achebe","country":"Nigeria","birth_year":1930},
                {"author_id":5,"name":"Jane Austen","country":"UK","birth_year":1775},
                {"author_id":6,"name":"J.K. Rowling","country":"UK","birth_year":1965},
                {"author_id":7,"name":"Yuval Noah Harari","country":"Israel","birth_year":1976},
                {"author_id":8,"name":"New Author","country":"Nowhere","birth_year":1990},
            ],
        },
        {
            "name": "bookstores",
            "pk": ["store_id"],
            "columns": [
                {"name":"store_id","type":"INTEGER"},
                {"name":"name","type":"TEXT"},
                {"name":"city","type":"TEXT"},
                {"name":"state","type":"TEXT"},
            ],
            "fks": [],
            "rows": [
                {"store_id":1,"name":"Downtown Books","city":"Oklahoma City","state":"OK"},
                {"store_id":2,"name":"Harbor Books","city":"Seattle","state":"WA"},
                {"store_id":3,"name":"Desert Pages","city":"Phoenix","state":"AZ"},
            ],
        },
        {
            "name": "books",
            "pk": ["book_id"],
            "columns": [
                {"name":"book_id","type":"INTEGER"},
                {"name":"title","type":"TEXT"},
                {"name":"author_id","type":"INTEGER"},
                {"name":"store_id","type":"INTEGER"},
                {"name":"category","type":"TEXT"},
                {"name":"price","type":"REAL"},
                {"name":"published_year","type":"INTEGER"},
            ],
            "fks": [
                {"columns":["author_id"],"ref_table":"authors","ref_columns":["author_id"]},
                {"columns":["store_id"],"ref_table":"bookstores","ref_columns":["store_id"]},
            ],
            "rows": [
                {"book_id":101,"title":"Foundation","author_id":1,"store_id":1,"category":"Sci-Fi","price":14.99,"published_year":1951},
                {"book_id":102,"title":"I, Robot","author_id":1,"store_id":1,"category":"Sci-Fi","price":12.50,"published_year":1950},
                {"book_id":103,"title":"The Left Hand of Darkness","author_id":2,"store_id":2,"category":"Sci-Fi","price":16.00,"published_year":1969},
                {"book_id":104,"title":"A Wizard of Earthsea","author_id":2,"store_id":2,"category":"Fantasy","price":11.50,"published_year":1968},
                {"book_id":105,"title":"Norwegian Wood","author_id":3,"store_id":3,"category":"Fiction","price":18.00,"published_year":1987},
                {"book_id":106,"title":"Kafka on the Shore","author_id":3,"store_id":1,"category":"Fiction","price":21.00,"published_year":2002},
                {"book_id":107,"title":"Things Fall Apart","author_id":4,"store_id":1,"category":"Fiction","price":10.00,"published_year":1958},
                {"book_id":108,"title":"Pride and Prejudice","author_id":5,"store_id":2,"category":"Fiction","price":9.00,"published_year":1813},
                {"book_id":109,"title":"Harry Potter and the Sorcerer's Stone","author_id":6,"store_id":3,"category":"Children","price":22.00,"published_year":1997},
                {"book_id":110,"title":"Harry Potter and the Chamber of Secrets","author_id":6,"store_id":3,"category":"Children","price":23.00,"published_year":1998},
                {"book_id":111,"title":"Sapiens","author_id":7,"store_id":1,"category":"History","price":26.00,"published_year":2011},
                {"book_id":112,"title":"Homo Deus","author_id":7,"store_id":2,"category":"History","price":28.00,"published_year":2015},
            ],
        },
    ]
}

FALLBACK_QUESTIONS = [
    {
        "id":"Q01","category":"SELECT *","difficulty":1,
        "prompt_md":"Select all rows and columns from `authors`.",
        "answer_sql":["SELECT * FROM authors;"],
        "requires_aliases":False,"required_aliases":[]
    },
    {
        "id":"Q02","category":"SELECT columns","difficulty":1,
        "prompt_md":"Show `title` and `price` from `books`.",
        "answer_sql":["SELECT title, price FROM books;"],
        "requires_aliases":False,"required_aliases":[]
    },
    {
        "id":"Q03","category":"WHERE","difficulty":1,
        "prompt_md":"List Sci‑Fi books under $15 (show title, price).",
        "answer_sql":["SELECT title, price FROM books WHERE category='Sci-Fi' AND price < 15;"],
        "requires_aliases":False,"required_aliases":[]
    },
    {
        "id":"Q04","category":"Aliases","difficulty":1,
        "prompt_md":"Using aliases `b` and `a`, join `books` to `authors` and show `b.title` and `a.name` as `author_name`.",
        "answer_sql":["SELECT b.title, a.name AS author_name FROM books b JOIN authors a ON b.author_id=a.author_id;"],
        "requires_aliases":True,"required_aliases":["a","b"]
    },
    {
        "id":"Q05","category":"JOIN (INNER)","difficulty":2,
        "prompt_md":"Inner join `books` and `bookstores`. Return `title`, `name` as `store`.",
        "answer_sql":[
            "SELECT b.title, s.name AS store FROM books b INNER JOIN bookstores s ON b.store_id=s.store_id;"
        ],
        "requires_aliases":False,"required_aliases":[]
    },
    {
        "id":"Q06","category":"JOIN (LEFT)","difficulty":2,
        "prompt_md":"List each author and their number of books (include authors with zero): columns `name`, `book_count`.",
        "answer_sql":[
            "SELECT a.name, COUNT(b.book_id) AS book_count FROM authors a LEFT JOIN books b ON a.author_id=b.author_id GROUP BY a.name;"
        ],
        "requires_aliases":False,"required_aliases":[]
    },
    {
        "id":"Q07","category":"VIEW","difficulty":2,
        "prompt_md":"Create a view `vw_pricy` with `title`, `price` for books priced > 25.",
        "answer_sql":[
            "CREATE VIEW vw_pricy AS SELECT title, price FROM books WHERE price > 25;"
        ],
        "requires_aliases":False,"required_aliases":[]
    },
    {
        "id":"Q08","category":"CTAS / SELECT INTO","difficulty":2,
        "prompt_md":"Create a table `cheap_books` containing books priced < 12. Use CTAS or SELECT INTO.",
        "answer_sql":[
            "CREATE TABLE cheap_books AS SELECT * FROM books WHERE price < 12;",
            "SELECT * INTO cheap_books FROM books WHERE price < 12;"
        ],
        "requires_aliases":False,"required_aliases":[]
    },
]

# -------------------- OpenAI prompts --------------------
DOMAIN_AND_QUESTIONS_SCHEMA = {
    "name": "DomainSQLPack",
    "schema": {
        "type": "object",
        "additionalProperties": False,
        "properties": {
            "domain": {"type":"string"},
            "tables": {
                "type":"array",
                "items": {
                    "type":"object",
                    "additionalProperties": False,
                    "properties": {
                        "name": {"type":"string"},
                        "pk": {"type":"array","items":{"type":"string"}},
                        "columns": {
                            "type":"array",
                            "items": {
                                "type":"object",
                                "additionalProperties": False,
                                "properties": {
                                    "name":{"type":"string"},
                                    "type":{"type":"string"}
                                },
                                "required":["name","type"]
                            }
                        },
                        "fks": {
                            "type":"array",
                            "items": {
                                "type":"object",
                                "additionalProperties": False,
                                "properties": {
                                    "columns":{"type":"array","items":{"type":"string"}},
                                    "ref_table":{"type":"string"},
                                    "ref_columns":{"type":"array","items":{"type":"string"}}
                                },
                                "required":["columns","ref_table","ref_columns"]
                            }
                        },
                        "rows": {"type":"array","items":{"type":["object","array"]}}
                    },
                    "required":["name","pk","columns","fks","rows"]
                },
                "minItems":3,"maxItems":4
            },
            "questions": {
                "type":"array",
                "items": {
                    "type":"object",
                    "additionalProperties": False,
                    "properties": {
                        "id":{"type":"string"},
                        "category":{"type":"string"},
                        "difficulty":{"type":"integer"},
                        "prompt_md":{"type":"string"},
                        "answer_sql":{"type":"array","items":{"type":"string"}},
                        "requires_aliases":{"type":"boolean"},
                        "required_aliases":{"type":"array","items":{"type":"string"}}
                    },
                    "required":["id","category","difficulty","prompt_md","answer_sql"]
                },
                "minItems":8,"maxItems":12
            }
        },
        "required":["domain","tables","questions"]
    },
    "strict": True
}

DOMAIN_AND_QUESTIONS_PROMPT = """
You are designing a small relational dataset and training questions for SQL basics.

1) Choose ONE domain at random from:
   - bookstore, retail sales, wholesaler, sales tax, oil and gas wells, marketing.

2) Produce exactly 3–4 tables that fit together (SQLite-friendly):
   - Use snake_case, avoid reserved words.
   - Types: INTEGER, REAL, TEXT, NUMERIC, DATE (but no advanced features).
   - Primary keys (pk) and foreign keys (fks) must align.
   - Provide 8–15 small, realistic seed rows per table (not huge).

3) Generate 8–12 SQL questions covering basics with varied, natural language:
   - Categories from: "SELECT *", "SELECT columns", "WHERE", "Aliases",
     "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO".
   - Include a few joins and at least one LEFT JOIN.
   - Include one view creation.
   - Include one table creation from SELECT (either CTAS or SELECT INTO).
   - Prefer SQLite-compatible SQL. DO NOT use RIGHT/FULL OUTER JOIN.
   - Offer 1–3 acceptable answer_sql variants per question.
   - For 1–2 questions, require table aliases (set requires_aliases=true and list required_aliases).

Return JSON only.
"""

def llm_generate_domain_and_questions() -> Optional[Dict[str,Any]]:
    if not OPENAI_AVAILABLE:
        return None
    try:
        if USE_RESPONSES_API:
            resp = _client.responses.create(
                model=MODEL_ID,
                response_format={"type":"json_schema","json_schema":DOMAIN_AND_QUESTIONS_SCHEMA},
                input=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
                temperature=0.6,
            )
            data_text = getattr(resp, "output_text", None)
        else:
            chat = _client.chat.completions.create(
                model=MODEL_ID,
                messages=[{"role":"user","content": DOMAIN_AND_QUESTIONS_PROMPT}],
                temperature=0.6
            )
            data_text = chat.choices[0].message.content
        obj = json.loads(data_text) if data_text else None
        return obj
    except Exception:
        return None

# -------------------- Schema install & question handling --------------------
def drop_existing_domain_tables(con: sqlite3.Connection, keep_internal=True):
    cur = con.cursor()
    cur.execute("SELECT name, type FROM sqlite_master WHERE type IN ('table','view')")
    items = cur.fetchall()
    for name, typ in items:
        if keep_internal and name in ("users","attempts","session_meta"):
            continue
        try:
            cur.execute(f"DROP {typ.upper()} IF EXISTS {name}")
        except Exception:
            pass
    con.commit()

def install_schema(con: sqlite3.Connection, schema: Dict[str,Any]):
    drop_existing_domain_tables(con, keep_internal=True)
    cur = con.cursor()
    # Create tables first
    for t in schema.get("tables", []):
        cols_sql = []
        pk = t.get("pk", [])
        for c in t.get("columns", []):
            cname = c["name"]
            ctype = c.get("type","TEXT")
            cols_sql.append(f"{cname} {ctype}")
        if pk:
            cols_sql.append(f"PRIMARY KEY ({', '.join(pk)})")
        create_sql = f"CREATE TABLE {t['name']} ({', '.join(cols_sql)})"
        cur.execute(create_sql)
    # Insert rows
    for t in schema.get("tables", []):
        if not t.get("rows"):
            continue
        cols = [c["name"] for c in t.get("columns", [])]
        qmarks = ",".join(["?"]*len(cols))
        insert_sql = f"INSERT INTO {t['name']} ({', '.join(cols)}) VALUES ({qmarks})"
        # rows can be objects or arrays
        for r in t["rows"]:
            if isinstance(r, dict):
                vals = [r.get(col, None) for col in cols]
            elif isinstance(r, list) or isinstance(r, tuple):
                vals = list(r) + [None]*(len(cols)-len(r))
                vals = vals[:len(cols)]
            else:
                continue
            cur.execute(insert_sql, vals)
    con.commit()
    # Persist schema JSON
    cur.execute("INSERT OR REPLACE INTO session_meta(id, domain, schema_json) VALUES (1, ?, ?)",
                (schema.get("domain","unknown"), json.dumps(schema)))
    con.commit()

def run_df(con: sqlite3.Connection, sql: str) -> pd.DataFrame:
    return pd.read_sql_query(sql, con)

def rewrite_select_into(sql: str) -> Tuple[str, Optional[str]]:
    s = sql.strip().strip(";")
    if re.search(r"\bselect\b.+\binto\b.+\bfrom\b", s, flags=re.IGNORECASE|re.DOTALL):
        m = re.match(r"(?is)^\s*select\s+(.*?)\s+into\s+([A-Za-z_][A-Za-z0-9_]*)\s+from\s+(.*)$", s)
        if m:
            cols, tbl, rest = m.groups()
            return f"CREATE TABLE {tbl} AS SELECT {cols} FROM {rest}", tbl
    return sql, None

def detect_unsupported_joins(sql: str) -> Optional[str]:
    low = sql.lower()
    if " right join " in low:
        return "SQLite does not support RIGHT JOIN. Use LEFT JOIN in the opposite direction."
    if " full join " in low or " full outer join " in low:
        return "SQLite does not support FULL OUTER JOIN. Use LEFT JOIN plus UNION for the other side."
    if " ilike " in low:
        return "SQLite has no ILIKE. Use `LOWER(col) LIKE LOWER('%pattern%')`."
    return None

def detect_cartesian(con: sqlite3.Connection, sql: str, df_result: pd.DataFrame) -> Optional[str]:
    low = sql.lower()
    if " cross join " in low:
        return "Query uses CROSS JOIN (cartesian product). Ensure this is intended."
    comma_from = re.search(r"\bfrom\b\s+([a-z_]\w*)\s*,\s*([a-z_]\w*)", low)
    missing_on = (" join " in low) and (" on " not in low) and (" using " not in low) and (" natural " not in low)
    if comma_from or missing_on:
        try:
            cur = con.cursor()
            if comma_from:
                t1, t2 = comma_from.groups()
            else:
                m = re.search(r"\bfrom\b\s+([a-z_]\w*)", low)
                j = re.search(r"\bjoin\b\s+([a-z_]\w*)", low)
                if not m or not j:
                    return "Possible cartesian product: no join condition detected."
                t1, t2 = m.group(1), j.group(1)
            cur.execute(f"SELECT COUNT(*) FROM {t1}")
            n1 = cur.fetchone()[0]
            cur.execute(f"SELECT COUNT(*) FROM {t2}")
            n2 = cur.fetchone()[0]
            prod = n1 * n2
            if len(df_result) == prod and prod > 0:
                return f"Result row count equals {n1}×{n2}={prod}. Likely cartesian product (missing join)."
        except Exception:
            return "Possible cartesian product: no join condition detected."
    return None

def results_equal(df_a: pd.DataFrame, df_b: pd.DataFrame) -> bool:
    if df_a.shape != df_b.shape:
        return False
    a = df_a.copy()
    b = df_b.copy()
    a.columns = [c.lower() for c in a.columns]
    b.columns = [c.lower() for c in b.columns]
    a = a.sort_values(list(a.columns)).reset_index(drop=True)
    b = b.sort_values(list(b.columns)).reset_index(drop=True)
    return a.equals(b)

def aliases_present(sql: str, required_aliases: List[str]) -> bool:
    low = re.sub(r"\s+", " ", sql.lower())
    for al in required_aliases:
        if f" {al}." not in low and f" as {al} " not in low:
            return False
    return True

# -------------------- Question model --------------------
@dataclass
class SQLQuestion:
    id: str
    category: str
    difficulty: int
    prompt_md: str
    answer_sql: List[str]
    requires_aliases: bool = False
    required_aliases: List[str] = None

def to_question_dict(q) -> Dict[str,Any]:
    d = dict(q)
    d.setdefault("requires_aliases", False)
    d.setdefault("required_aliases", [])
    return d

def load_questions(obj_list: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
    out = []
    for o in obj_list:
        out.append(to_question_dict(o))
    return out

# -------------------- Domain bootstrap --------------------
def bootstrap_domain_with_llm_or_fallback() -> Tuple[Dict[str,Any], List[Dict[str,Any]]]:
    obj = llm_generate_domain_and_questions()
    if obj is None:
        return FALLBACK_SCHEMA, FALLBACK_QUESTIONS
    # Guardrails: strip RIGHT/FULL joins from answers
    clean_qs = []
    for q in obj.get("questions", []):
        answers = [a for a in q.get("answer_sql", []) if " right join " not in a.lower() and " full " not in a.lower()]
        if not answers:
            continue
        q["answer_sql"] = answers
        q.setdefault("requires_aliases", False)
        q.setdefault("required_aliases", [])
        clean_qs.append(q)
    obj["questions"] = clean_qs
    return obj, clean_qs

def install_new_domain():
    schema, questions = bootstrap_domain_with_llm_or_fallback()
    install_schema(CONN, schema)
    return schema, questions

# -------------------- Session state --------------------
CURRENT_SCHEMA, CURRENT_QS = install_new_domain()

# -------------------- Progress + mastery --------------------
def upsert_user(con: sqlite3.Connection, user_id: str, name: str):
    cur = con.cursor()
    cur.execute("SELECT user_id FROM users WHERE user_id = ?", (user_id,))
    if cur.fetchone() is None:
        cur.execute("INSERT INTO users (user_id, name, created_at) VALUES (?, ?, ?)",
                    (user_id, name, datetime.now(timezone.utc).isoformat()))
    else:
        cur.execute("UPDATE users SET name=? WHERE user_id=?", (name, user_id))
    con.commit()

CATEGORIES_ORDER = [
    "SELECT *", "SELECT columns", "WHERE", "Aliases",
    "JOIN (INNER)", "JOIN (LEFT)", "Aggregation", "VIEW", "CTAS / SELECT INTO"
]

def topic_stats(df_attempts: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for cat in CATEGORIES_ORDER:
        sub = df_attempts[df_attempts["category"] == cat] if not df_attempts.empty else pd.DataFrame()
        att = int(sub.shape[0]) if not sub.empty else 0
        cor = int(sub["correct"].sum()) if not sub.empty else 0
        acc = float(cor / max(att, 1))
        rows.append({"category":cat,"attempts":att,"correct":cor,"accuracy":acc})
    return pd.DataFrame(rows)

def fetch_attempts(con: sqlite3.Connection, user_id: str) -> pd.DataFrame:
    return pd.read_sql_query("SELECT * FROM attempts WHERE user_id=? ORDER BY id DESC", con, params=(user_id,))

def pick_next_question(user_id: str) -> Dict[str,Any]:
    df = fetch_attempts(CONN, user_id)
    stats = topic_stats(df)
    stats = stats.sort_values(by=["accuracy","attempts"], ascending=[True, True])
    weakest = stats.iloc[0]["category"] if not stats.empty else CATEGORIES_ORDER[0]
    cands = [q for q in CURRENT_QS if q["category"] == weakest] or CURRENT_QS
    return dict(random.choice(cands))

# -------------------- Execution & feedback --------------------
def exec_student_sql(sql_text: str) -> Tuple[Optional[pd.DataFrame], Optional[str], Optional[str], Optional[str]]:
    if not sql_text or not sql_text.strip():
        return None, "Enter a SQL statement.", None, None

    sql_raw = sql_text.strip().rstrip(";")
    sql_rew, created_tbl = rewrite_select_into(sql_raw)
    note = None
    if sql_rew != sql_raw:
        note = "Rewrote `SELECT ... INTO` to `CREATE TABLE ... AS SELECT ...` for SQLite."

    unsup = detect_unsupported_joins(sql_rew)
    if unsup:
        return None, unsup, None, note

    try:
        low = sql_rew.lower()
        if low.startswith("select"):
            df = run_df(CONN, sql_rew)
            warn = detect_cartesian(CONN, sql_rew, df)
            return df, None, warn, note
        else:
            cur = CONN.cursor()
            cur.execute(sql_rew)
            CONN.commit()
            # Preview newly created objects
            if low.startswith("create view"):
                m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+(select.*)$", low)
                name = m.group(2) if m else None
                if name:
                    try:
                        df = run_df(CONN, f"SELECT * FROM {name}")
                        return df, None, None, note
                    except Exception:
                        return None, "View created but could not be queried.", None, note
            if low.startswith("create table"):
                tbl = created_tbl
                if not tbl:
                    m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
                    tbl = m.group(2) if m else None
                if tbl:
                    try:
                        df = run_df(CONN, f"SELECT * FROM {tbl}")
                        return df, None, None, note
                    except Exception:
                        return None, "Table created but could not be queried.", None, note
            return pd.DataFrame(), None, None, note
    except Exception as e:
        # Tailored messages
        msg = str(e)
        if "no such table" in msg.lower():
            return None, f"{msg}. Check table names for this randomized domain.", None, note
        if "no such column" in msg.lower():
            return None, f"{msg}. Use correct column names or prefixes (alias.column).", None, note
        if "ambiguous column name" in msg.lower():
            return None, f"{msg}. Qualify the column with a table alias.", None, note
        if "misuse of aggregate" in msg.lower() or "aggregate functions are not allowed in" in msg.lower():
            return None, f"{msg}. You might need a GROUP BY for non-aggregated columns.", None, note
        if "near \"into\"" in msg.lower() and "syntax error" in msg.lower():
            return None, "SQLite doesn’t support `SELECT ... INTO`. I can rewrite it automatically—try again.", None, note
        if "syntax error" in msg.lower():
            return None, f"Syntax error. Check commas, keywords, and parentheses. Raw error: {msg}", None, note
        return None, f"SQL error: {msg}", None, note

def answer_df(answer_sql: List[str]) -> Optional[pd.DataFrame]:
    for sql in answer_sql:
        try:
            low = sql.strip().lower()
            if low.startswith("select"):
                return run_df(CONN, sql)
            if low.startswith("create view"):
                # temp preview
                m = re.match(r"(?is)^\s*create\s+view\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
                view_name = m.group(2) if m else "vw_tmp"
                cur = CONN.cursor()
                cur.execute(f"DROP VIEW IF EXISTS {view_name}")
                cur.execute(sql)
                CONN.commit()
                return run_df(CONN, f"SELECT * FROM {view_name}")
            if low.startswith("create table"):
                m = re.match(r"(?is)^\s*create\s+table\s+(if\s+not\s+exists\s+)?([a-z_]\w*)\s+as\s+select.*$", low)
                tbl = m.group(2) if m else None
                cur = CONN.cursor()
                if tbl:
                    cur.execute(f"DROP TABLE IF EXISTS {tbl}")
                cur.execute(sql)
                CONN.commit()
                if tbl:
                    return run_df(CONN, f"SELECT * FROM {tbl}")
        except Exception:
            continue
    return None

def validate_answer(q: Dict[str,Any], student_sql: str, df_student: Optional[pd.DataFrame]) -> Tuple[bool, str]:
    df_expected = answer_df(q["answer_sql"])
    # If we can't build a canonical DF (e.g., DDL side effect), we accept any successful execution as correct
    if df_expected is None:
        return (df_student is not None), f"**Explanation:** Your statement executed successfully for this task."
    if df_student is None:
        return False, f"**Explanation:** Expected data result differs."
    return results_equal(df_student, df_expected), f"**Explanation:** Compare your result to a canonical solution."

def log_attempt(user_id: str, qid: str, category: str, correct: bool, sql_text: str,
                time_taken: float, difficulty: int, source: str, notes: str):
    cur = CONN.cursor()
    cur.execute("""
        INSERT INTO attempts (user_id, question_id, category, correct, sql_text, timestamp, time_taken, difficulty, source, notes)
        VALUES (?,?,?,?,?,?,?,?,?,?)
    """, (user_id, qid, category, int(correct), sql_text, datetime.now(timezone.utc).isoformat(),
          time_taken, difficulty, source, notes))
    CONN.commit()

# -------------------- UI callbacks --------------------
def start_session(name: str, session: dict):
    name = (name or "").strip()
    if not name:
        return (session,
                gr.update(value="Please enter your name to begin.", visible=True),
                gr.update(visible=False),
                gr.update(visible=False),
                None,
                gr.update(visible=False),
                pd.DataFrame(),
                pd.DataFrame())

    slug = "-".join(name.lower().split())
    user_id = slug[:64] if slug else f"user-{int(time.time())}"
    upsert_user(CONN, user_id, name)
    q = pick_next_question(user_id)
    session = {"user_id": user_id, "name": name, "qid": q["id"], "start_ts": time.time(), "q": q}

    prompt = q["prompt_md"]
    stats = topic_stats(fetch_attempts(CONN, user_id))
    erd = draw_dynamic_erd(CURRENT_SCHEMA)
    return (session,
            gr.update(value=f"**Question {q['id']}**\n\n{prompt}", visible=True),
            gr.update(visible=True),  # show SQL input
            gr.update(value="", visible=True),  # preview block
            erd,
            gr.update(visible=False),  # next btn hidden until submit
            stats,
            pd.DataFrame())

def render_preview_and_erd(sql_text: str, session: dict):
    if not session or "q" not in session:
        return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
    s = (sql_text or "").strip()
    if not s:
        return gr.update(value="", visible=False), draw_dynamic_erd(CURRENT_SCHEMA)
    return gr.update(value=f"**Preview:**\n\n```sql\n{s}\n```", visible=True), draw_dynamic_erd(CURRENT_SCHEMA)

def submit_answer(sql_text: str, session: dict):
    if not session or "user_id" not in session or "q" not in session:
        return gr.update(value="Start a session first.", visible=True), pd.DataFrame(), gr.update(visible=False), pd.DataFrame()
    user_id = session["user_id"]
    q = session["q"]
    elapsed = max(0.0, time.time() - session.get("start_ts", time.time()))

    df, err, warn, note = exec_student_sql(sql_text)
    details = []
    if note: details.append(f"ℹ️ {note}")
    if err:
        fb = f"❌ **Did not run**\n\n{err}"
        if details: fb += "\n\n" + "\n".join(details)
        log_attempt(user_id, q["id"], q["category"], False, sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join([err] + details))
        stats = topic_stats(fetch_attempts(CONN, user_id))
        return gr.update(value=fb, visible=True), pd.DataFrame(), gr.update(visible=True), stats

    # Validate correctness
    alias_msg = None
    if q.get("requires_aliases"):
        if not aliases_present(sql_text, q.get("required_aliases", [])):
            alias_msg = f"⚠️ This task asked for aliases {q.get('required_aliases', [])}. I didn’t detect them."

    is_correct, explanation = validate_answer(q, sql_text, df)
    if warn: details.append(f"⚠️ {warn}")
    if alias_msg: details.append(alias_msg)

    prefix = "✅ **Correct!**" if is_correct else "❌ **Not quite.**"
    feedback = prefix
    if details:
        feedback += "\n\n" + "\n".join(details)
    feedback += "\n\n" + explanation + "\n\n**One acceptable solution:**\n```sql\n" + q["answer_sql"][0].rstrip(";") + ";\n```"

    log_attempt(user_id, q["id"], q["category"], bool(is_correct), sql_text, elapsed, int(q["difficulty"]), "bank", " | ".join(details))
    stats = topic_stats(fetch_attempts(CONN, user_id))
    return gr.update(value=feedback, visible=True), (df if df is not None else pd.DataFrame()), gr.update(visible=True), stats

def next_question(session: dict):
    if not session or "user_id" not in session:
        return session, gr.update(value="Start a session first.", visible=True), gr.update(visible=False), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)
    user_id = session["user_id"]
    q = pick_next_question(user_id)
    session["qid"] = q["id"]
    session["q"] = q
    session["start_ts"] = time.time()
    return session, gr.update(value=f"**Question {q['id']}**\n\n{q['prompt_md']}", visible=True), gr.update(value="", visible=True), draw_dynamic_erd(CURRENT_SCHEMA), gr.update(visible=False)

def show_hint(session: dict):
    if not session or "q" not in session:
        return gr.update(value="Start a session first.", visible=True)
    # Lightweight hint policy: category-specific guidance
    cat = session["q"]["category"]
    hint = {
        "SELECT *": "Use `SELECT * FROM table_name`.",
        "SELECT columns": "List columns: `SELECT col1, col2 FROM table_name`.",
        "WHERE": "Filter with `WHERE` and combine conditions using AND/OR.",
        "Aliases": "Use `table_name t` and qualify: `t.col`.",
        "JOIN (INNER)": "Join with `... INNER JOIN ... ON left.key = right.key`.",
        "JOIN (LEFT)": "LEFT JOIN keeps all rows from the left table.",
        "Aggregation": "Use aggregate functions and `GROUP BY` non-aggregated columns.",
        "VIEW": "`CREATE VIEW view_name AS SELECT ...`.",
        "CTAS / SELECT INTO": "SQLite uses `CREATE TABLE name AS SELECT ...`."
    }.get(cat, "Read the ER diagram and identify keys to join on.")
    return gr.update(value=f"**Hint:** {hint}", visible=True)

def export_progress(user_name: str):
    slug = "-".join((user_name or "").lower().split())
    if not slug:
        return None
    user_id = slug[:64]
    df = fetch_attempts(CONN, user_id)
    os.makedirs(EXPORT_DIR, exist_ok=True)
    path = os.path.abspath(os.path.join(EXPORT_DIR, f"{user_id}_progress.csv"))
    (pd.DataFrame([{"info":"No attempts yet."}]) if df.empty else df).to_csv(path, index=False)
    return path

def regenerate_domain():
    global CURRENT_SCHEMA, CURRENT_QS
    CURRENT_SCHEMA, CURRENT_QS = install_new_domain()
    erd = draw_dynamic_erd(CURRENT_SCHEMA)
    return gr.update(value="✅ Domain regenerated.", visible=True), erd

def preview_table(tbl: str):
    try:
        return run_df(CONN, f"SELECT * FROM {tbl} LIMIT 20")
    except Exception as e:
        return pd.DataFrame([{"error": str(e)}])

def list_tables_for_preview():
    df = run_df(CONN, "SELECT name, type FROM sqlite_master WHERE type in ('table','view') AND name NOT IN ('users','attempts','session_meta') ORDER BY type, name")
    if df.empty:
        return ["(no tables)"]
    return df["name"].tolist()

# -------------------- UI --------------------
with gr.Blocks(title="Adaptive SQL Trainer — Randomized Domains") as demo:
    gr.Markdown(
        """
        # 🧪 Adaptive SQL Trainer — Randomized Domains (SQLite)
        - Uses **OpenAI** (if configured) to randomize a domain (bookstore, retail sales, wholesaler,
          sales tax, oil & gas wells, marketing), generate **3–4 tables** and **8–12** questions.
        - Practice `SELECT`, `WHERE`, `JOIN` (INNER/LEFT), **aliases**, **views**, and **CTAS / SELECT INTO**.
        - The app explains **SQLite quirks** (no RIGHT/FULL JOIN) and flags likely **cartesian products**.

        > Set your `OPENAI_API_KEY` in the Space secrets to enable randomization.
        """
    )

    with gr.Row():
        # -------- Left column: controls + quick preview ----------
        with gr.Column(scale=1):
            name_box = gr.Textbox(label="Your Name", placeholder="e.g., Jordan Alvarez")
            start_btn = gr.Button("Start / Resume Session", variant="primary")
            session_state = gr.State({"user_id": None, "name": None, "qid": None, "start_ts": None, "q": None})

            gr.Markdown("---")
            gr.Markdown("### Dataset Controls")
            regen_btn = gr.Button("🔀 Randomize Dataset (OpenAI)")
            regen_fb = gr.Markdown(visible=False)

            gr.Markdown("---")
            gr.Markdown("### Instructor Tools")
            export_name = gr.Textbox(label="Export a student's progress (enter name)")
            export_btn = gr.Button("Export CSV")
            export_file = gr.File(label="Download progress")

            gr.Markdown("---")
            gr.Markdown("### Quick Table/View Preview (top 20 rows)")
            tbl_dd = gr.Dropdown(choices=list_tables_for_preview(), label="Pick table/view", interactive=True)
            tbl_btn = gr.Button("Preview")
            preview_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)

        # -------- Right column: task + feedback + mastery + results ----------
        with gr.Column(scale=2):
            prompt_md = gr.Markdown(visible=False)
            sql_input = gr.Textbox(label="Your SQL", placeholder="Type SQL here (end ; optional).", lines=6, visible=False)

            preview_md = gr.Markdown(visible=False)
            er_image = gr.Image(label="Entity Diagram", value=draw_dynamic_erd(CURRENT_SCHEMA), height=PLOT_HEIGHT)

            with gr.Row():
                submit_btn = gr.Button("Run & Submit", variant="primary")
                hint_btn = gr.Button("Hint")
                next_btn = gr.Button("Next Question ▶", visible=False)

            feedback_md = gr.Markdown("")

            gr.Markdown("---")
            gr.Markdown("### Your Progress by Category")
            mastery_df = gr.Dataframe(
                headers=["category","attempts","correct","accuracy"],
                col_count=(4, "dynamic"),
                row_count=(0, "dynamic"),
                interactive=False
            )

            gr.Markdown("---")
            gr.Markdown("### Result Preview")
            result_df = gr.Dataframe(value=pd.DataFrame(), interactive=False)

    # Wire events
    start_btn.click(
        start_session,
        inputs=[name_box, session_state],
        outputs=[session_state, prompt_md, sql_input, preview_md, er_image, next_btn, mastery_df, result_df],
    )
    sql_input.change(
        render_preview_and_erd,
        inputs=[sql_input, session_state],
        outputs=[preview_md, er_image],
    )
    submit_btn.click(
        submit_answer,
        inputs=[sql_input, session_state],
        outputs=[feedback_md, result_df, next_btn, mastery_df],
    )
    next_btn.click(
        next_question,
        inputs=[session_state],
        outputs=[session_state, prompt_md, sql_input, er_image, next_btn],
    )
    hint_btn.click(
        show_hint,
        inputs=[session_state],
        outputs=[feedback_md],
    )
    export_btn.click(
        export_progress,
        inputs=[export_name],
        outputs=[export_file],
    )
    regen_btn.click(
        regenerate_domain,
        inputs=[],
        outputs=[regen_fb, er_image],
    )
    tbl_btn.click(
        lambda name: preview_table(name),
        inputs=[tbl_dd],
        outputs=[preview_df]
    )
    # Keep dropdown fresh after regeneration
    regen_btn.click(
        lambda: gr.update(choices=list_tables_for_preview()),
        inputs=[],
        outputs=[tbl_dd]
    )

if __name__ == "__main__":
    demo.launch()