Update app.py
Browse files
app.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
-
# app.py β
|
| 2 |
-
#
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
|
| 6 |
-
import os, re, json, time, sys, subprocess
|
| 7 |
from typing import Optional, Tuple
|
| 8 |
|
| 9 |
import gradio as gr
|
|
|
|
| 10 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 11 |
from peft import PeftModel
|
| 12 |
|
|
@@ -46,13 +47,28 @@ def adapter_exists(label: str) -> bool:
|
|
| 46 |
has_info= "MODEL_INFO.json" in files
|
| 47 |
return has_cfg and has_wts and has_info
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# --------------------- global loaded model ---------------------
|
| 51 |
TOK = None
|
| 52 |
MODEL = None
|
| 53 |
ACTIVE_LABEL: Optional[str] = None
|
| 54 |
ACTIVE_BASE = BASE_MODEL_DEFAULT
|
| 55 |
-
|
| 56 |
|
| 57 |
# --------------------- training (live logs) ---------------------
|
| 58 |
def train_model_live(dataset_id, base_model, label, epochs):
|
|
@@ -87,13 +103,13 @@ def train_model_live(dataset_id, base_model, label, epochs):
|
|
| 87 |
rc = process.wait()
|
| 88 |
yield buffer[-8000:] + f"\n\n[exit code: {rc}]"
|
| 89 |
|
| 90 |
-
|
| 91 |
# --------------------- loading + generation ---------------------
|
| 92 |
def load_selected_model(label: str) -> str:
|
| 93 |
-
global TOK, MODEL, ACTIVE_LABEL, ACTIVE_BASE
|
| 94 |
if not adapter_exists(label):
|
| 95 |
TOK = MODEL = None
|
| 96 |
ACTIVE_LABEL = None
|
|
|
|
| 97 |
return f"π Adapter '{label}' not found. Train it first."
|
| 98 |
|
| 99 |
info_path = os.path.join(MODELS_DIR, label, "MODEL_INFO.json")
|
|
@@ -103,14 +119,19 @@ def load_selected_model(label: str) -> str:
|
|
| 103 |
except Exception:
|
| 104 |
meta, base = {}, BASE_MODEL_DEFAULT
|
| 105 |
|
|
|
|
| 106 |
TOK = AutoTokenizer.from_pretrained(base)
|
| 107 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
|
| 108 |
MODEL = PeftModel.from_pretrained(base_model, os.path.join(MODELS_DIR, label))
|
| 109 |
MODEL.eval()
|
|
|
|
| 110 |
ACTIVE_LABEL = label
|
| 111 |
ACTIVE_BASE = base
|
|
|
|
|
|
|
| 112 |
ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(meta.get("saved_at", 0))) if meta else "unknown"
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
# ---- FEW-SHOT PROMPT to anchor structure ----
|
| 116 |
FEW_SHOT = (
|
|
@@ -119,11 +140,11 @@ FEW_SHOT = (
|
|
| 119 |
"q, choices, answer, explanation, age_band, genre, difficulty.\n"
|
| 120 |
"Do not include any extra text before/after the JSON.\n\n"
|
| 121 |
"Example:\n"
|
| 122 |
-
|
| 123 |
-
"
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
-
"
|
| 127 |
)
|
| 128 |
|
| 129 |
def mk_prompt(age_band, genre, difficulty):
|
|
@@ -147,42 +168,51 @@ def try_parse_json(js: str) -> Optional[dict]:
|
|
| 147 |
except Exception:
|
| 148 |
return None
|
| 149 |
|
| 150 |
-
def
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
prompt = mk_prompt(age_band, genre, difficulty)
|
| 155 |
-
inputs = TOK(prompt, return_tensors="pt")
|
| 156 |
-
|
| 157 |
-
# PASS 1: deterministic with a minimum length (prevents trivial '4')
|
| 158 |
-
out = MODEL.generate(
|
| 159 |
**inputs,
|
| 160 |
-
max_new_tokens=
|
| 161 |
-
min_new_tokens=
|
| 162 |
do_sample=False,
|
| 163 |
num_beams=4,
|
| 164 |
length_penalty=1.0,
|
| 165 |
early_stopping=False,
|
| 166 |
no_repeat_ngram_size=3,
|
| 167 |
)
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
js = extract_json_str(text)
|
| 170 |
obj = try_parse_json(js) if js else None
|
| 171 |
|
| 172 |
-
# PASS 2: sampled retry if too short or
|
| 173 |
if not obj or len(js or "") < 40:
|
| 174 |
-
|
| 175 |
-
**inputs,
|
| 176 |
-
max_new_tokens=220,
|
| 177 |
-
min_new_tokens=80,
|
| 178 |
-
do_sample=True,
|
| 179 |
-
temperature=0.6,
|
| 180 |
-
top_p=0.9,
|
| 181 |
-
top_k=50,
|
| 182 |
-
no_repeat_ngram_size=3,
|
| 183 |
-
early_stopping=True,
|
| 184 |
-
)
|
| 185 |
-
text = TOK.decode(out[0], skip_special_tokens=True)
|
| 186 |
js = extract_json_str(text)
|
| 187 |
obj = try_parse_json(js) if js else None
|
| 188 |
|
|
@@ -196,7 +226,7 @@ def generate(age_band, genre, difficulty):
|
|
| 196 |
missing = sorted(list(need - set(obj.keys())))
|
| 197 |
return f"ERROR: JSON missing keys {missing}\n\nPARSED:\n{json.dumps(obj, indent=2)}"
|
| 198 |
|
| 199 |
-
# Guardrails
|
| 200 |
if not isinstance(obj.get("choices"), list) or len(obj["choices"]) != 4:
|
| 201 |
return f"ERROR: choices must be a list of 4.\n\nPARSED:\n{json.dumps(obj, indent=2)}"
|
| 202 |
if not isinstance(obj.get("answer"), int) or not (0 <= obj["answer"] <= 3):
|
|
@@ -204,6 +234,44 @@ def generate(age_band, genre, difficulty):
|
|
| 204 |
|
| 205 |
return json.dumps(obj, indent=2, ensure_ascii=False)
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
def reload_list():
|
| 208 |
items = list_available_models()
|
| 209 |
if not items:
|
|
@@ -217,10 +285,9 @@ def do_load(label):
|
|
| 217 |
can_gen = TOK is not None and MODEL is not None and ACTIVE_LABEL == label
|
| 218 |
return status, (zip_path if os.path.isfile(zip_path) else None), gr.update(interactive=can_gen)
|
| 219 |
|
| 220 |
-
|
| 221 |
# --------------------- UI ---------------------
|
| 222 |
with gr.Blocks() as demo:
|
| 223 |
-
gr.Markdown("## Quiz AI β Train β Save β Download β Use (No tokens)")
|
| 224 |
|
| 225 |
with gr.Tab("Train"):
|
| 226 |
with gr.Row():
|
|
@@ -255,9 +322,18 @@ with gr.Blocks() as demo:
|
|
| 255 |
gen_btn = gr.Button("Generate sample", interactive=False) # disabled until model loaded
|
| 256 |
out = gr.Code(label="Model output (JSON expected)")
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
refresh_btn.click(fn=reload_list, outputs=[model_list, status_md, label])
|
| 259 |
load_btn.click(fn=do_load, inputs=[model_list], outputs=[status_md, current_zip, gen_btn])
|
| 260 |
gen_btn.click(fn=generate, inputs=[age, gen, diff], outputs=out)
|
|
|
|
| 261 |
|
| 262 |
if __name__ == "__main__":
|
| 263 |
demo.launch()
|
|
|
|
| 1 |
+
# app.py β Train β Save β Download β Use (No tokens) + Adapter Verification
|
| 2 |
+
# - Train tab: live log streaming (runs train.py)
|
| 3 |
+
# - Use tab: pick & load adapter (no fallback), generate strict JSON
|
| 4 |
+
# - Verify: A/B compare (base vs adapter) and display adapter SHA-256
|
| 5 |
|
| 6 |
+
import os, re, json, time, sys, subprocess, hashlib
|
| 7 |
from typing import Optional, Tuple
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
+
import torch
|
| 11 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 12 |
from peft import PeftModel
|
| 13 |
|
|
|
|
| 47 |
has_info= "MODEL_INFO.json" in files
|
| 48 |
return has_cfg and has_wts and has_info
|
| 49 |
|
| 50 |
+
def _adapter_weights_path(label: str) -> Optional[str]:
|
| 51 |
+
d, _ = model_paths(label)
|
| 52 |
+
p_safe = os.path.join(d, "adapter_model.safetensors")
|
| 53 |
+
p_bin = os.path.join(d, "adapter_model.bin")
|
| 54 |
+
if os.path.isfile(p_safe): return p_safe
|
| 55 |
+
if os.path.isfile(p_bin): return p_bin
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
def sha256_file(path: str) -> Optional[str]:
|
| 59 |
+
if not path or not os.path.isfile(path): return None
|
| 60 |
+
h = hashlib.sha256()
|
| 61 |
+
with open(path, "rb") as f:
|
| 62 |
+
for chunk in iter(lambda: f.read(8192), b""):
|
| 63 |
+
h.update(chunk)
|
| 64 |
+
return h.hexdigest()
|
| 65 |
|
| 66 |
# --------------------- global loaded model ---------------------
|
| 67 |
TOK = None
|
| 68 |
MODEL = None
|
| 69 |
ACTIVE_LABEL: Optional[str] = None
|
| 70 |
ACTIVE_BASE = BASE_MODEL_DEFAULT
|
| 71 |
+
ACTIVE_SHA: Optional[str] = None
|
| 72 |
|
| 73 |
# --------------------- training (live logs) ---------------------
|
| 74 |
def train_model_live(dataset_id, base_model, label, epochs):
|
|
|
|
| 103 |
rc = process.wait()
|
| 104 |
yield buffer[-8000:] + f"\n\n[exit code: {rc}]"
|
| 105 |
|
|
|
|
| 106 |
# --------------------- loading + generation ---------------------
|
| 107 |
def load_selected_model(label: str) -> str:
|
| 108 |
+
global TOK, MODEL, ACTIVE_LABEL, ACTIVE_BASE, ACTIVE_SHA
|
| 109 |
if not adapter_exists(label):
|
| 110 |
TOK = MODEL = None
|
| 111 |
ACTIVE_LABEL = None
|
| 112 |
+
ACTIVE_SHA = None
|
| 113 |
return f"π Adapter '{label}' not found. Train it first."
|
| 114 |
|
| 115 |
info_path = os.path.join(MODELS_DIR, label, "MODEL_INFO.json")
|
|
|
|
| 119 |
except Exception:
|
| 120 |
meta, base = {}, BASE_MODEL_DEFAULT
|
| 121 |
|
| 122 |
+
# Load tokenizer + base, then attach adapter
|
| 123 |
TOK = AutoTokenizer.from_pretrained(base)
|
| 124 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
|
| 125 |
MODEL = PeftModel.from_pretrained(base_model, os.path.join(MODELS_DIR, label))
|
| 126 |
MODEL.eval()
|
| 127 |
+
|
| 128 |
ACTIVE_LABEL = label
|
| 129 |
ACTIVE_BASE = base
|
| 130 |
+
ACTIVE_SHA = sha256_file(_adapter_weights_path(label))
|
| 131 |
+
|
| 132 |
ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(meta.get("saved_at", 0))) if meta else "unknown"
|
| 133 |
+
sha_show = (ACTIVE_SHA[:12] + "β¦") if ACTIVE_SHA else "unknown"
|
| 134 |
+
return f"β
Loaded: {label} (base={base}, saved={ts}, sha={sha_show})"
|
| 135 |
|
| 136 |
# ---- FEW-SHOT PROMPT to anchor structure ----
|
| 137 |
FEW_SHOT = (
|
|
|
|
| 140 |
"q, choices, answer, explanation, age_band, genre, difficulty.\n"
|
| 141 |
"Do not include any extra text before/after the JSON.\n\n"
|
| 142 |
"Example:\n"
|
| 143 |
+
'JSON:{"q":"What planet is known as the Red Planet?",'
|
| 144 |
+
'"choices":["Earth","Mars","Venus","Jupiter"],'
|
| 145 |
+
'"answer":1,'
|
| 146 |
+
'"explanation":"Mars appears red due to iron oxide.",'
|
| 147 |
+
'"age_band":"13-17","genre":"science","difficulty":"easy"}\n\n'
|
| 148 |
)
|
| 149 |
|
| 150 |
def mk_prompt(age_band, genre, difficulty):
|
|
|
|
| 168 |
except Exception:
|
| 169 |
return None
|
| 170 |
|
| 171 |
+
def _gen_deterministic(model, tok, prompt: str, min_len=80, max_len=220):
|
| 172 |
+
torch.manual_seed(0) # reproducible
|
| 173 |
+
inputs = tok(prompt, return_tensors="pt")
|
| 174 |
+
out = model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
**inputs,
|
| 176 |
+
max_new_tokens=max_len,
|
| 177 |
+
min_new_tokens=min_len,
|
| 178 |
do_sample=False,
|
| 179 |
num_beams=4,
|
| 180 |
length_penalty=1.0,
|
| 181 |
early_stopping=False,
|
| 182 |
no_repeat_ngram_size=3,
|
| 183 |
)
|
| 184 |
+
return tok.decode(out[0], skip_special_tokens=True)
|
| 185 |
+
|
| 186 |
+
def _gen_sampled(model, tok, prompt: str, min_len=80, max_len=220):
|
| 187 |
+
torch.manual_seed(0) # keep reproducible for debugging
|
| 188 |
+
inputs = tok(prompt, return_tensors="pt")
|
| 189 |
+
out = model.generate(
|
| 190 |
+
**inputs,
|
| 191 |
+
max_new_tokens=max_len,
|
| 192 |
+
min_new_tokens=min_len,
|
| 193 |
+
do_sample=True,
|
| 194 |
+
temperature=0.6,
|
| 195 |
+
top_p=0.9,
|
| 196 |
+
top_k=50,
|
| 197 |
+
no_repeat_ngram_size=3,
|
| 198 |
+
early_stopping=True,
|
| 199 |
+
)
|
| 200 |
+
return tok.decode(out[0], skip_special_tokens=True)
|
| 201 |
+
|
| 202 |
+
def generate(age_band, genre, difficulty):
|
| 203 |
+
if TOK is None or MODEL is None or ACTIVE_LABEL is None:
|
| 204 |
+
return "π No model loaded. Pick a trained adapter and press *Load model*."
|
| 205 |
+
|
| 206 |
+
prompt = mk_prompt(age_band, genre, difficulty)
|
| 207 |
+
|
| 208 |
+
# PASS 1: deterministic
|
| 209 |
+
text = _gen_deterministic(MODEL, TOK, prompt)
|
| 210 |
js = extract_json_str(text)
|
| 211 |
obj = try_parse_json(js) if js else None
|
| 212 |
|
| 213 |
+
# PASS 2: sampled retry if too short or invalid
|
| 214 |
if not obj or len(js or "") < 40:
|
| 215 |
+
text = _gen_sampled(MODEL, TOK, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
js = extract_json_str(text)
|
| 217 |
obj = try_parse_json(js) if js else None
|
| 218 |
|
|
|
|
| 226 |
missing = sorted(list(need - set(obj.keys())))
|
| 227 |
return f"ERROR: JSON missing keys {missing}\n\nPARSED:\n{json.dumps(obj, indent=2)}"
|
| 228 |
|
| 229 |
+
# Guardrails
|
| 230 |
if not isinstance(obj.get("choices"), list) or len(obj["choices"]) != 4:
|
| 231 |
return f"ERROR: choices must be a list of 4.\n\nPARSED:\n{json.dumps(obj, indent=2)}"
|
| 232 |
if not isinstance(obj.get("answer"), int) or not (0 <= obj["answer"] <= 3):
|
|
|
|
| 234 |
|
| 235 |
return json.dumps(obj, indent=2, ensure_ascii=False)
|
| 236 |
|
| 237 |
+
# -------- Verification: base vs adapter A/B on the same prompt --------
|
| 238 |
+
def verify_adapter(age_band, genre, difficulty):
|
| 239 |
+
if TOK is None or MODEL is None or ACTIVE_LABEL is None:
|
| 240 |
+
return "π No adapter loaded."
|
| 241 |
+
|
| 242 |
+
prompt = mk_prompt(age_band, genre, difficulty)
|
| 243 |
+
|
| 244 |
+
# Base-only (fresh load, no adapter)
|
| 245 |
+
base_tok = AutoTokenizer.from_pretrained(ACTIVE_BASE)
|
| 246 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(ACTIVE_BASE)
|
| 247 |
+
base_model.eval()
|
| 248 |
+
|
| 249 |
+
base_text = _gen_deterministic(base_model, base_tok, prompt)
|
| 250 |
+
base_json = extract_json_str(base_text)
|
| 251 |
+
base_ok = bool(try_parse_json(base_json) if base_json else None)
|
| 252 |
+
|
| 253 |
+
# Adapter (current MODEL/TOK)
|
| 254 |
+
adapter_text = _gen_deterministic(MODEL, TOK, prompt)
|
| 255 |
+
adapter_json = extract_json_str(adapter_text)
|
| 256 |
+
adapter_ok = bool(try_parse_json(adapter_json) if adapter_json else None)
|
| 257 |
+
|
| 258 |
+
sha_show = (ACTIVE_SHA[:12] + "β¦") if ACTIVE_SHA else "unknown"
|
| 259 |
+
|
| 260 |
+
report = {
|
| 261 |
+
"active_label": ACTIVE_LABEL,
|
| 262 |
+
"base_model": ACTIVE_BASE,
|
| 263 |
+
"adapter_sha256": ACTIVE_SHA,
|
| 264 |
+
"prompt_preview": prompt[:200] + ("β¦" if len(prompt) > 200 else ""),
|
| 265 |
+
"base_output_preview": (base_text[:400] + "β¦") if len(base_text) > 400 else base_text,
|
| 266 |
+
"base_json_detected": bool(base_json),
|
| 267 |
+
"base_json_parsed_ok": base_ok,
|
| 268 |
+
"adapter_output_preview": (adapter_text[:400] + "β¦") if len(adapter_text) > 400 else adapter_text,
|
| 269 |
+
"adapter_json_detected": bool(adapter_json),
|
| 270 |
+
"adapter_json_parsed_ok": adapter_ok,
|
| 271 |
+
"note": "If adapter_json_parsed_ok != base_json_parsed_ok, the adapter is changing behavior."
|
| 272 |
+
}
|
| 273 |
+
return json.dumps(report, indent=2, ensure_ascii=False)
|
| 274 |
+
|
| 275 |
def reload_list():
|
| 276 |
items = list_available_models()
|
| 277 |
if not items:
|
|
|
|
| 285 |
can_gen = TOK is not None and MODEL is not None and ACTIVE_LABEL == label
|
| 286 |
return status, (zip_path if os.path.isfile(zip_path) else None), gr.update(interactive=can_gen)
|
| 287 |
|
|
|
|
| 288 |
# --------------------- UI ---------------------
|
| 289 |
with gr.Blocks() as demo:
|
| 290 |
+
gr.Markdown("## Quiz AI β Train β Save β Download β Use (No tokens) + Verify Adapter")
|
| 291 |
|
| 292 |
with gr.Tab("Train"):
|
| 293 |
with gr.Row():
|
|
|
|
| 322 |
gen_btn = gr.Button("Generate sample", interactive=False) # disabled until model loaded
|
| 323 |
out = gr.Code(label="Model output (JSON expected)")
|
| 324 |
|
| 325 |
+
# Verify section
|
| 326 |
+
with gr.Accordion("Verify adapter (A/B vs base)", open=False):
|
| 327 |
+
v_age = gr.Textbox("13-17", label="Age band")
|
| 328 |
+
v_gen = gr.Dropdown(choices=["geography","science","history","math"], value="science", label="Genre")
|
| 329 |
+
v_diff = gr.Dropdown(choices=["easy","medium","hard"], value="easy", label="Difficulty")
|
| 330 |
+
verify_btn = gr.Button("Run verification")
|
| 331 |
+
verify_out = gr.Code(label="Verification report (JSON)")
|
| 332 |
+
|
| 333 |
refresh_btn.click(fn=reload_list, outputs=[model_list, status_md, label])
|
| 334 |
load_btn.click(fn=do_load, inputs=[model_list], outputs=[status_md, current_zip, gen_btn])
|
| 335 |
gen_btn.click(fn=generate, inputs=[age, gen, diff], outputs=out)
|
| 336 |
+
verify_btn.click(fn=verify_adapter, inputs=[v_age, v_gen, v_diff], outputs=verify_out)
|
| 337 |
|
| 338 |
if __name__ == "__main__":
|
| 339 |
demo.launch()
|