Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
# app.py β Gradio UI
|
| 2 |
-
#
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
|
| 6 |
import os, re, json, time, sys, subprocess
|
| 7 |
from typing import Optional, Tuple
|
|
@@ -46,13 +46,15 @@ def adapter_exists(label: str) -> bool:
|
|
| 46 |
has_info= "MODEL_INFO.json" in files
|
| 47 |
return has_cfg and has_wts and has_info
|
| 48 |
|
|
|
|
| 49 |
# --------------------- global loaded model ---------------------
|
| 50 |
TOK = None
|
| 51 |
MODEL = None
|
| 52 |
ACTIVE_LABEL: Optional[str] = None
|
| 53 |
ACTIVE_BASE = BASE_MODEL_DEFAULT
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
def train_model_live(dataset_id, base_model, label, epochs):
|
| 57 |
if not label:
|
| 58 |
label = time.strftime("run_%Y%m%d_%H%M%S")
|
|
@@ -64,8 +66,10 @@ def train_model_live(dataset_id, base_model, label, epochs):
|
|
| 64 |
"--label", label.strip(),
|
| 65 |
"--epochs", str(int(epochs)),
|
| 66 |
]
|
|
|
|
| 67 |
yield f"$ {' '.join(cmd)}\n\n"
|
| 68 |
|
|
|
|
| 69 |
process = subprocess.Popen(
|
| 70 |
cmd,
|
| 71 |
stdout=subprocess.PIPE,
|
|
@@ -81,10 +85,10 @@ def train_model_live(dataset_id, base_model, label, epochs):
|
|
| 81 |
break
|
| 82 |
if line:
|
| 83 |
buffer += line
|
| 84 |
-
#
|
| 85 |
-
yield buffer[-6000:] # keep the tail to avoid giant textbox
|
| 86 |
rc = process.wait()
|
| 87 |
-
yield buffer[-
|
|
|
|
| 88 |
|
| 89 |
# --------------------- loading + generation ---------------------
|
| 90 |
def load_selected_model(label: str) -> str:
|
|
@@ -99,7 +103,7 @@ def load_selected_model(label: str) -> str:
|
|
| 99 |
meta = json.load(open(info_path, "r", encoding="utf-8"))
|
| 100 |
base = meta.get("base_model", BASE_MODEL_DEFAULT)
|
| 101 |
except Exception:
|
| 102 |
-
base = BASE_MODEL_DEFAULT
|
| 103 |
|
| 104 |
TOK = AutoTokenizer.from_pretrained(base)
|
| 105 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
|
|
@@ -107,7 +111,8 @@ def load_selected_model(label: str) -> str:
|
|
| 107 |
MODEL.eval()
|
| 108 |
ACTIVE_LABEL = label
|
| 109 |
ACTIVE_BASE = base
|
| 110 |
-
|
|
|
|
| 111 |
|
| 112 |
PROMPT_HEADER = (
|
| 113 |
"You are a quiz question generator. "
|
|
@@ -129,7 +134,8 @@ def try_parse_json(js: str) -> Optional[dict]:
|
|
| 129 |
try:
|
| 130 |
return json.loads(js)
|
| 131 |
except Exception:
|
| 132 |
-
|
|
|
|
| 133 |
try:
|
| 134 |
return json.loads(js2)
|
| 135 |
except Exception:
|
|
@@ -144,7 +150,7 @@ def generate(age_band, genre, difficulty):
|
|
| 144 |
out = MODEL.generate(
|
| 145 |
**inputs,
|
| 146 |
max_new_tokens=160,
|
| 147 |
-
do_sample=False, # deterministic
|
| 148 |
num_beams=4,
|
| 149 |
early_stopping=True,
|
| 150 |
no_repeat_ngram_size=3,
|
|
@@ -177,13 +183,15 @@ def do_load(label):
|
|
| 177 |
can_gen = TOK is not None and MODEL is not None and ACTIVE_LABEL == label
|
| 178 |
return status, (zip_path if os.path.isfile(zip_path) else None), gr.update(interactive=can_gen)
|
| 179 |
|
|
|
|
| 180 |
# --------------------- UI ---------------------
|
| 181 |
with gr.Blocks() as demo:
|
| 182 |
-
gr.Markdown("## Quiz AI β Train β Save β Download β Use")
|
| 183 |
|
| 184 |
with gr.Tab("Train"):
|
| 185 |
with gr.Row():
|
| 186 |
-
dataset = gr.Textbox(value=os.environ.get("DATASET_ID", "Percy3822/quiz_ai_dataset"),
|
|
|
|
| 187 |
base = gr.Textbox(value=BASE_MODEL_DEFAULT, label="Base model", scale=2)
|
| 188 |
with gr.Row():
|
| 189 |
label = gr.Textbox(placeholder="e.g., quiz_v1 (leave blank for timestamp)", label="Adapter label", scale=2)
|
|
|
|
| 1 |
+
# app.py β Gradio UI with:
|
| 2 |
+
# - Train tab: live log streaming while running train.py
|
| 3 |
+
# - Use tab: pick a trained adapter, load (no fallback), generate strict JSON
|
| 4 |
+
# - Downloads: provides artifacts/<label>.zip for the loaded adapter
|
| 5 |
|
| 6 |
import os, re, json, time, sys, subprocess
|
| 7 |
from typing import Optional, Tuple
|
|
|
|
| 46 |
has_info= "MODEL_INFO.json" in files
|
| 47 |
return has_cfg and has_wts and has_info
|
| 48 |
|
| 49 |
+
|
| 50 |
# --------------------- global loaded model ---------------------
|
| 51 |
TOK = None
|
| 52 |
MODEL = None
|
| 53 |
ACTIVE_LABEL: Optional[str] = None
|
| 54 |
ACTIVE_BASE = BASE_MODEL_DEFAULT
|
| 55 |
|
| 56 |
+
|
| 57 |
+
# --------------------- training (live logs) ---------------------
|
| 58 |
def train_model_live(dataset_id, base_model, label, epochs):
|
| 59 |
if not label:
|
| 60 |
label = time.strftime("run_%Y%m%d_%H%M%S")
|
|
|
|
| 66 |
"--label", label.strip(),
|
| 67 |
"--epochs", str(int(epochs)),
|
| 68 |
]
|
| 69 |
+
# Show command first
|
| 70 |
yield f"$ {' '.join(cmd)}\n\n"
|
| 71 |
|
| 72 |
+
# Stream stdout line-by-line
|
| 73 |
process = subprocess.Popen(
|
| 74 |
cmd,
|
| 75 |
stdout=subprocess.PIPE,
|
|
|
|
| 85 |
break
|
| 86 |
if line:
|
| 87 |
buffer += line
|
| 88 |
+
yield buffer[-8000:] # keep tail visible
|
|
|
|
| 89 |
rc = process.wait()
|
| 90 |
+
yield buffer[-8000:] + f"\n\n[exit code: {rc}]"
|
| 91 |
+
|
| 92 |
|
| 93 |
# --------------------- loading + generation ---------------------
|
| 94 |
def load_selected_model(label: str) -> str:
|
|
|
|
| 103 |
meta = json.load(open(info_path, "r", encoding="utf-8"))
|
| 104 |
base = meta.get("base_model", BASE_MODEL_DEFAULT)
|
| 105 |
except Exception:
|
| 106 |
+
meta, base = {}, BASE_MODEL_DEFAULT
|
| 107 |
|
| 108 |
TOK = AutoTokenizer.from_pretrained(base)
|
| 109 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
|
|
|
|
| 111 |
MODEL.eval()
|
| 112 |
ACTIVE_LABEL = label
|
| 113 |
ACTIVE_BASE = base
|
| 114 |
+
ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(meta.get("saved_at", 0))) if meta else "unknown"
|
| 115 |
+
return f"β
Loaded: {label} (base={base}, saved={ts})"
|
| 116 |
|
| 117 |
PROMPT_HEADER = (
|
| 118 |
"You are a quiz question generator. "
|
|
|
|
| 134 |
try:
|
| 135 |
return json.loads(js)
|
| 136 |
except Exception:
|
| 137 |
+
# minor cleanup for trailing commas
|
| 138 |
+
js2 = re.sub(r",\s*([}\]])", r"\1", js)
|
| 139 |
try:
|
| 140 |
return json.loads(js2)
|
| 141 |
except Exception:
|
|
|
|
| 150 |
out = MODEL.generate(
|
| 151 |
**inputs,
|
| 152 |
max_new_tokens=160,
|
| 153 |
+
do_sample=False, # deterministic, cleaner JSON
|
| 154 |
num_beams=4,
|
| 155 |
early_stopping=True,
|
| 156 |
no_repeat_ngram_size=3,
|
|
|
|
| 183 |
can_gen = TOK is not None and MODEL is not None and ACTIVE_LABEL == label
|
| 184 |
return status, (zip_path if os.path.isfile(zip_path) else None), gr.update(interactive=can_gen)
|
| 185 |
|
| 186 |
+
|
| 187 |
# --------------------- UI ---------------------
|
| 188 |
with gr.Blocks() as demo:
|
| 189 |
+
gr.Markdown("## Quiz AI β Train β Save β Download β Use (No tokens)")
|
| 190 |
|
| 191 |
with gr.Tab("Train"):
|
| 192 |
with gr.Row():
|
| 193 |
+
dataset = gr.Textbox(value=os.environ.get("DATASET_ID", "Percy3822/quiz_ai_dataset"),
|
| 194 |
+
label="Dataset ID (HF) β ignored if data/train.jsonl exists", scale=2)
|
| 195 |
base = gr.Textbox(value=BASE_MODEL_DEFAULT, label="Base model", scale=2)
|
| 196 |
with gr.Row():
|
| 197 |
label = gr.Textbox(placeholder="e.g., quiz_v1 (leave blank for timestamp)", label="Adapter label", scale=2)
|