# app.py import os, shutil, subprocess, zipfile, traceback, io from pathlib import Path from datetime import datetime import gradio as gr # ----------------- Paths ----------------- ROOT = Path(__file__).resolve().parent DATA = ROOT / "dataset.jsonl" LOG = ROOT / "train.log" RUNS = ROOT / "runs" RUNS.mkdir(exist_ok=True) # ----------------- Logging ----------------- def append_log(msg: str): msg = (msg or "").rstrip("\n") try: with open(LOG, "a", encoding="utf-8") as lf: lf.write(msg + "\n") except Exception: pass def read_logs(): return LOG.read_text(encoding="utf-8")[-20000:] if LOG.exists() else "⏳ Waiting…" # ----------------- Workspace & Models ----------------- def ls_workspace() -> str: rows = [] for p in sorted(ROOT.iterdir(), key=lambda x: (x.is_file(), x.name.lower())): try: size = p.stat().st_size except Exception: size = 0 rows.append(f"{'[DIR]' if p.is_dir() else ' '}\t{size:>10}\t{p.name}") return "\n".join(rows) or "(empty)" def list_models(): out = [] for base in [ROOT, RUNS]: if not base.exists(): continue for p in base.iterdir(): if p.is_dir() and (p / "config.json").exists() and ( (p / "tokenizer.json").exists() or (p / "tokenizer_config.json").exists() ): out.append(str(p)) return sorted(set(out)) def dropdown_update_safe(models, prefer=None): val = prefer if (prefer and prefer in models) else (models[0] if models else None) return gr.update(choices=models, value=val) # ----------------- Dataset Upload ----------------- def upload_dataset(file): append_log("πŸ“₯ upload_dataset clicked") if not file: return "❌ No file selected.", ls_workspace() if hasattr(file, "name") and os.path.isfile(file.name): shutil.copy(file.name, DATA) return f"βœ… Uploaded β†’ {DATA.name}", ls_workspace() return "⚠ Unexpected item; please upload a .jsonl file.", ls_workspace() # ----------------- Training (Live Logs) ----------------- def start_training_live(run_name): append_log("πŸš€ start_training_live clicked") if not DATA.exists(): msg = "❌ dataset.jsonl not found. Upload a JSONL dataset first." append_log(msg) yield (msg, gr.update(value=None, visible=False), ls_workspace(), read_logs(), dropdown_update_safe(list_models())) return run_id = (run_name or "").strip() or datetime.now().strftime("run_%Y%m%d_%H%M%S") out_dir = RUNS / run_id zip_path = RUNS / f"{run_id}.zip" # clean only this run if out_dir.exists(): shutil.rmtree(out_dir, ignore_errors=True) if zip_path.exists(): zip_path.unlink() # init log LOG.write_text(f"πŸ”₯ Training started…\nRun: {run_id}\n", encoding="utf-8") append_log(f"Workspace:\n{ls_workspace()}") cmd = [ "python", str(ROOT / "train.py"), "--dataset", str(DATA), "--output", str(out_dir), "--zip_path", str(zip_path), "--model_name", "Salesforce/codegen-350M-multi", "--epochs", "1", "--batch_size", "2", "--block_size", "256", "--learning_rate", "5e-5", ] append_log("β–Ά " + " ".join(cmd)) # start subprocess with live stdout try: proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True, encoding="utf-8", errors="replace", ) except Exception as e: err = "❌ Failed to start train.py: " + "".join(traceback.format_exception_only(type(e), e)) append_log(err) yield (err, gr.update(value=None, visible=False), ls_workspace(), read_logs(), dropdown_update_safe(list_models())) return live_log = io.StringIO() status_msg = f"πŸš€ Training run '{run_id}' in progress…" # stream loop while True: line = proc.stdout.readline() if line == "" and proc.poll() is not None: break if line: append_log(line.rstrip("\n")) live_log.write(line) text = live_log.getvalue()[-20000:] yield ( status_msg, gr.update(value=None, visible=False), ls_workspace(), text, dropdown_update_safe(list_models(), prefer=None), ) if zip_path.exists(): yield ( "πŸ“¦ Model zip created during run.", gr.update(value=str(zip_path), visible=True), ls_workspace(), text, dropdown_update_safe(list_models(), prefer=None), ) code = proc.wait() models = list_models() model_update = dropdown_update_safe(models, prefer=str(out_dir) if out_dir.exists() else None) final_logs = read_logs() if code == 0 and zip_path.exists(): info = f"βœ… Training complete. Saved: {out_dir.name} | Zip: {zip_path.name}" append_log(info) yield (info, gr.update(value=str(zip_path), visible=True), ls_workspace(), final_logs, model_update) else: info = f"❌ Training failed (exit {code}). Check logs below." append_log(info) yield (info, gr.update(value=None, visible=False), ls_workspace(), final_logs, model_update) def refresh_download(): append_log("↻ refresh_download clicked") zips = sorted(RUNS.glob("*.zip"), key=lambda p: p.stat().st_mtime, reverse=True) latest = zips[0] if zips else None models = list_models() return ( gr.update(value=(str(latest) if latest else None), visible=bool(latest)), ls_workspace(), dropdown_update_safe(models) ) # ----------------- Import a Zip as Model Folder ----------------- def import_zip(zfile): append_log("πŸ“¦ import_zip clicked") if not zfile: return "❌ No zip selected.", list_models() dest = ROOT / "imported_model" if dest.exists(): shutil.rmtree(dest, ignore_errors=True) dest.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(zfile.name, "r") as z: z.extractall(dest) return f"βœ… Imported to {dest.name}", list_models() # ----------------- Generation (cached pipeline) ----------------- _GEN_CACHE = {"path": None, "pipe": None} def get_generation_pipeline(model_path: str): from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch if _GEN_CACHE["path"] == model_path and _GEN_CACHE["pipe"] is not None: return _GEN_CACHE["pipe"] append_log(f"🧩 Loading pipeline from: {model_path}") tok = AutoTokenizer.from_pretrained(model_path, use_fast=True) if tok.pad_token_id is None: if tok.eos_token_id is not None: tok.pad_token = tok.eos_token append_log("β„Ή No pad_token; using eos_token as pad_token.") else: tok.add_special_tokens({"pad_token": "[PAD]"}) append_log("β„Ή Added [PAD] token to tokenizer.") model = AutoModelForCausalLM.from_pretrained(model_path) if getattr(model, "config", None) and getattr(model.config, "vocab_size", None) and len(tok) > model.config.vocab_size: model.resize_token_embeddings(len(tok)) append_log(f"β„Ή Resized embeddings to {len(tok)}.") pipe = pipeline( "text-generation", model=model, tokenizer=tok, device_map="auto" if torch.cuda.is_available() else None, ) _GEN_CACHE["path"] = model_path _GEN_CACHE["pipe"] = pipe append_log("βœ… Pipeline loaded.") return pipe # ----------------- Test Tab Helpers ----------------- def ping(): append_log("πŸ”” Ping pressed (UI wiring OK)") return "βœ… UI is connected and responding." def load_selected_model(model_path): append_log("πŸ“¦ load_selected_model clicked") # Dropdown may pass a list; coerce to string if isinstance(model_path, list): model_path = model_path[0] if model_path else None if not model_path: return "❌ Select a model first." if not isinstance(model_path, str): return f"❌ Invalid model path type: {type(model_path)._name_}" p = Path(model_path) if not p.exists() or not p.is_dir(): return f"❌ Model folder not found: {model_path}" try: append_log(f"πŸ“¦ Load request β†’ {model_path}") _ = get_generation_pipeline(model_path) append_log(f"βœ… Loaded pipeline: {model_path}") return f"βœ… Loaded: {model_path}" except Exception as e: tb = traceback.format_exc() append_log("❌ Load error:\n" + tb) return "❌ Error while loading model:\n" + "".join(traceback.format_exception_only(type(e), e)) def generate_once(model_path, prompt): """Non-streaming fallback.""" append_log("β–Ά generate_once clicked") # Coerce if isinstance(model_path, list): model_path = model_path[0] if model_path else None # validate if not model_path: msg = "❌ Select a model from the dropdown first." append_log(msg); return msg if not isinstance(model_path, str): msg = f"❌ Invalid model path type: {type(model_path)._name_}" append_log(msg); return msg if not Path(model_path).exists(): msg = f"❌ Model folder not found: {model_path}" append_log(msg); return msg if not prompt or not prompt.strip(): msg = "❌ Enter a prompt." append_log(msg); return msg try: pipe = get_generation_pipeline(model_path) append_log(f"πŸ“ Generating once… prompt_len={len(prompt)}") result = pipe( prompt.strip(), max_new_tokens=80, do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.15, no_repeat_ngram_size=4, truncation=True, return_full_text=True, ) text = result[0].get("generated_text", "") if not text: append_log("⚠ Empty generated_text") return "⚠ Model returned empty text. Try lowering temperature or adding more context." append_log("βœ… Generation OK.") return text except Exception as e: tb = traceback.format_exc() append_log("❌ Generation error:\n" + tb) return "❌ Error during generation:\n" + "".join(traceback.format_exception_only(type(e), e)) def generate_stream(model_path, prompt): """Streaming version (if frontend streaming is healthy).""" yield "⏳ Loading model…" append_log("β–Ά generate_stream clicked") # Coerce if isinstance(model_path, list): model_path = model_path[0] if model_path else None # validate if not model_path: msg = "❌ Select a model from the dropdown first." append_log(msg); yield msg; return if not isinstance(model_path, str): msg = f"❌ Invalid model path type: {type(model_path)._name_}" append_log(msg); yield msg; return if not Path(model_path).exists(): msg = f"❌ Model folder not found: {model_path}" append_log(msg); yield msg; return if not prompt or not prompt.strip(): msg = "❌ Enter a prompt." append_log(msg); yield msg; return try: pipe = get_generation_pipeline(model_path) yield "βš™ Generating… (this may take a bit on CPU)" append_log(f"πŸ“ Generating (stream)… prompt_len={len(prompt)}") result = pipe( prompt.strip(), max_new_tokens=80, do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.15, no_repeat_ngram_size=4, truncation=True, return_full_text=True, ) text = result[0].get("generated_text", "") if not text: append_log("⚠ Empty generated_text") yield "⚠ Model returned empty text. Try lowering temperature or adding more context." return append_log("βœ… Generation OK.") yield text except Exception as e: tb = traceback.format_exc() append_log("❌ Generation error:\n" + tb) yield "❌ Error during generation:\n" + "".join(traceback.format_exception_only(type(e), e)) # ----------------- UI ----------------- with gr.Blocks(title="Python AI β€” Train & Test") as app: gr.Markdown("## 🧠 Python AI β€” Train & Test\nβ€’ Unique runs β€’ Safe download β€’ Cached generation β€’ Live logs\n") # ---------- Test Tab ---------- with gr.Tab("Test"): gr.Markdown("### Choose a model folder or upload a .zip, then prompt it") with gr.Row(): refresh_btn = gr.Button("↻ Refresh Model List") ping_btn = gr.Button("πŸ”” Ping UI") # sanity check model_list = gr.Dropdown( choices=list_models(), label="Available AIs", interactive=True, allow_custom_value=True, multiselect=False ) load_btn = gr.Button("πŸ“¦ Load Model") load_status = gr.Textbox(label="Model Status", interactive=False) zip_in = gr.File(label="Or upload a model .zip", file_types=[".zip"]) import_status = gr.Textbox(label="Import Status", interactive=False) prompt = gr.Textbox( label="Prompt", lines=8, placeholder="### Instruction:\nPython: write a function ...\n### Response:\n" ) with gr.Row(): go_stream = gr.Button("Generate (stream)") go_once = gr.Button("Generate (once)") out = gr.Textbox(label="AI Response", lines=20) # ---------- Train Tab ---------- with gr.Tab("Train"): with gr.Row(): ds = gr.File(label="πŸ“₯ Upload JSONL", file_types=[".jsonl"]) ws = gr.Textbox(label="Workspace", lines=16, value=ls_workspace()) run_name = gr.Textbox(label="Run name (optional)", placeholder="e.g., python_small_v1") up_status = gr.Textbox(label="Upload Status", interactive=False) start = gr.Button("πŸš€ Start Training (Live Logs)", variant="primary") logs = gr.Textbox(label="πŸ“œ Training Logs (live)", lines=18, value=read_logs()) status = gr.Textbox(label="Status", interactive=False) download_file = gr.File(label="πŸ“¦ Latest trained zip", visible=False) refresh_dl_btn = gr.Button("Refresh Download") # ---------- Wiring ---------- ds.change(upload_dataset, inputs=ds, outputs=[up_status, ws]) start.click( start_training_live, inputs=[run_name], outputs=[status, download_file, ws, logs, model_list] ) refresh_dl_btn.click( refresh_download, outputs=[download_file, ws, model_list] ) refresh_btn.click(lambda: dropdown_update_safe(list_models()), outputs=model_list) ping_btn.click(ping, outputs=out) load_btn.click(load_selected_model, inputs=[model_list], outputs=[load_status]) zip_in.change(import_zip, inputs=zip_in, outputs=[import_status, model_list]) go_stream.click(generate_stream, inputs=[model_list, prompt], outputs=out) go_once.click(generate_once, inputs=[model_list, prompt], outputs=out) # Critical: disable SSR; ensure queue is enabled app.queue(default_concurrency_limit=1) app.launch(ssr_mode=False, show_error=True)