Percy3822's picture
Create app.py
719f624 verified
raw
history blame
7.19 kB
import os, io, zipfile, shutil, subprocess, json, time, glob, tempfile
import gradio as gr
from pathlib import Path
from typing import List, Tuple
WORKDIR = Path(".")
DATASET_PATH = WORKDIR / "dataset.jsonl"
LOG_PATH = WORKDIR / "train.log"
MODEL_DIR = WORKDIR / "trained_model" # training output folder
ZIP_PATH = WORKDIR / "trained_model.zip" # zipped after train
MODELS_ROOT = WORKDIR # where we scan for saved AIs
# ---------- helpers ----------
def _safe_unzip(zip_file: str, out_dir: Path) -> str:
out_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_file, "r") as z:
z.extractall(out_dir)
# return the inner model folder if zip contained a single directory
subdirs = [p for p in out_dir.iterdir() if p.is_dir()]
return str(subdirs[0] if len(subdirs) == 1 else out_dir)
def _list_local_models() -> List[str]:
"""
Return model folders found under MODELS_ROOT that look like HF models.
We include any folder that has a tokenizer.json or tokenizer_config.json.
"""
candidates = []
for p in MODELS_ROOT.iterdir():
if not p.is_dir():
continue
if (p / "tokenizer.json").exists() or (p / "tokenizer_config.json").exists():
candidates.append(str(p))
return sorted(candidates)
def _start_training_subprocess() -> int:
# clear old outputs
if MODEL_DIR.exists():
shutil.rmtree(MODEL_DIR)
if ZIP_PATH.exists():
ZIP_PATH.unlink(missing_ok=True)
cmd = [
"python", "train.py",
"--dataset", str(DATASET_PATH),
"--output", str(MODEL_DIR),
# sensible defaults for quick, real training; adjust in train.py if needed
"--model_name", "Salesforce/codegen-350M-multi",
"--epochs", "1",
"--batch_size", "2",
"--block_size", "256",
"--learning_rate", "5e-5",
]
LOG_PATH.write_text("πŸ”₯ Starting training...\n", encoding="utf-8")
with open(LOG_PATH, "a", encoding="utf-8") as lf:
proc = subprocess.Popen(cmd, stdout=lf, stderr=subprocess.STDOUT)
return proc.wait()
def _zip_model_folder() -> bool:
if not MODEL_DIR.exists():
return False
if ZIP_PATH.exists():
ZIP_PATH.unlink()
shutil.make_archive(ZIP_PATH.with_suffix("").as_posix(), "zip", MODEL_DIR)
return ZIP_PATH.exists()
# ---------- UI callbacks ----------
def upload_dataset(file) -> str:
if file is None:
return "❌ No file selected."
shutil.copy(file.name, DATASET_PATH)
return f"βœ… Uploaded: {file.name} β†’ {DATASET_PATH.name}"
def start_training() -> Tuple[str, str, gr.File]:
if not DATASET_PATH.exists():
return ("❌ Please upload a JSONL first.", "", gr.File.update(visible=False))
exit_code = _start_training_subprocess()
# after training, try to zip and expose
if exit_code == 0 and _zip_model_folder():
status = "βœ… Training complete."
model_info = f"Saved: {MODEL_DIR.name} | Zip: {ZIP_PATH.name}"
return (status, model_info, gr.File.update(value=str(ZIP_PATH), visible=True))
else:
# surface the tail of the log for quick diagnosis
tail = ""
if LOG_PATH.exists():
with open(LOG_PATH, "r", encoding="utf-8") as f:
lines = f.readlines()[-30:]
tail = "".join(lines)
return (f"❌ Training failed (code {exit_code}).", tail, gr.File.update(visible=False))
def read_logs() -> str:
if LOG_PATH.exists():
return LOG_PATH.read_text(encoding="utf-8")[-20_000:] # last ~20k chars
return "⏳ Waiting for logs..."
def refresh_model_list() -> List[str]:
return _list_local_models()
def upload_model_zip(zip_file) -> Tuple[str, List[str]]:
if zip_file is None:
return "❌ No zip provided.", refresh_model_list()
out = WORKDIR / f"imported_{int(time.time())}"
path = _safe_unzip(zip_file.name, out)
msg = f"βœ… Imported model at: {path}"
return msg, refresh_model_list()
def generate(model_path: str, prompt: str) -> str:
if not model_path:
return "❌ Select a model."
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
if tok.pad_token_id is None and tok.eos_token_id is not None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(model_path)
gen = pipeline("text-generation", model=model, tokenizer=tok)
# decoding tuned for code
out = gen(
prompt,
max_new_tokens=220,
do_sample=True,
temperature=0.2,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=4,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
truncation=True,
)[0]["generated_text"]
return out
except Exception as e:
return f"❌ Error: {e}"
# ---------- UI ----------
with gr.Blocks(title="Python AI Trainer") as demo:
gr.Markdown("## 🧠 Python AI Trainer\nUpload JSONL, train, then test your model.")
with gr.Tab("Train"):
file_in = gr.File(label="πŸ“₯ Upload JSONL Dataset", file_types=[".jsonl", ".jsonl.gz", ".json"])
up_status = gr.Textbox(label="Upload Status", interactive=False)
start_btn = gr.Button("πŸš€ Start Training", variant="primary")
logs_box = gr.Textbox(label="πŸ“œ Live Logs (click Refresh)", lines=16)
refresh_logs = gr.Button("Refresh Logs")
status_box = gr.Textbox(label="Status", interactive=False)
model_info = gr.Textbox(label="Model Output", interactive=False)
dl = gr.File(label="πŸ“¦ Download Trained Model (.zip)", visible=False)
refresh_dl = gr.Button("Refresh Download Area")
file_in.change(fn=upload_dataset, inputs=file_in, outputs=up_status)
start_btn.click(fn=start_training, outputs=[status_box, model_info, dl])
refresh_logs.click(fn=read_logs, outputs=logs_box)
refresh_dl.click(fn=lambda: (gr.File.update(value=str(ZIP_PATH), visible=ZIP_PATH.exists())),
outputs=dl)
with gr.Tab("Test"):
gr.Markdown("### πŸ”¬ Choose a stored AI and prompt it")
refresh_models_btn = gr.Button("↻ Refresh AI List")
model_list = gr.Dropdown(choices=_list_local_models(), label="Available AIs", interactive=True)
up_zip = gr.File(label="Or upload a model .zip to test", file_types=[".zip"])
zip_status = gr.Textbox(label="Model Import Status", interactive=False)
prompt = gr.Textbox(label="Prompt", lines=6, placeholder="### Instruction:\nPython: write a function ...\n### Response:\n")
generate_btn = gr.Button("Generate")
output = gr.Textbox(label="AI Response", lines=20)
refresh_models_btn.click(fn=refresh_model_list, outputs=model_list)
up_zip.change(fn=upload_model_zip, inputs=up_zip, outputs=[zip_status, model_list])
generate_btn.click(fn=generate, inputs=[model_list, prompt], outputs=output)
demo.launch()