Percy3822 commited on
Commit
3da5a15
Β·
verified Β·
1 Parent(s): 50b5ef2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -59
app.py CHANGED
@@ -1,85 +1,411 @@
1
- # app.py (minimal callback sanity check)
2
- import os, time
3
  from pathlib import Path
 
4
  import gradio as gr
5
 
 
6
  ROOT = Path(__file__).resolve().parent
7
- LOG = ROOT / "callback_test.log"
 
8
  RUNS = ROOT / "runs"
9
  RUNS.mkdir(exist_ok=True)
10
 
 
11
  def append_log(msg: str):
 
12
  try:
13
- with open(LOG, "a", encoding="utf-8") as f:
14
- f.write(msg.rstrip() + "\n")
15
  except Exception:
16
  pass
17
 
18
- def ping():
19
- append_log("PING called")
20
- return "βœ… Backend alive (ping)"
21
 
22
- def stream():
23
- append_log("STREAM start")
24
- for i in range(5):
25
- time.sleep(0.3)
26
- yield f"tick {i}"
27
- append_log("STREAM done")
28
- yield "βœ… Stream ok"
 
 
 
29
 
30
  def list_models():
31
- # Just list subfolders under runs/ to prove dropdown gets values
32
- items = [str(p) for p in sorted(RUNS.glob("*")) if p.is_dir()]
33
- append_log(f"LIST_MODELS -> {items}")
34
- return items
35
-
36
- def refresh_models():
37
- return gr.update(choices=list_models(), value=None)
38
-
39
- def load_model(path):
40
- append_log(f"LOAD_MODEL clicked: {path}")
41
- if not path:
42
- return "❌ Select a folder under runs/ (create one manually if empty)."
43
- p = Path(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if not p.exists() or not p.is_dir():
45
- return f"❌ Not found: {path}"
46
- return f"βœ… (Mock) loaded folder: {path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- def generate_once(path, prompt):
49
- append_log(f"GENERATE clicked: path={path!r}, len(prompt)={len(prompt or '')}")
50
- if not path:
51
- return "❌ Pick a folder in the dropdown."
 
 
 
 
 
 
52
  if not prompt or not prompt.strip():
53
- return "❌ Enter a prompt."
54
- # No model hereβ€”just echo to prove callback path works.
55
- return f"πŸ€– MOCK RESPONSE\nModelFolder: {path}\nPrompt: {prompt.strip()[:120]}"
56
 
57
- with gr.Blocks(title="Callback Sanity Check") as demo:
58
- gr.Markdown("## πŸ”§ Minimal Callback Test\nIf these buttons do nothing, the issue is front-end/runtime (not your code).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- with gr.Row():
61
- ping_btn = gr.Button("πŸ”” Ping")
62
- stream_btn = gr.Button("πŸ“‘ Stream Test")
63
- out = gr.Textbox(label="Output", lines=8)
64
 
65
- gr.Markdown("### Model list mock (folders under runs/)")
66
- with gr.Row():
67
- refresh_btn = gr.Button("↻ Refresh List")
68
- model_dd = gr.Dropdown(choices=list_models(), label="Available Folders", interactive=True)
69
 
70
- load_btn = gr.Button("πŸ“¦ Load Folder (mock)")
71
- load_status = gr.Textbox(label="Load Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- prompt = gr.Textbox(label="Prompt", lines=4, placeholder="Type anything…")
74
- gen_btn = gr.Button("Generate (mock)")
75
- gen_out = gr.Textbox(label="Generated Text", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Wiring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ping_btn.click(ping, outputs=out)
79
- stream_btn.click(stream, outputs=out)
80
- refresh_btn.click(refresh_models, outputs=model_dd)
81
- load_btn.click(load_model, inputs=model_dd, outputs=load_status)
82
- gen_btn.click(generate_once, inputs=[model_dd, prompt], outputs=gen_out)
 
83
 
84
- demo.queue(default_concurrency_limit=1)
85
- demo.launch(ssr_mode=False, show_error=True)
 
 
1
+ # app.py
2
+ import os, shutil, subprocess, zipfile, traceback, io
3
  from pathlib import Path
4
+ from datetime import datetime
5
  import gradio as gr
6
 
7
+ # ----------------- Paths -----------------
8
  ROOT = Path(__file__).resolve().parent
9
+ DATA = ROOT / "dataset.jsonl"
10
+ LOG = ROOT / "train.log"
11
  RUNS = ROOT / "runs"
12
  RUNS.mkdir(exist_ok=True)
13
 
14
+ # ----------------- Logging -----------------
15
  def append_log(msg: str):
16
+ msg = (msg or "").rstrip("\n")
17
  try:
18
+ with open(LOG, "a", encoding="utf-8") as lf:
19
+ lf.write(msg + "\n")
20
  except Exception:
21
  pass
22
 
23
+ def read_logs():
24
+ return LOG.read_text(encoding="utf-8")[-20000:] if LOG.exists() else "⏳ Waiting…"
 
25
 
26
+ # ----------------- Workspace & Models -----------------
27
+ def ls_workspace() -> str:
28
+ rows = []
29
+ for p in sorted(ROOT.iterdir(), key=lambda x: (x.is_file(), x.name.lower())):
30
+ try:
31
+ size = p.stat().st_size
32
+ except Exception:
33
+ size = 0
34
+ rows.append(f"{'[DIR]' if p.is_dir() else ' '}\t{size:>10}\t{p.name}")
35
+ return "\n".join(rows) or "(empty)"
36
 
37
  def list_models():
38
+ out = []
39
+ for base in [ROOT, RUNS]:
40
+ if not base.exists():
41
+ continue
42
+ for p in base.iterdir():
43
+ if p.is_dir() and (p / "config.json").exists() and (
44
+ (p / "tokenizer.json").exists() or (p / "tokenizer_config.json").exists()
45
+ ):
46
+ out.append(str(p))
47
+ return sorted(set(out))
48
+
49
+ def dropdown_update_safe(models, prefer=None):
50
+ val = prefer if (prefer and prefer in models) else (models[0] if models else None)
51
+ return gr.update(choices=models, value=val)
52
+
53
+ # ----------------- Dataset Upload -----------------
54
+ def upload_dataset(file):
55
+ append_log("πŸ“₯ upload_dataset clicked")
56
+ if not file:
57
+ return "❌ No file selected.", ls_workspace()
58
+ if hasattr(file, "name") and os.path.isfile(file.name):
59
+ shutil.copy(file.name, DATA)
60
+ return f"βœ… Uploaded β†’ {DATA.name}", ls_workspace()
61
+ return "⚠ Unexpected item; please upload a .jsonl file.", ls_workspace()
62
+
63
+ # ----------------- Training (Live Logs) -----------------
64
+ def start_training_live(run_name):
65
+ append_log("πŸš€ start_training_live clicked")
66
+ if not DATA.exists():
67
+ msg = "❌ dataset.jsonl not found. Upload a JSONL dataset first."
68
+ append_log(msg)
69
+ yield (msg, gr.update(value=None, visible=False), ls_workspace(), read_logs(), dropdown_update_safe(list_models()))
70
+ return
71
+
72
+ run_id = (run_name or "").strip() or datetime.now().strftime("run_%Y%m%d_%H%M%S")
73
+ out_dir = RUNS / run_id
74
+ zip_path = RUNS / f"{run_id}.zip"
75
+
76
+ # clean only this run
77
+ if out_dir.exists():
78
+ shutil.rmtree(out_dir, ignore_errors=True)
79
+ if zip_path.exists():
80
+ zip_path.unlink()
81
+
82
+ # init log
83
+ LOG.write_text(f"πŸ”₯ Training started…\nRun: {run_id}\n", encoding="utf-8")
84
+ append_log(f"Workspace:\n{ls_workspace()}")
85
+
86
+ cmd = [
87
+ "python", str(ROOT / "train.py"),
88
+ "--dataset", str(DATA),
89
+ "--output", str(out_dir),
90
+ "--zip_path", str(zip_path),
91
+ "--model_name", "Salesforce/codegen-350M-multi",
92
+ "--epochs", "1",
93
+ "--batch_size", "2",
94
+ "--block_size", "256",
95
+ "--learning_rate", "5e-5",
96
+ ]
97
+ append_log("β–Ά " + " ".join(cmd))
98
+
99
+ # start subprocess with live stdout
100
+ try:
101
+ proc = subprocess.Popen(
102
+ cmd,
103
+ stdout=subprocess.PIPE,
104
+ stderr=subprocess.STDOUT,
105
+ bufsize=1,
106
+ universal_newlines=True,
107
+ encoding="utf-8",
108
+ errors="replace",
109
+ )
110
+ except Exception as e:
111
+ err = "❌ Failed to start train.py: " + "".join(traceback.format_exception_only(type(e), e))
112
+ append_log(err)
113
+ yield (err, gr.update(value=None, visible=False), ls_workspace(), read_logs(), dropdown_update_safe(list_models()))
114
+ return
115
+
116
+ live_log = io.StringIO()
117
+ status_msg = f"πŸš€ Training run '{run_id}' in progress…"
118
+
119
+ # stream loop
120
+ while True:
121
+ line = proc.stdout.readline()
122
+ if line == "" and proc.poll() is not None:
123
+ break
124
+ if line:
125
+ append_log(line.rstrip("\n"))
126
+ live_log.write(line)
127
+ text = live_log.getvalue()[-20000:]
128
+ yield (
129
+ status_msg,
130
+ gr.update(value=None, visible=False),
131
+ ls_workspace(),
132
+ text,
133
+ dropdown_update_safe(list_models(), prefer=None),
134
+ )
135
+ if zip_path.exists():
136
+ yield (
137
+ "πŸ“¦ Model zip created during run.",
138
+ gr.update(value=str(zip_path), visible=True),
139
+ ls_workspace(),
140
+ text,
141
+ dropdown_update_safe(list_models(), prefer=None),
142
+ )
143
+
144
+ code = proc.wait()
145
+
146
+ models = list_models()
147
+ model_update = dropdown_update_safe(models, prefer=str(out_dir) if out_dir.exists() else None)
148
+ final_logs = read_logs()
149
+
150
+ if code == 0 and zip_path.exists():
151
+ info = f"βœ… Training complete. Saved: {out_dir.name} | Zip: {zip_path.name}"
152
+ append_log(info)
153
+ yield (info, gr.update(value=str(zip_path), visible=True), ls_workspace(), final_logs, model_update)
154
+ else:
155
+ info = f"❌ Training failed (exit {code}). Check logs below."
156
+ append_log(info)
157
+ yield (info, gr.update(value=None, visible=False), ls_workspace(), final_logs, model_update)
158
+
159
+ def refresh_download():
160
+ append_log("↻ refresh_download clicked")
161
+ zips = sorted(RUNS.glob("*.zip"), key=lambda p: p.stat().st_mtime, reverse=True)
162
+ latest = zips[0] if zips else None
163
+ models = list_models()
164
+ return (
165
+ gr.update(value=(str(latest) if latest else None), visible=bool(latest)),
166
+ ls_workspace(),
167
+ dropdown_update_safe(models)
168
+ )
169
+
170
+ # ----------------- Import a Zip as Model Folder -----------------
171
+ def import_zip(zfile):
172
+ append_log("πŸ“¦ import_zip clicked")
173
+ if not zfile:
174
+ return "❌ No zip selected.", list_models()
175
+ dest = ROOT / "imported_model"
176
+ if dest.exists():
177
+ shutil.rmtree(dest, ignore_errors=True)
178
+ dest.mkdir(parents=True, exist_ok=True)
179
+ with zipfile.ZipFile(zfile.name, "r") as z:
180
+ z.extractall(dest)
181
+ return f"βœ… Imported to {dest.name}", list_models()
182
+
183
+ # ----------------- Generation (cached pipeline) -----------------
184
+ _GEN_CACHE = {"path": None, "pipe": None}
185
+
186
+ def get_generation_pipeline(model_path: str):
187
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
188
+ import torch
189
+
190
+ if _GEN_CACHE["path"] == model_path and _GEN_CACHE["pipe"] is not None:
191
+ return _GEN_CACHE["pipe"]
192
+
193
+ append_log(f"🧩 Loading pipeline from: {model_path}")
194
+ tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
195
+ if tok.pad_token_id is None:
196
+ if tok.eos_token_id is not None:
197
+ tok.pad_token = tok.eos_token
198
+ append_log("β„Ή No pad_token; using eos_token as pad_token.")
199
+ else:
200
+ tok.add_special_tokens({"pad_token": "[PAD]"})
201
+ append_log("β„Ή Added [PAD] token to tokenizer.")
202
+ model = AutoModelForCausalLM.from_pretrained(model_path)
203
+ if getattr(model, "config", None) and getattr(model.config, "vocab_size", None) and len(tok) > model.config.vocab_size:
204
+ model.resize_token_embeddings(len(tok))
205
+ append_log(f"β„Ή Resized embeddings to {len(tok)}.")
206
+
207
+ pipe = pipeline(
208
+ "text-generation",
209
+ model=model,
210
+ tokenizer=tok,
211
+ device_map="auto" if torch.cuda.is_available() else None,
212
+ )
213
+ _GEN_CACHE["path"] = model_path
214
+ _GEN_CACHE["pipe"] = pipe
215
+ append_log("βœ… Pipeline loaded.")
216
+ return pipe
217
+
218
+ # ----------------- Test Tab Helpers -----------------
219
+ def ping():
220
+ append_log("πŸ”” Ping pressed (UI wiring OK)")
221
+ return "βœ… UI is connected and responding."
222
+
223
+ def load_selected_model(model_path):
224
+ append_log("πŸ“¦ load_selected_model clicked")
225
+ # Dropdown may pass a list; coerce to string
226
+ if isinstance(model_path, list):
227
+ model_path = model_path[0] if model_path else None
228
+ if not model_path:
229
+ return "❌ Select a model first."
230
+ if not isinstance(model_path, str):
231
+ return f"❌ Invalid model path type: {type(model_path)._name_}"
232
+ p = Path(model_path)
233
  if not p.exists() or not p.is_dir():
234
+ return f"❌ Model folder not found: {model_path}"
235
+ try:
236
+ append_log(f"πŸ“¦ Load request β†’ {model_path}")
237
+ _ = get_generation_pipeline(model_path)
238
+ append_log(f"βœ… Loaded pipeline: {model_path}")
239
+ return f"βœ… Loaded: {model_path}"
240
+ except Exception as e:
241
+ tb = traceback.format_exc()
242
+ append_log("❌ Load error:\n" + tb)
243
+ return "❌ Error while loading model:\n" + "".join(traceback.format_exception_only(type(e), e))
244
+
245
+ def generate_once(model_path, prompt):
246
+ """Non-streaming fallback."""
247
+ append_log("β–Ά generate_once clicked")
248
+ # Coerce
249
+ if isinstance(model_path, list):
250
+ model_path = model_path[0] if model_path else None
251
 
252
+ # validate
253
+ if not model_path:
254
+ msg = "❌ Select a model from the dropdown first."
255
+ append_log(msg); return msg
256
+ if not isinstance(model_path, str):
257
+ msg = f"❌ Invalid model path type: {type(model_path)._name_}"
258
+ append_log(msg); return msg
259
+ if not Path(model_path).exists():
260
+ msg = f"❌ Model folder not found: {model_path}"
261
+ append_log(msg); return msg
262
  if not prompt or not prompt.strip():
263
+ msg = "❌ Enter a prompt."
264
+ append_log(msg); return msg
 
265
 
266
+ try:
267
+ pipe = get_generation_pipeline(model_path)
268
+ append_log(f"πŸ“ Generating once… prompt_len={len(prompt)}")
269
+ result = pipe(
270
+ prompt.strip(),
271
+ max_new_tokens=80,
272
+ do_sample=True,
273
+ temperature=0.3,
274
+ top_p=0.9,
275
+ repetition_penalty=1.15,
276
+ no_repeat_ngram_size=4,
277
+ truncation=True,
278
+ return_full_text=True,
279
+ )
280
+ text = result[0].get("generated_text", "")
281
+ if not text:
282
+ append_log("⚠ Empty generated_text")
283
+ return "⚠ Model returned empty text. Try lowering temperature or adding more context."
284
+ append_log("βœ… Generation OK.")
285
+ return text
286
+ except Exception as e:
287
+ tb = traceback.format_exc()
288
+ append_log("❌ Generation error:\n" + tb)
289
+ return "❌ Error during generation:\n" + "".join(traceback.format_exception_only(type(e), e))
290
 
291
+ def generate_stream(model_path, prompt):
292
+ """Streaming version (if frontend streaming is healthy)."""
293
+ yield "⏳ Loading model…"
294
+ append_log("β–Ά generate_stream clicked")
295
 
296
+ # Coerce
297
+ if isinstance(model_path, list):
298
+ model_path = model_path[0] if model_path else None
 
299
 
300
+ # validate
301
+ if not model_path:
302
+ msg = "❌ Select a model from the dropdown first."
303
+ append_log(msg); yield msg; return
304
+ if not isinstance(model_path, str):
305
+ msg = f"❌ Invalid model path type: {type(model_path)._name_}"
306
+ append_log(msg); yield msg; return
307
+ if not Path(model_path).exists():
308
+ msg = f"❌ Model folder not found: {model_path}"
309
+ append_log(msg); yield msg; return
310
+ if not prompt or not prompt.strip():
311
+ msg = "❌ Enter a prompt."
312
+ append_log(msg); yield msg; return
313
 
314
+ try:
315
+ pipe = get_generation_pipeline(model_path)
316
+ yield "βš™ Generating… (this may take a bit on CPU)"
317
+ append_log(f"πŸ“ Generating (stream)… prompt_len={len(prompt)}")
318
+ result = pipe(
319
+ prompt.strip(),
320
+ max_new_tokens=80,
321
+ do_sample=True,
322
+ temperature=0.3,
323
+ top_p=0.9,
324
+ repetition_penalty=1.15,
325
+ no_repeat_ngram_size=4,
326
+ truncation=True,
327
+ return_full_text=True,
328
+ )
329
+ text = result[0].get("generated_text", "")
330
+ if not text:
331
+ append_log("⚠ Empty generated_text")
332
+ yield "⚠ Model returned empty text. Try lowering temperature or adding more context."
333
+ return
334
+ append_log("βœ… Generation OK.")
335
+ yield text
336
+ except Exception as e:
337
+ tb = traceback.format_exc()
338
+ append_log("❌ Generation error:\n" + tb)
339
+ yield "❌ Error during generation:\n" + "".join(traceback.format_exception_only(type(e), e))
340
 
341
+ # ----------------- UI -----------------
342
+ with gr.Blocks(title="Python AI β€” Train & Test") as app:
343
+ gr.Markdown("## 🧠 Python AI β€” Train & Test\nβ€’ Unique runs β€’ Safe download β€’ Cached generation β€’ Live logs\n")
344
+
345
+ # ---------- Test Tab ----------
346
+ with gr.Tab("Test"):
347
+ gr.Markdown("### Choose a model folder or upload a .zip, then prompt it")
348
+ with gr.Row():
349
+ refresh_btn = gr.Button("↻ Refresh Model List")
350
+ ping_btn = gr.Button("πŸ”” Ping UI") # sanity check
351
+ model_list = gr.Dropdown(
352
+ choices=list_models(),
353
+ label="Available AIs",
354
+ interactive=True,
355
+ allow_custom_value=True,
356
+ multiselect=False
357
+ )
358
+ load_btn = gr.Button("πŸ“¦ Load Model")
359
+ load_status = gr.Textbox(label="Model Status", interactive=False)
360
+
361
+ zip_in = gr.File(label="Or upload a model .zip", file_types=[".zip"])
362
+ import_status = gr.Textbox(label="Import Status", interactive=False)
363
+
364
+ prompt = gr.Textbox(
365
+ label="Prompt",
366
+ lines=8,
367
+ placeholder="### Instruction:\nPython: write a function ...\n### Response:\n"
368
+ )
369
+ with gr.Row():
370
+ go_stream = gr.Button("Generate (stream)")
371
+ go_once = gr.Button("Generate (once)")
372
+ out = gr.Textbox(label="AI Response", lines=20)
373
+
374
+ # ---------- Train Tab ----------
375
+ with gr.Tab("Train"):
376
+ with gr.Row():
377
+ ds = gr.File(label="πŸ“₯ Upload JSONL", file_types=[".jsonl"])
378
+ ws = gr.Textbox(label="Workspace", lines=16, value=ls_workspace())
379
+ run_name = gr.Textbox(label="Run name (optional)", placeholder="e.g., python_small_v1")
380
+ up_status = gr.Textbox(label="Upload Status", interactive=False)
381
+ start = gr.Button("πŸš€ Start Training (Live Logs)", variant="primary")
382
+ logs = gr.Textbox(label="πŸ“œ Training Logs (live)", lines=18, value=read_logs())
383
+ status = gr.Textbox(label="Status", interactive=False)
384
+ download_file = gr.File(label="πŸ“¦ Latest trained zip", visible=False)
385
+ refresh_dl_btn = gr.Button("Refresh Download")
386
+
387
+ # ---------- Wiring ----------
388
+ ds.change(upload_dataset, inputs=ds, outputs=[up_status, ws])
389
+
390
+ start.click(
391
+ start_training_live,
392
+ inputs=[run_name],
393
+ outputs=[status, download_file, ws, logs, model_list]
394
+ )
395
+
396
+ refresh_dl_btn.click(
397
+ refresh_download,
398
+ outputs=[download_file, ws, model_list]
399
+ )
400
+
401
+ refresh_btn.click(lambda: dropdown_update_safe(list_models()), outputs=model_list)
402
  ping_btn.click(ping, outputs=out)
403
+ load_btn.click(load_selected_model, inputs=[model_list], outputs=[load_status])
404
+ zip_in.change(import_zip, inputs=zip_in, outputs=[import_status, model_list])
405
+
406
+ go_stream.click(generate_stream, inputs=[model_list, prompt], outputs=out)
407
+ go_once.click(generate_once, inputs=[model_list, prompt], outputs=out)
408
 
409
+ # Critical: disable SSR; ensure queue is enabled
410
+ app.queue(default_concurrency_limit=1)
411
+ app.launch(ssr_mode=False, show_error=True)