import os, io, zipfile, shutil, subprocess, json, time, glob, tempfile import gradio as gr from pathlib import Path from typing import List, Tuple WORKDIR = Path(".") DATASET_PATH = WORKDIR / "dataset.jsonl" LOG_PATH = WORKDIR / "train.log" MODEL_DIR = WORKDIR / "trained_model" # training output folder ZIP_PATH = WORKDIR / "trained_model.zip" # zipped after train MODELS_ROOT = WORKDIR # where we scan for saved AIs # ---------- helpers ---------- def _safe_unzip(zip_file: str, out_dir: Path) -> str: out_dir.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(zip_file, "r") as z: z.extractall(out_dir) # return the inner model folder if zip contained a single directory subdirs = [p for p in out_dir.iterdir() if p.is_dir()] return str(subdirs[0] if len(subdirs) == 1 else out_dir) def _list_local_models() -> List[str]: """ Return model folders found under MODELS_ROOT that look like HF models. We include any folder that has a tokenizer.json or tokenizer_config.json. """ candidates = [] for p in MODELS_ROOT.iterdir(): if not p.is_dir(): continue if (p / "tokenizer.json").exists() or (p / "tokenizer_config.json").exists(): candidates.append(str(p)) return sorted(candidates) def _start_training_subprocess() -> int: # clear old outputs if MODEL_DIR.exists(): shutil.rmtree(MODEL_DIR) if ZIP_PATH.exists(): ZIP_PATH.unlink(missing_ok=True) cmd = [ "python", "train.py", "--dataset", str(DATASET_PATH), "--output", str(MODEL_DIR), # sensible defaults for quick, real training; adjust in train.py if needed "--model_name", "Salesforce/codegen-350M-multi", "--epochs", "1", "--batch_size", "2", "--block_size", "256", "--learning_rate", "5e-5", ] LOG_PATH.write_text("🔥 Starting training...\n", encoding="utf-8") with open(LOG_PATH, "a", encoding="utf-8") as lf: proc = subprocess.Popen(cmd, stdout=lf, stderr=subprocess.STDOUT) return proc.wait() def _zip_model_folder() -> bool: if not MODEL_DIR.exists(): return False if ZIP_PATH.exists(): ZIP_PATH.unlink() shutil.make_archive(ZIP_PATH.with_suffix("").as_posix(), "zip", MODEL_DIR) return ZIP_PATH.exists() # ---------- UI callbacks ---------- def upload_dataset(file) -> str: if file is None: return "❌ No file selected." shutil.copy(file.name, DATASET_PATH) return f"✅ Uploaded: {file.name} → {DATASET_PATH.name}" def start_training() -> Tuple[str, str, gr.File]: if not DATASET_PATH.exists(): return ("❌ Please upload a JSONL first.", "", gr.File.update(visible=False)) exit_code = _start_training_subprocess() # after training, try to zip and expose if exit_code == 0 and _zip_model_folder(): status = "✅ Training complete." model_info = f"Saved: {MODEL_DIR.name} | Zip: {ZIP_PATH.name}" return (status, model_info, gr.File.update(value=str(ZIP_PATH), visible=True)) else: # surface the tail of the log for quick diagnosis tail = "" if LOG_PATH.exists(): with open(LOG_PATH, "r", encoding="utf-8") as f: lines = f.readlines()[-30:] tail = "".join(lines) return (f"❌ Training failed (code {exit_code}).", tail, gr.File.update(visible=False)) def read_logs() -> str: if LOG_PATH.exists(): return LOG_PATH.read_text(encoding="utf-8")[-20_000:] # last ~20k chars return "⏳ Waiting for logs..." def refresh_model_list() -> List[str]: return _list_local_models() def upload_model_zip(zip_file) -> Tuple[str, List[str]]: if zip_file is None: return "❌ No zip provided.", refresh_model_list() out = WORKDIR / f"imported_{int(time.time())}" path = _safe_unzip(zip_file.name, out) msg = f"✅ Imported model at: {path}" return msg, refresh_model_list() def generate(model_path: str, prompt: str) -> str: if not model_path: return "❌ Select a model." try: from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline tok = AutoTokenizer.from_pretrained(model_path, use_fast=True) if tok.pad_token_id is None and tok.eos_token_id is not None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained(model_path) gen = pipeline("text-generation", model=model, tokenizer=tok) # decoding tuned for code out = gen( prompt, max_new_tokens=220, do_sample=True, temperature=0.2, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=4, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, truncation=True, )[0]["generated_text"] return out except Exception as e: return f"❌ Error: {e}" # ---------- UI ---------- with gr.Blocks(title="Python AI Trainer") as demo: gr.Markdown("## 🧠 Python AI Trainer\nUpload JSONL, train, then test your model.") with gr.Tab("Train"): file_in = gr.File(label="📥 Upload JSONL Dataset", file_types=[".jsonl", ".jsonl.gz", ".json"]) up_status = gr.Textbox(label="Upload Status", interactive=False) start_btn = gr.Button("🚀 Start Training", variant="primary") logs_box = gr.Textbox(label="📜 Live Logs (click Refresh)", lines=16) refresh_logs = gr.Button("Refresh Logs") status_box = gr.Textbox(label="Status", interactive=False) model_info = gr.Textbox(label="Model Output", interactive=False) dl = gr.File(label="📦 Download Trained Model (.zip)", visible=False) refresh_dl = gr.Button("Refresh Download Area") file_in.change(fn=upload_dataset, inputs=file_in, outputs=up_status) start_btn.click(fn=start_training, outputs=[status_box, model_info, dl]) refresh_logs.click(fn=read_logs, outputs=logs_box) refresh_dl.click(fn=lambda: (gr.File.update(value=str(ZIP_PATH), visible=ZIP_PATH.exists())), outputs=dl) with gr.Tab("Test"): gr.Markdown("### 🔬 Choose a stored AI and prompt it") refresh_models_btn = gr.Button("↻ Refresh AI List") model_list = gr.Dropdown(choices=_list_local_models(), label="Available AIs", interactive=True) up_zip = gr.File(label="Or upload a model .zip to test", file_types=[".zip"]) zip_status = gr.Textbox(label="Model Import Status", interactive=False) prompt = gr.Textbox(label="Prompt", lines=6, placeholder="### Instruction:\nPython: write a function ...\n### Response:\n") generate_btn = gr.Button("Generate") output = gr.Textbox(label="AI Response", lines=20) refresh_models_btn.click(fn=refresh_model_list, outputs=model_list) up_zip.change(fn=upload_model_zip, inputs=up_zip, outputs=[zip_status, model_list]) generate_btn.click(fn=generate, inputs=[model_list, prompt], outputs=output) demo.launch()