vliw-optimizer / app.py
CreativeEngineer's picture
Revert to Gradio 5.49.1 and ASCII logs
c5c47e3
"""
HF Spaces app for VLIW kernel optimization via RL.
Uses actual simulator for correctness-gated cycle-count rewards.
"""
import os
import sys
import gradio as gr
import threading
import time
import random
import re
from copy import copy
from pathlib import Path
# Check imports at startup
startup_log = []
def check_import(name, import_fn):
try:
result = import_fn()
startup_log.append(f"[OK] {name}: {result}")
return True
except Exception as e:
startup_log.append(f"[ERR] {name}: {str(e)[:80]}")
return False
check_import("torch", lambda: __import__("torch").__version__)
check_import("transformers", lambda: __import__("transformers").__version__)
check_import("datasets", lambda: __import__("datasets").__version__)
check_import("peft", lambda: __import__("peft").__version__)
check_import("trl", lambda: __import__("trl").__version__)
check_import("huggingface_hub", lambda: __import__("huggingface_hub").__version__)
try:
from trl import GRPOConfig, GRPOTrainer
startup_log.append("[OK] GRPOTrainer: OK")
except Exception as e:
startup_log.append(f"[ERR] GRPOTrainer: {e}")
try:
import torch
if torch.cuda.is_available():
startup_log.append(f"[OK] CUDA: {torch.cuda.get_device_name(0)}")
else:
startup_log.append("[ERR] CUDA: Not available")
except Exception as e:
startup_log.append(f"[ERR] CUDA check: {e}")
# Prefer simulator + KernelBuilder from bundled original_performance_takehome.
# In Spaces, this keeps evaluation consistent and enables correctness checks.
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PERF_TAKEHOME_PATH = os.path.join(THIS_DIR, "original_performance_takehome")
if os.path.isdir(PERF_TAKEHOME_PATH):
sys.path.insert(0, PERF_TAKEHOME_PATH)
# Import simulator components
try:
from problem import (
Machine, Tree, Input, DebugInfo,
build_mem_image, reference_kernel2,
SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, CoreState
)
from perf_takehome import KernelBuilder, HASH_STAGES
startup_log.append("[OK] VLIW Simulator: OK")
SIMULATOR_AVAILABLE = True
except Exception as e:
startup_log.append(f"[ERR] VLIW Simulator: {e}")
SIMULATOR_AVAILABLE = False
# Hugging Face Hub adapter persistence via dataset repo
try:
from huggingface_hub import HfApi, snapshot_download
startup_log.append("[OK] huggingface_hub: OK")
HF_HUB_AVAILABLE = True
except Exception as e:
startup_log.append(f"[ERR] huggingface_hub: {str(e)[:80]}")
HF_HUB_AVAILABLE = False
# Constants
BASELINE_CYCLES = 147734
TARGET_CYCLES = 1363
SCORE_SCALE = 3000.0
PERSIST_DIR = "/data" if os.path.isdir("/data") else "."
ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapters", "perf_takehome_latest")
ADAPTER_DATASET_REPO = os.environ.get("ADAPTER_DATASET_REPO", "CreativeEngineer/vliw-optimizer-adapters")
ADAPTER_DATASET_SUBDIR = os.environ.get("ADAPTER_DATASET_SUBDIR", "perf_takehome_latest")
# Training state
training_state = {
"is_training": False,
"should_stop": False,
"log": [],
"best_cycles": BASELINE_CYCLES,
"best_code": None,
"step": 0,
}
state_lock = threading.Lock()
_eval_context = {}
def get_status():
return "\n".join(startup_log)
def extract_code_block(text: str) -> str:
# Prefer closed fences
pattern = r"```python\s*(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1].strip()
pattern = r"```\s*(.*?)```"
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1].strip()
# Handle unclosed fences (common when generation truncates)
if "```python" in text:
after = text.split("```python", 1)[1]
if "```" in after:
after = after.split("```", 1)[0]
return after.strip()
if "```" in text:
after = text.split("```", 1)[1]
if "```" in after:
after = after.split("```", 1)[0]
return after.strip()
return text.strip()
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
def _ensure_dir(path: str) -> None:
Path(path).mkdir(parents=True, exist_ok=True)
def _adapter_exists(path: str) -> bool:
return os.path.exists(os.path.join(path, "adapter_config.json"))
def _try_download_adapter(add_log) -> None:
if not HF_HUB_AVAILABLE:
add_log("[ERR] Hub sync disabled: huggingface_hub not available")
return
_ensure_dir(os.path.dirname(ADAPTER_DIR))
allow = [f"{ADAPTER_DATASET_SUBDIR}/**"]
try:
snapshot_download(
repo_id=ADAPTER_DATASET_REPO,
repo_type="dataset",
allow_patterns=allow,
local_dir=os.path.dirname(ADAPTER_DIR),
local_dir_use_symlinks=False,
token=_hf_token(),
)
downloaded = os.path.join(os.path.dirname(ADAPTER_DIR), ADAPTER_DATASET_SUBDIR)
if _adapter_exists(downloaded):
if downloaded != ADAPTER_DIR:
_ensure_dir(os.path.dirname(ADAPTER_DIR))
# Simple overwrite by copying files into ADAPTER_DIR
_ensure_dir(ADAPTER_DIR)
for root, _, files in os.walk(downloaded):
rel = os.path.relpath(root, downloaded)
dst_root = ADAPTER_DIR if rel == "." else os.path.join(ADAPTER_DIR, rel)
_ensure_dir(dst_root)
for name in files:
src = os.path.join(root, name)
dst = os.path.join(dst_root, name)
with open(src, "rb") as fsrc, open(dst, "wb") as fdst:
fdst.write(fsrc.read())
add_log(f"[OK] Downloaded adapter from dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
else:
add_log("ℹ No adapter found in dataset yet")
except Exception as e:
add_log(f"ℹ Adapter download skipped: {str(e)[:160]}")
def _try_upload_adapter(add_log) -> None:
if not HF_HUB_AVAILABLE:
add_log("[ERR] Hub sync disabled: huggingface_hub not available")
return
if not _adapter_exists(ADAPTER_DIR):
add_log("ℹ No adapter to upload yet")
return
token = _hf_token()
if token is None:
add_log("ℹ No HF token set (HF_TOKEN/HUGGINGFACE_HUB_TOKEN); skipping upload")
return
try:
api = HfApi(token=token)
api.create_repo(repo_id=ADAPTER_DATASET_REPO, repo_type="dataset", exist_ok=True)
api.upload_folder(
repo_id=ADAPTER_DATASET_REPO,
repo_type="dataset",
folder_path=ADAPTER_DIR,
path_in_repo=ADAPTER_DATASET_SUBDIR,
commit_message="Update perf_takehome adapter",
)
add_log(f"[OK] Uploaded adapter to dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
except Exception as e:
add_log(f"ℹ Adapter upload skipped: {str(e)[:160]}")
def _run_machine_with_cycle_limit(machine: Machine, max_cycles: int) -> bool:
for core in machine.cores:
if core.state == CoreState.PAUSED:
core.state = CoreState.RUNNING
while any(c.state == CoreState.RUNNING for c in machine.cores):
has_non_debug = False
for core in machine.cores:
if core.state != CoreState.RUNNING:
continue
if core.pc >= len(machine.program):
core.state = CoreState.STOPPED
continue
instr = machine.program[core.pc]
core.pc += 1
machine.step(instr, core)
if any(name != "debug" for name in instr.keys()):
has_non_debug = True
if has_non_debug:
machine.cycle += 1
if machine.cycle >= max_cycles:
for core in machine.cores:
core.state = CoreState.STOPPED
return False
return True
def _get_eval_context(seed: int) -> dict:
with state_lock:
cached = _eval_context.get(seed)
if cached is not None:
return cached
random.seed(seed)
forest = Tree.generate(10)
inp = Input.generate(forest, 256, 16)
mem0 = build_mem_image(forest, inp)
ref_mem = None
for ref_mem in reference_kernel2(list(mem0)):
pass
if ref_mem is None:
raise RuntimeError("Reference kernel produced no output")
inp_values_p = ref_mem[6]
expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)]
ctx = {
"forest": forest,
"inp": inp,
"mem0": mem0,
"expected": expected,
"inp_values_p": inp_values_p,
}
with state_lock:
_eval_context[seed] = ctx
return ctx
def verify_perf_takehome_code(code: str, seed: int = 123) -> dict:
if not SIMULATOR_AVAILABLE:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": None,
"msg": "Simulator unavailable",
}
try:
code = code.strip()
if not code:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": None,
"msg": "Empty code",
}
if "OptimizedKernelBuilder" not in code:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": None,
"msg": "Missing OptimizedKernelBuilder",
}
if "def run" not in code:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": None,
"msg": "Missing run()",
}
safe_builtins = {
"abs": abs,
"all": all,
"any": any,
"dict": dict,
"enumerate": enumerate,
"int": int,
"len": len,
"list": list,
"max": max,
"min": min,
"range": range,
"sum": sum,
"tuple": tuple,
"zip": zip,
}
exec_globals = {
"__builtins__": safe_builtins,
"KernelBuilder": KernelBuilder,
"HASH_STAGES": HASH_STAGES,
"VLEN": VLEN,
"SLOT_LIMITS": SLOT_LIMITS,
}
exec(code, exec_globals)
if "OptimizedKernelBuilder" not in exec_globals:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": None,
"msg": "OptimizedKernelBuilder not defined after exec",
}
ctx = _get_eval_context(seed)
forest = ctx["forest"]
inp = ctx["inp"]
mem0 = ctx["mem0"]
kb = exec_globals["OptimizedKernelBuilder"]()
kb.build_kernel(10, len(forest.values), 256, 16)
machine = Machine(
list(mem0),
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
trace=False,
)
machine.enable_pause = False
machine.enable_debug = False
ok = _run_machine_with_cycle_limit(machine, max_cycles=250000)
if not ok:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": int(machine.cycle),
"msg": f"Exceeded cycle limit (cycles={machine.cycle})",
}
cycles = machine.cycle
if cycles <= 100:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": int(cycles),
"msg": f"Suspiciously low cycles ({cycles})",
}
if cycles > 200000:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": int(cycles),
"msg": f"Cycles too high ({cycles})",
}
inp_values_p = ctx["inp_values_p"]
expected = ctx["expected"]
actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)]
if expected != actual:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": int(cycles),
"msg": f"Incorrect output (cycles={cycles})",
}
score = SCORE_SCALE / cycles
return {
"score": float(score),
"correctness": 1.0,
"cycles": int(cycles),
"msg": f"Success: {cycles} cycles",
}
except Exception as e:
return {
"score": 0.0,
"correctness": 0.0,
"cycles": None,
"msg": f"Execution error: {str(e)[:200]}",
}
def perf_takehome_reward_fn(completions, prompts=None, **kwargs):
rewards = []
for completion in completions:
if isinstance(completion, list):
text = completion[0].get("content", "") if completion else ""
else:
text = str(completion)
code = extract_code_block(text)
result = verify_perf_takehome_code(code)
reward = 0.0
if result.get("correctness", 0.0) > 0:
reward = float(result["score"]) + 1.0
cycles = result.get("cycles")
with state_lock:
if isinstance(cycles, int) and cycles < training_state["best_cycles"]:
training_state["best_cycles"] = cycles
training_state["best_code"] = code
rewards.append(float(reward))
return rewards
# Prompt template for VLIW optimization
PERF_TAKEHOME_PROMPT = f"""Write an optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK.
ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch.
API (KernelBuilder):
- alloc_scratch(name, length) -> addr
- scratch_const(val, name) -> addr
- add(engine, slot): engine in {{alu, valu, load, store, flow}}
- alu: (op, dst, src1, src2) where op in {{+,-,*,//,%,^,&,|,<<,>>,<,==,!=,<=,>=,>}}
- valu: same ops but on vectors (VLEN=8)
- load: (load,dst,addr), (vload,dst,addr), (const,dst,val), (vbroadcast,dst,scalar_addr)
- store: (store,addr,src), (vstore,addr,src)
- flow: (select,dst,cond,t,f), (vselect,dst,cond,t,f), (cond_jump,cond,pc), (jump,pc), (halt,)
- label(name): mark code position
- build(slots, vliw=True): pack slots into VLIW bundle
MEMORY: mem[4]=forest_values, mem[5]=inp_indices, mem[6]=inp_values (256 elements each)
ALGORITHM: 16 rounds x 256 items:
load idx,val
node = tree[idx]
val = hash(val ^ node) using HASH_STAGES
idx = 2*idx + (1 if val%2==0 else 2)
idx = 0 if idx >= n_nodes else idx
store idx,val
RULES:
- Output exactly one python code block.
- The code block must define:
- class OptimizedKernelBuilder(KernelBuilder): override build_kernel() and emit instructions using add()/build()
- def run(): return any tuple (ignored), but must exist
- No imports.
Baseline: {BASELINE_CYCLES:,} cycles. Target: <{TARGET_CYCLES:,} cycles.
"""
def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue):
"""Run GRPO + LoRA training with correctness-gated perf_takehome rewards."""
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from peft import PeftModel
from trl import GRPOConfig, GRPOTrainer
from transformers import TrainerCallback
log = []
def add_log(msg):
log.append(f"[{time.strftime('%H:%M:%S')}] {msg}")
with state_lock:
training_state["log"] = log.copy()
with state_lock:
training_state["is_training"] = True
training_state["should_stop"] = False
training_state["log"] = []
training_state["best_cycles"] = BASELINE_CYCLES
training_state["best_code"] = None
training_state["step"] = 0
try:
add_log(f"Starting VLIW optimization training")
add_log(f"Model: {model_name}")
add_log(f"Chunk steps: {chunk_steps}")
add_log(f"Auto-continue: {auto_continue} (max_total_steps={max_total_steps}, max_minutes={max_minutes})")
add_log(f"Baseline: {BASELINE_CYCLES:,} cycles, Target: {TARGET_CYCLES:,} cycles")
add_log(f"Adapter dir: {ADAPTER_DIR}")
add_log(f"Adapter dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}")
# Load tokenizer
add_log("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
add_log("[OK] Tokenizer ready")
# Load model with 4-bit quantization
add_log("Loading model (4-bit quantization)...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
add_log(f"[OK] Base model loaded on {next(base_model.parameters()).device}")
# Try to restore adapter from dataset before loading it
_try_download_adapter(add_log)
# Resume LoRA adapter if present
resume_adapter = False
if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")):
add_log("Loading existing LoRA adapter (resume)...")
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True)
add_log("[OK] Adapter loaded")
resume_adapter = True
else:
model = base_model
# Create dataset with prompts
add_log("Creating VLIW optimization dataset...")
prompts = [PERF_TAKEHOME_PROMPT] * 16
dataset = Dataset.from_dict({"prompt": prompts})
add_log(f"[OK] Dataset ready: {len(prompts)} prompts")
# LoRA config
add_log("Setting up LoRA...")
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
progress = {"step": 0}
start_time = time.time()
max_seconds = float(max_minutes) * 60.0 if auto_continue else float("inf")
total_target_steps = int(max_total_steps) if auto_continue else int(chunk_steps)
# Custom callback for logging + early stop
class VLIWCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
with state_lock:
progress["step"] += 1
training_state["step"] = progress["step"]
if training_state["should_stop"]:
control.should_training_stop = True
if training_state["best_cycles"] <= TARGET_CYCLES:
control.should_training_stop = True
return control
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
loss = logs.get("loss", "N/A")
reward = logs.get("reward", logs.get("mean_reward", "N/A"))
step = progress["step"]
add_log(f"Step {step}: loss={loss:.4f}, reward={reward:.4f}" if isinstance(loss, float) else f"Step {step}: {logs}")
add_log("Creating GRPO trainer with perf_takehome rewards...")
output_dir = os.path.join(PERSIST_DIR, "grpo_perf_takehome_output")
os.makedirs(output_dir, exist_ok=True)
add_log("[OK] Trainer config ready")
add_log("Starting training loop...")
add_log("(Stops early if target reached; can auto-continue in chunks)")
chunk_idx = 0
while True:
with state_lock:
if training_state["should_stop"]:
break
if training_state["best_cycles"] <= TARGET_CYCLES:
break
if progress["step"] >= total_target_steps:
break
if (time.time() - start_time) >= max_seconds:
break
remaining = total_target_steps - progress["step"]
this_chunk_steps = min(int(chunk_steps), int(remaining))
if this_chunk_steps <= 0:
break
chunk_idx += 1
add_log(f"Chunk {chunk_idx}: training {this_chunk_steps} steps...")
config = GRPOConfig(
output_dir=output_dir,
num_train_epochs=1,
max_steps=this_chunk_steps,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=1e-5,
logging_steps=1,
save_steps=999999,
report_to="none",
remove_unused_columns=False,
max_completion_length=2048,
num_generations=4,
)
trainer_kwargs = {
"model": model,
"args": config,
"train_dataset": dataset,
"reward_funcs": perf_takehome_reward_fn,
"processing_class": tokenizer,
"callbacks": [VLIWCallback()],
}
if not resume_adapter:
trainer_kwargs["peft_config"] = lora_config
trainer = GRPOTrainer(**trainer_kwargs)
train_result = trainer.train()
metrics = train_result.metrics
add_log(f"Chunk {chunk_idx} done: steps={metrics.get('train_steps', this_chunk_steps)}")
# Save adapter after each chunk so it persists across restarts
try:
os.makedirs(os.path.dirname(ADAPTER_DIR), exist_ok=True)
trainer.save_model(ADAPTER_DIR)
add_log(f"[OK] Saved adapter to {ADAPTER_DIR}")
_try_upload_adapter(add_log)
except Exception as e:
add_log(f"[ERR] Failed to save adapter: {str(e)[:120]}")
if not auto_continue:
break
# Test generation
add_log("Testing trained model...")
inputs = tokenizer(PERF_TAKEHOME_PROMPT, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
code = extract_code_block(result)
verify_out = verify_perf_takehome_code(code)
if verify_out.get("correctness", 0.0) > 0:
cycles = verify_out.get("cycles")
add_log(f"Generated kernel verified: {cycles:,} cycles")
speedup = BASELINE_CYCLES / max(int(cycles), 1) if isinstance(cycles, int) else 0.0
add_log(f"Speedup: {speedup:.2f}x over baseline")
else:
add_log(f"Generated kernel invalid: {verify_out.get('msg', '')[:160]}")
add_log("\n[OK] All done!")
except Exception as e:
import traceback
add_log(f"[ERR] Error: {e}")
add_log(traceback.format_exc()[:800])
finally:
with state_lock:
training_state["is_training"] = False
try:
del model
torch.cuda.empty_cache()
except:
pass
return "\n".join(log)
def start_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue):
"""Start training."""
with state_lock:
if training_state["is_training"]:
return "\n".join(training_state["log"][-200:]) or "Training already in progress. Please wait."
training_state["is_training"] = True
training_state["should_stop"] = False
training_state["log"] = [f"[{time.strftime('%H:%M:%S')}] Starting training..."]
training_state["step"] = 0
thread = threading.Thread(
target=run_training,
args=(
model_name,
int(chunk_steps),
int(max_total_steps),
float(max_minutes),
bool(auto_continue),
),
daemon=True,
)
thread.start()
return "Training started. Logs will stream below."
def stop_training():
"""Request stop."""
with state_lock:
if not training_state["is_training"]:
return "No training in progress"
training_state["should_stop"] = True
return "Stop requested. Training will stop after current step."
# Gradio UI
with gr.Blocks(title="VLIW Optimizer") as demo:
gr.Markdown("# VLIW Kernel Optimizer - RL Training")
gr.Markdown(f"""
Train a language model with reinforcement learning (LoRA) at test time to generate correct, fast VLIW/SIMD kernels.
**Goal:** Reduce cycle count from **{BASELINE_CYCLES:,}** (baseline) to **<{TARGET_CYCLES:,}** (108x speedup)
**How it works:**
1. Model generates Python kernel builder code
2. Simulator checks correctness vs reference and measures cycles
3. GRPO updates LoRA weights; adapter is saved and reloaded from `{ADAPTER_DIR}`
""")
with gr.Row():
with gr.Column(scale=1):
status_box = gr.Textbox(
label="System Status",
value=get_status(),
lines=12,
interactive=False,
)
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=[
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
"Qwen/Qwen2.5-Coder-3B-Instruct",
],
value="Qwen/Qwen2.5-Coder-1.5B-Instruct",
label="Model",
)
chunk_steps_slider = gr.Slider(
minimum=5,
maximum=100,
value=20,
step=5,
label="Chunk Steps",
)
auto_continue_checkbox = gr.Checkbox(
value=False,
label="Auto-continue (chain chunks)",
)
max_total_steps_slider = gr.Slider(
minimum=5,
maximum=500,
value=100,
step=5,
label="Max Total Steps",
)
max_minutes_number = gr.Number(
value=60,
precision=0,
label="Max Minutes",
)
with gr.Row():
start_btn = gr.Button("Start Training", variant="primary")
stop_btn = gr.Button("Stop", variant="stop")
output_box = gr.Textbox(
label="Training Log",
lines=25,
interactive=False,
value="Click 'Start Training' to begin VLIW optimization.",
)
def poll_log():
with state_lock:
if not training_state["log"]:
return ""
lines = training_state["log"][-200:]
return "\n".join(line[:400] for line in lines)
start_btn.click(
start_training,
[model_dropdown, chunk_steps_slider, max_total_steps_slider, max_minutes_number, auto_continue_checkbox],
[output_box],
queue=False,
)
stop_btn.click(stop_training, [], [output_box], queue=False)
refresh_btn = gr.Button("Refresh Log")
refresh_btn.click(poll_log, outputs=[output_box], queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)