Percy3822 commited on
Commit
53f1707
Β·
verified Β·
1 Parent(s): cfd0695

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -42
app.py CHANGED
@@ -1,12 +1,13 @@
1
- # app.py β€” Gradio UI with:
2
- # - Train tab: live log streaming while running train.py
3
- # - Use tab: pick a trained adapter, load (no fallback), generate strict JSON
4
- # - Downloads: provides artifacts/<label>.zip for the loaded adapter
5
 
6
- import os, re, json, time, sys, subprocess
7
  from typing import Optional, Tuple
8
 
9
  import gradio as gr
 
10
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
  from peft import PeftModel
12
 
@@ -46,13 +47,28 @@ def adapter_exists(label: str) -> bool:
46
  has_info= "MODEL_INFO.json" in files
47
  return has_cfg and has_wts and has_info
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # --------------------- global loaded model ---------------------
51
  TOK = None
52
  MODEL = None
53
  ACTIVE_LABEL: Optional[str] = None
54
  ACTIVE_BASE = BASE_MODEL_DEFAULT
55
-
56
 
57
  # --------------------- training (live logs) ---------------------
58
  def train_model_live(dataset_id, base_model, label, epochs):
@@ -87,13 +103,13 @@ def train_model_live(dataset_id, base_model, label, epochs):
87
  rc = process.wait()
88
  yield buffer[-8000:] + f"\n\n[exit code: {rc}]"
89
 
90
-
91
  # --------------------- loading + generation ---------------------
92
  def load_selected_model(label: str) -> str:
93
- global TOK, MODEL, ACTIVE_LABEL, ACTIVE_BASE
94
  if not adapter_exists(label):
95
  TOK = MODEL = None
96
  ACTIVE_LABEL = None
 
97
  return f"πŸ›‘ Adapter '{label}' not found. Train it first."
98
 
99
  info_path = os.path.join(MODELS_DIR, label, "MODEL_INFO.json")
@@ -103,14 +119,19 @@ def load_selected_model(label: str) -> str:
103
  except Exception:
104
  meta, base = {}, BASE_MODEL_DEFAULT
105
 
 
106
  TOK = AutoTokenizer.from_pretrained(base)
107
  base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
108
  MODEL = PeftModel.from_pretrained(base_model, os.path.join(MODELS_DIR, label))
109
  MODEL.eval()
 
110
  ACTIVE_LABEL = label
111
  ACTIVE_BASE = base
 
 
112
  ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(meta.get("saved_at", 0))) if meta else "unknown"
113
- return f"βœ… Loaded: {label} (base={base}, saved={ts})"
 
114
 
115
  # ---- FEW-SHOT PROMPT to anchor structure ----
116
  FEW_SHOT = (
@@ -119,11 +140,11 @@ FEW_SHOT = (
119
  "q, choices, answer, explanation, age_band, genre, difficulty.\n"
120
  "Do not include any extra text before/after the JSON.\n\n"
121
  "Example:\n"
122
- "JSON:{\"q\":\"What planet is known as the Red Planet?\","
123
- "\"choices\":[\"Earth\",\"Mars\",\"Venus\",\"Jupiter\"],"
124
- "\"answer\":1,"
125
- "\"explanation\":\"Mars appears red due to iron oxide.\","
126
- "\"age_band\":\"13-17\",\"genre\":\"science\",\"difficulty\":\"easy\"}\n\n"
127
  )
128
 
129
  def mk_prompt(age_band, genre, difficulty):
@@ -147,42 +168,51 @@ def try_parse_json(js: str) -> Optional[dict]:
147
  except Exception:
148
  return None
149
 
150
- def generate(age_band, genre, difficulty):
151
- if TOK is None or MODEL is None or ACTIVE_LABEL is None:
152
- return "πŸ›‘ No model loaded. Pick a trained adapter and press *Load model*."
153
-
154
- prompt = mk_prompt(age_band, genre, difficulty)
155
- inputs = TOK(prompt, return_tensors="pt")
156
-
157
- # PASS 1: deterministic with a minimum length (prevents trivial '4')
158
- out = MODEL.generate(
159
  **inputs,
160
- max_new_tokens=220,
161
- min_new_tokens=80,
162
  do_sample=False,
163
  num_beams=4,
164
  length_penalty=1.0,
165
  early_stopping=False,
166
  no_repeat_ngram_size=3,
167
  )
168
- text = TOK.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  js = extract_json_str(text)
170
  obj = try_parse_json(js) if js else None
171
 
172
- # PASS 2: sampled retry if too short or not valid JSON
173
  if not obj or len(js or "") < 40:
174
- out = MODEL.generate(
175
- **inputs,
176
- max_new_tokens=220,
177
- min_new_tokens=80,
178
- do_sample=True,
179
- temperature=0.6,
180
- top_p=0.9,
181
- top_k=50,
182
- no_repeat_ngram_size=3,
183
- early_stopping=True,
184
- )
185
- text = TOK.decode(out[0], skip_special_tokens=True)
186
  js = extract_json_str(text)
187
  obj = try_parse_json(js) if js else None
188
 
@@ -196,7 +226,7 @@ def generate(age_band, genre, difficulty):
196
  missing = sorted(list(need - set(obj.keys())))
197
  return f"ERROR: JSON missing keys {missing}\n\nPARSED:\n{json.dumps(obj, indent=2)}"
198
 
199
- # Guardrails: exactly 4 choices and valid 0-3 answer index
200
  if not isinstance(obj.get("choices"), list) or len(obj["choices"]) != 4:
201
  return f"ERROR: choices must be a list of 4.\n\nPARSED:\n{json.dumps(obj, indent=2)}"
202
  if not isinstance(obj.get("answer"), int) or not (0 <= obj["answer"] <= 3):
@@ -204,6 +234,44 @@ def generate(age_band, genre, difficulty):
204
 
205
  return json.dumps(obj, indent=2, ensure_ascii=False)
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def reload_list():
208
  items = list_available_models()
209
  if not items:
@@ -217,10 +285,9 @@ def do_load(label):
217
  can_gen = TOK is not None and MODEL is not None and ACTIVE_LABEL == label
218
  return status, (zip_path if os.path.isfile(zip_path) else None), gr.update(interactive=can_gen)
219
 
220
-
221
  # --------------------- UI ---------------------
222
  with gr.Blocks() as demo:
223
- gr.Markdown("## Quiz AI β€” Train ➜ Save ➜ Download ➜ Use (No tokens)")
224
 
225
  with gr.Tab("Train"):
226
  with gr.Row():
@@ -255,9 +322,18 @@ with gr.Blocks() as demo:
255
  gen_btn = gr.Button("Generate sample", interactive=False) # disabled until model loaded
256
  out = gr.Code(label="Model output (JSON expected)")
257
 
 
 
 
 
 
 
 
 
258
  refresh_btn.click(fn=reload_list, outputs=[model_list, status_md, label])
259
  load_btn.click(fn=do_load, inputs=[model_list], outputs=[status_md, current_zip, gen_btn])
260
  gen_btn.click(fn=generate, inputs=[age, gen, diff], outputs=out)
 
261
 
262
  if __name__ == "__main__":
263
  demo.launch()
 
1
+ # app.py β€” Train ➜ Save ➜ Download ➜ Use (No tokens) + Adapter Verification
2
+ # - Train tab: live log streaming (runs train.py)
3
+ # - Use tab: pick & load adapter (no fallback), generate strict JSON
4
+ # - Verify: A/B compare (base vs adapter) and display adapter SHA-256
5
 
6
+ import os, re, json, time, sys, subprocess, hashlib
7
  from typing import Optional, Tuple
8
 
9
  import gradio as gr
10
+ import torch
11
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
  from peft import PeftModel
13
 
 
47
  has_info= "MODEL_INFO.json" in files
48
  return has_cfg and has_wts and has_info
49
 
50
+ def _adapter_weights_path(label: str) -> Optional[str]:
51
+ d, _ = model_paths(label)
52
+ p_safe = os.path.join(d, "adapter_model.safetensors")
53
+ p_bin = os.path.join(d, "adapter_model.bin")
54
+ if os.path.isfile(p_safe): return p_safe
55
+ if os.path.isfile(p_bin): return p_bin
56
+ return None
57
+
58
+ def sha256_file(path: str) -> Optional[str]:
59
+ if not path or not os.path.isfile(path): return None
60
+ h = hashlib.sha256()
61
+ with open(path, "rb") as f:
62
+ for chunk in iter(lambda: f.read(8192), b""):
63
+ h.update(chunk)
64
+ return h.hexdigest()
65
 
66
  # --------------------- global loaded model ---------------------
67
  TOK = None
68
  MODEL = None
69
  ACTIVE_LABEL: Optional[str] = None
70
  ACTIVE_BASE = BASE_MODEL_DEFAULT
71
+ ACTIVE_SHA: Optional[str] = None
72
 
73
  # --------------------- training (live logs) ---------------------
74
  def train_model_live(dataset_id, base_model, label, epochs):
 
103
  rc = process.wait()
104
  yield buffer[-8000:] + f"\n\n[exit code: {rc}]"
105
 
 
106
  # --------------------- loading + generation ---------------------
107
  def load_selected_model(label: str) -> str:
108
+ global TOK, MODEL, ACTIVE_LABEL, ACTIVE_BASE, ACTIVE_SHA
109
  if not adapter_exists(label):
110
  TOK = MODEL = None
111
  ACTIVE_LABEL = None
112
+ ACTIVE_SHA = None
113
  return f"πŸ›‘ Adapter '{label}' not found. Train it first."
114
 
115
  info_path = os.path.join(MODELS_DIR, label, "MODEL_INFO.json")
 
119
  except Exception:
120
  meta, base = {}, BASE_MODEL_DEFAULT
121
 
122
+ # Load tokenizer + base, then attach adapter
123
  TOK = AutoTokenizer.from_pretrained(base)
124
  base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
125
  MODEL = PeftModel.from_pretrained(base_model, os.path.join(MODELS_DIR, label))
126
  MODEL.eval()
127
+
128
  ACTIVE_LABEL = label
129
  ACTIVE_BASE = base
130
+ ACTIVE_SHA = sha256_file(_adapter_weights_path(label))
131
+
132
  ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(meta.get("saved_at", 0))) if meta else "unknown"
133
+ sha_show = (ACTIVE_SHA[:12] + "…") if ACTIVE_SHA else "unknown"
134
+ return f"βœ… Loaded: {label} (base={base}, saved={ts}, sha={sha_show})"
135
 
136
  # ---- FEW-SHOT PROMPT to anchor structure ----
137
  FEW_SHOT = (
 
140
  "q, choices, answer, explanation, age_band, genre, difficulty.\n"
141
  "Do not include any extra text before/after the JSON.\n\n"
142
  "Example:\n"
143
+ 'JSON:{"q":"What planet is known as the Red Planet?",'
144
+ '"choices":["Earth","Mars","Venus","Jupiter"],'
145
+ '"answer":1,'
146
+ '"explanation":"Mars appears red due to iron oxide.",'
147
+ '"age_band":"13-17","genre":"science","difficulty":"easy"}\n\n'
148
  )
149
 
150
  def mk_prompt(age_band, genre, difficulty):
 
168
  except Exception:
169
  return None
170
 
171
+ def _gen_deterministic(model, tok, prompt: str, min_len=80, max_len=220):
172
+ torch.manual_seed(0) # reproducible
173
+ inputs = tok(prompt, return_tensors="pt")
174
+ out = model.generate(
 
 
 
 
 
175
  **inputs,
176
+ max_new_tokens=max_len,
177
+ min_new_tokens=min_len,
178
  do_sample=False,
179
  num_beams=4,
180
  length_penalty=1.0,
181
  early_stopping=False,
182
  no_repeat_ngram_size=3,
183
  )
184
+ return tok.decode(out[0], skip_special_tokens=True)
185
+
186
+ def _gen_sampled(model, tok, prompt: str, min_len=80, max_len=220):
187
+ torch.manual_seed(0) # keep reproducible for debugging
188
+ inputs = tok(prompt, return_tensors="pt")
189
+ out = model.generate(
190
+ **inputs,
191
+ max_new_tokens=max_len,
192
+ min_new_tokens=min_len,
193
+ do_sample=True,
194
+ temperature=0.6,
195
+ top_p=0.9,
196
+ top_k=50,
197
+ no_repeat_ngram_size=3,
198
+ early_stopping=True,
199
+ )
200
+ return tok.decode(out[0], skip_special_tokens=True)
201
+
202
+ def generate(age_band, genre, difficulty):
203
+ if TOK is None or MODEL is None or ACTIVE_LABEL is None:
204
+ return "πŸ›‘ No model loaded. Pick a trained adapter and press *Load model*."
205
+
206
+ prompt = mk_prompt(age_band, genre, difficulty)
207
+
208
+ # PASS 1: deterministic
209
+ text = _gen_deterministic(MODEL, TOK, prompt)
210
  js = extract_json_str(text)
211
  obj = try_parse_json(js) if js else None
212
 
213
+ # PASS 2: sampled retry if too short or invalid
214
  if not obj or len(js or "") < 40:
215
+ text = _gen_sampled(MODEL, TOK, prompt)
 
 
 
 
 
 
 
 
 
 
 
216
  js = extract_json_str(text)
217
  obj = try_parse_json(js) if js else None
218
 
 
226
  missing = sorted(list(need - set(obj.keys())))
227
  return f"ERROR: JSON missing keys {missing}\n\nPARSED:\n{json.dumps(obj, indent=2)}"
228
 
229
+ # Guardrails
230
  if not isinstance(obj.get("choices"), list) or len(obj["choices"]) != 4:
231
  return f"ERROR: choices must be a list of 4.\n\nPARSED:\n{json.dumps(obj, indent=2)}"
232
  if not isinstance(obj.get("answer"), int) or not (0 <= obj["answer"] <= 3):
 
234
 
235
  return json.dumps(obj, indent=2, ensure_ascii=False)
236
 
237
+ # -------- Verification: base vs adapter A/B on the same prompt --------
238
+ def verify_adapter(age_band, genre, difficulty):
239
+ if TOK is None or MODEL is None or ACTIVE_LABEL is None:
240
+ return "πŸ›‘ No adapter loaded."
241
+
242
+ prompt = mk_prompt(age_band, genre, difficulty)
243
+
244
+ # Base-only (fresh load, no adapter)
245
+ base_tok = AutoTokenizer.from_pretrained(ACTIVE_BASE)
246
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(ACTIVE_BASE)
247
+ base_model.eval()
248
+
249
+ base_text = _gen_deterministic(base_model, base_tok, prompt)
250
+ base_json = extract_json_str(base_text)
251
+ base_ok = bool(try_parse_json(base_json) if base_json else None)
252
+
253
+ # Adapter (current MODEL/TOK)
254
+ adapter_text = _gen_deterministic(MODEL, TOK, prompt)
255
+ adapter_json = extract_json_str(adapter_text)
256
+ adapter_ok = bool(try_parse_json(adapter_json) if adapter_json else None)
257
+
258
+ sha_show = (ACTIVE_SHA[:12] + "…") if ACTIVE_SHA else "unknown"
259
+
260
+ report = {
261
+ "active_label": ACTIVE_LABEL,
262
+ "base_model": ACTIVE_BASE,
263
+ "adapter_sha256": ACTIVE_SHA,
264
+ "prompt_preview": prompt[:200] + ("…" if len(prompt) > 200 else ""),
265
+ "base_output_preview": (base_text[:400] + "…") if len(base_text) > 400 else base_text,
266
+ "base_json_detected": bool(base_json),
267
+ "base_json_parsed_ok": base_ok,
268
+ "adapter_output_preview": (adapter_text[:400] + "…") if len(adapter_text) > 400 else adapter_text,
269
+ "adapter_json_detected": bool(adapter_json),
270
+ "adapter_json_parsed_ok": adapter_ok,
271
+ "note": "If adapter_json_parsed_ok != base_json_parsed_ok, the adapter is changing behavior."
272
+ }
273
+ return json.dumps(report, indent=2, ensure_ascii=False)
274
+
275
  def reload_list():
276
  items = list_available_models()
277
  if not items:
 
285
  can_gen = TOK is not None and MODEL is not None and ACTIVE_LABEL == label
286
  return status, (zip_path if os.path.isfile(zip_path) else None), gr.update(interactive=can_gen)
287
 
 
288
  # --------------------- UI ---------------------
289
  with gr.Blocks() as demo:
290
+ gr.Markdown("## Quiz AI β€” Train ➜ Save ➜ Download ➜ Use (No tokens) + Verify Adapter")
291
 
292
  with gr.Tab("Train"):
293
  with gr.Row():
 
322
  gen_btn = gr.Button("Generate sample", interactive=False) # disabled until model loaded
323
  out = gr.Code(label="Model output (JSON expected)")
324
 
325
+ # Verify section
326
+ with gr.Accordion("Verify adapter (A/B vs base)", open=False):
327
+ v_age = gr.Textbox("13-17", label="Age band")
328
+ v_gen = gr.Dropdown(choices=["geography","science","history","math"], value="science", label="Genre")
329
+ v_diff = gr.Dropdown(choices=["easy","medium","hard"], value="easy", label="Difficulty")
330
+ verify_btn = gr.Button("Run verification")
331
+ verify_out = gr.Code(label="Verification report (JSON)")
332
+
333
  refresh_btn.click(fn=reload_list, outputs=[model_list, status_md, label])
334
  load_btn.click(fn=do_load, inputs=[model_list], outputs=[status_md, current_zip, gen_btn])
335
  gen_btn.click(fn=generate, inputs=[age, gen, diff], outputs=out)
336
+ verify_btn.click(fn=verify_adapter, inputs=[v_age, v_gen, v_diff], outputs=verify_out)
337
 
338
  if __name__ == "__main__":
339
  demo.launch()