Percy3822's picture
Update app.py
3da5a15 verified
# app.py
import os, shutil, subprocess, zipfile, traceback, io
from pathlib import Path
from datetime import datetime
import gradio as gr
# ----------------- Paths -----------------
ROOT = Path(__file__).resolve().parent
DATA = ROOT / "dataset.jsonl"
LOG = ROOT / "train.log"
RUNS = ROOT / "runs"
RUNS.mkdir(exist_ok=True)
# ----------------- Logging -----------------
def append_log(msg: str):
msg = (msg or "").rstrip("\n")
try:
with open(LOG, "a", encoding="utf-8") as lf:
lf.write(msg + "\n")
except Exception:
pass
def read_logs():
return LOG.read_text(encoding="utf-8")[-20000:] if LOG.exists() else "⏳ Waiting…"
# ----------------- Workspace & Models -----------------
def ls_workspace() -> str:
rows = []
for p in sorted(ROOT.iterdir(), key=lambda x: (x.is_file(), x.name.lower())):
try:
size = p.stat().st_size
except Exception:
size = 0
rows.append(f"{'[DIR]' if p.is_dir() else ' '}\t{size:>10}\t{p.name}")
return "\n".join(rows) or "(empty)"
def list_models():
out = []
for base in [ROOT, RUNS]:
if not base.exists():
continue
for p in base.iterdir():
if p.is_dir() and (p / "config.json").exists() and (
(p / "tokenizer.json").exists() or (p / "tokenizer_config.json").exists()
):
out.append(str(p))
return sorted(set(out))
def dropdown_update_safe(models, prefer=None):
val = prefer if (prefer and prefer in models) else (models[0] if models else None)
return gr.update(choices=models, value=val)
# ----------------- Dataset Upload -----------------
def upload_dataset(file):
append_log("📥 upload_dataset clicked")
if not file:
return "❌ No file selected.", ls_workspace()
if hasattr(file, "name") and os.path.isfile(file.name):
shutil.copy(file.name, DATA)
return f"✅ Uploaded → {DATA.name}", ls_workspace()
return "⚠ Unexpected item; please upload a .jsonl file.", ls_workspace()
# ----------------- Training (Live Logs) -----------------
def start_training_live(run_name):
append_log("🚀 start_training_live clicked")
if not DATA.exists():
msg = "❌ dataset.jsonl not found. Upload a JSONL dataset first."
append_log(msg)
yield (msg, gr.update(value=None, visible=False), ls_workspace(), read_logs(), dropdown_update_safe(list_models()))
return
run_id = (run_name or "").strip() or datetime.now().strftime("run_%Y%m%d_%H%M%S")
out_dir = RUNS / run_id
zip_path = RUNS / f"{run_id}.zip"
# clean only this run
if out_dir.exists():
shutil.rmtree(out_dir, ignore_errors=True)
if zip_path.exists():
zip_path.unlink()
# init log
LOG.write_text(f"🔥 Training started…\nRun: {run_id}\n", encoding="utf-8")
append_log(f"Workspace:\n{ls_workspace()}")
cmd = [
"python", str(ROOT / "train.py"),
"--dataset", str(DATA),
"--output", str(out_dir),
"--zip_path", str(zip_path),
"--model_name", "Salesforce/codegen-350M-multi",
"--epochs", "1",
"--batch_size", "2",
"--block_size", "256",
"--learning_rate", "5e-5",
]
append_log("▶ " + " ".join(cmd))
# start subprocess with live stdout
try:
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
universal_newlines=True,
encoding="utf-8",
errors="replace",
)
except Exception as e:
err = "❌ Failed to start train.py: " + "".join(traceback.format_exception_only(type(e), e))
append_log(err)
yield (err, gr.update(value=None, visible=False), ls_workspace(), read_logs(), dropdown_update_safe(list_models()))
return
live_log = io.StringIO()
status_msg = f"🚀 Training run '{run_id}' in progress…"
# stream loop
while True:
line = proc.stdout.readline()
if line == "" and proc.poll() is not None:
break
if line:
append_log(line.rstrip("\n"))
live_log.write(line)
text = live_log.getvalue()[-20000:]
yield (
status_msg,
gr.update(value=None, visible=False),
ls_workspace(),
text,
dropdown_update_safe(list_models(), prefer=None),
)
if zip_path.exists():
yield (
"📦 Model zip created during run.",
gr.update(value=str(zip_path), visible=True),
ls_workspace(),
text,
dropdown_update_safe(list_models(), prefer=None),
)
code = proc.wait()
models = list_models()
model_update = dropdown_update_safe(models, prefer=str(out_dir) if out_dir.exists() else None)
final_logs = read_logs()
if code == 0 and zip_path.exists():
info = f"✅ Training complete. Saved: {out_dir.name} | Zip: {zip_path.name}"
append_log(info)
yield (info, gr.update(value=str(zip_path), visible=True), ls_workspace(), final_logs, model_update)
else:
info = f"❌ Training failed (exit {code}). Check logs below."
append_log(info)
yield (info, gr.update(value=None, visible=False), ls_workspace(), final_logs, model_update)
def refresh_download():
append_log("↻ refresh_download clicked")
zips = sorted(RUNS.glob("*.zip"), key=lambda p: p.stat().st_mtime, reverse=True)
latest = zips[0] if zips else None
models = list_models()
return (
gr.update(value=(str(latest) if latest else None), visible=bool(latest)),
ls_workspace(),
dropdown_update_safe(models)
)
# ----------------- Import a Zip as Model Folder -----------------
def import_zip(zfile):
append_log("📦 import_zip clicked")
if not zfile:
return "❌ No zip selected.", list_models()
dest = ROOT / "imported_model"
if dest.exists():
shutil.rmtree(dest, ignore_errors=True)
dest.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zfile.name, "r") as z:
z.extractall(dest)
return f"✅ Imported to {dest.name}", list_models()
# ----------------- Generation (cached pipeline) -----------------
_GEN_CACHE = {"path": None, "pipe": None}
def get_generation_pipeline(model_path: str):
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
if _GEN_CACHE["path"] == model_path and _GEN_CACHE["pipe"] is not None:
return _GEN_CACHE["pipe"]
append_log(f"🧩 Loading pipeline from: {model_path}")
tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
if tok.pad_token_id is None:
if tok.eos_token_id is not None:
tok.pad_token = tok.eos_token
append_log("ℹ No pad_token; using eos_token as pad_token.")
else:
tok.add_special_tokens({"pad_token": "[PAD]"})
append_log("ℹ Added [PAD] token to tokenizer.")
model = AutoModelForCausalLM.from_pretrained(model_path)
if getattr(model, "config", None) and getattr(model.config, "vocab_size", None) and len(tok) > model.config.vocab_size:
model.resize_token_embeddings(len(tok))
append_log(f"ℹ Resized embeddings to {len(tok)}.")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tok,
device_map="auto" if torch.cuda.is_available() else None,
)
_GEN_CACHE["path"] = model_path
_GEN_CACHE["pipe"] = pipe
append_log("✅ Pipeline loaded.")
return pipe
# ----------------- Test Tab Helpers -----------------
def ping():
append_log("🔔 Ping pressed (UI wiring OK)")
return "✅ UI is connected and responding."
def load_selected_model(model_path):
append_log("📦 load_selected_model clicked")
# Dropdown may pass a list; coerce to string
if isinstance(model_path, list):
model_path = model_path[0] if model_path else None
if not model_path:
return "❌ Select a model first."
if not isinstance(model_path, str):
return f"❌ Invalid model path type: {type(model_path)._name_}"
p = Path(model_path)
if not p.exists() or not p.is_dir():
return f"❌ Model folder not found: {model_path}"
try:
append_log(f"📦 Load request → {model_path}")
_ = get_generation_pipeline(model_path)
append_log(f"✅ Loaded pipeline: {model_path}")
return f"✅ Loaded: {model_path}"
except Exception as e:
tb = traceback.format_exc()
append_log("❌ Load error:\n" + tb)
return "❌ Error while loading model:\n" + "".join(traceback.format_exception_only(type(e), e))
def generate_once(model_path, prompt):
"""Non-streaming fallback."""
append_log("▶ generate_once clicked")
# Coerce
if isinstance(model_path, list):
model_path = model_path[0] if model_path else None
# validate
if not model_path:
msg = "❌ Select a model from the dropdown first."
append_log(msg); return msg
if not isinstance(model_path, str):
msg = f"❌ Invalid model path type: {type(model_path)._name_}"
append_log(msg); return msg
if not Path(model_path).exists():
msg = f"❌ Model folder not found: {model_path}"
append_log(msg); return msg
if not prompt or not prompt.strip():
msg = "❌ Enter a prompt."
append_log(msg); return msg
try:
pipe = get_generation_pipeline(model_path)
append_log(f"📝 Generating once… prompt_len={len(prompt)}")
result = pipe(
prompt.strip(),
max_new_tokens=80,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.15,
no_repeat_ngram_size=4,
truncation=True,
return_full_text=True,
)
text = result[0].get("generated_text", "")
if not text:
append_log("⚠ Empty generated_text")
return "⚠ Model returned empty text. Try lowering temperature or adding more context."
append_log("✅ Generation OK.")
return text
except Exception as e:
tb = traceback.format_exc()
append_log("❌ Generation error:\n" + tb)
return "❌ Error during generation:\n" + "".join(traceback.format_exception_only(type(e), e))
def generate_stream(model_path, prompt):
"""Streaming version (if frontend streaming is healthy)."""
yield "⏳ Loading model…"
append_log("▶ generate_stream clicked")
# Coerce
if isinstance(model_path, list):
model_path = model_path[0] if model_path else None
# validate
if not model_path:
msg = "❌ Select a model from the dropdown first."
append_log(msg); yield msg; return
if not isinstance(model_path, str):
msg = f"❌ Invalid model path type: {type(model_path)._name_}"
append_log(msg); yield msg; return
if not Path(model_path).exists():
msg = f"❌ Model folder not found: {model_path}"
append_log(msg); yield msg; return
if not prompt or not prompt.strip():
msg = "❌ Enter a prompt."
append_log(msg); yield msg; return
try:
pipe = get_generation_pipeline(model_path)
yield "⚙ Generating… (this may take a bit on CPU)"
append_log(f"📝 Generating (stream)… prompt_len={len(prompt)}")
result = pipe(
prompt.strip(),
max_new_tokens=80,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.15,
no_repeat_ngram_size=4,
truncation=True,
return_full_text=True,
)
text = result[0].get("generated_text", "")
if not text:
append_log("⚠ Empty generated_text")
yield "⚠ Model returned empty text. Try lowering temperature or adding more context."
return
append_log("✅ Generation OK.")
yield text
except Exception as e:
tb = traceback.format_exc()
append_log("❌ Generation error:\n" + tb)
yield "❌ Error during generation:\n" + "".join(traceback.format_exception_only(type(e), e))
# ----------------- UI -----------------
with gr.Blocks(title="Python AI — Train & Test") as app:
gr.Markdown("## 🧠 Python AI — Train & Test\n• Unique runs • Safe download • Cached generation • Live logs\n")
# ---------- Test Tab ----------
with gr.Tab("Test"):
gr.Markdown("### Choose a model folder or upload a .zip, then prompt it")
with gr.Row():
refresh_btn = gr.Button("↻ Refresh Model List")
ping_btn = gr.Button("🔔 Ping UI") # sanity check
model_list = gr.Dropdown(
choices=list_models(),
label="Available AIs",
interactive=True,
allow_custom_value=True,
multiselect=False
)
load_btn = gr.Button("📦 Load Model")
load_status = gr.Textbox(label="Model Status", interactive=False)
zip_in = gr.File(label="Or upload a model .zip", file_types=[".zip"])
import_status = gr.Textbox(label="Import Status", interactive=False)
prompt = gr.Textbox(
label="Prompt",
lines=8,
placeholder="### Instruction:\nPython: write a function ...\n### Response:\n"
)
with gr.Row():
go_stream = gr.Button("Generate (stream)")
go_once = gr.Button("Generate (once)")
out = gr.Textbox(label="AI Response", lines=20)
# ---------- Train Tab ----------
with gr.Tab("Train"):
with gr.Row():
ds = gr.File(label="📥 Upload JSONL", file_types=[".jsonl"])
ws = gr.Textbox(label="Workspace", lines=16, value=ls_workspace())
run_name = gr.Textbox(label="Run name (optional)", placeholder="e.g., python_small_v1")
up_status = gr.Textbox(label="Upload Status", interactive=False)
start = gr.Button("🚀 Start Training (Live Logs)", variant="primary")
logs = gr.Textbox(label="📜 Training Logs (live)", lines=18, value=read_logs())
status = gr.Textbox(label="Status", interactive=False)
download_file = gr.File(label="📦 Latest trained zip", visible=False)
refresh_dl_btn = gr.Button("Refresh Download")
# ---------- Wiring ----------
ds.change(upload_dataset, inputs=ds, outputs=[up_status, ws])
start.click(
start_training_live,
inputs=[run_name],
outputs=[status, download_file, ws, logs, model_list]
)
refresh_dl_btn.click(
refresh_download,
outputs=[download_file, ws, model_list]
)
refresh_btn.click(lambda: dropdown_update_safe(list_models()), outputs=model_list)
ping_btn.click(ping, outputs=out)
load_btn.click(load_selected_model, inputs=[model_list], outputs=[load_status])
zip_in.change(import_zip, inputs=zip_in, outputs=[import_status, model_list])
go_stream.click(generate_stream, inputs=[model_list, prompt], outputs=out)
go_once.click(generate_once, inputs=[model_list, prompt], outputs=out)
# Critical: disable SSR; ensure queue is enabled
app.queue(default_concurrency_limit=1)
app.launch(ssr_mode=False, show_error=True)