import gradio as gr import os import torch training_status = {"running": False, "log": ""} def run_training( base_model: str, dataset_id: str, epochs: int, batch_size: int, learning_rate: float, lora_r: int, output_repo: str, progress=gr.Progress() ): global training_status training_status["running"] = True training_status["log"] = "" def log(msg): training_status["log"] += msg + "\n" print(msg) try: log("=" * 50) log("Agent Zero Music Workflow Trainer") log("Intuition Labs • terminals.tech") log("=" * 50) progress(0.05, desc="Installing dependencies...") log("\n[1/6] Installing dependencies...") os.system("pip install -q transformers trl peft datasets accelerate bitsandbytes") progress(0.1, desc="Loading libraries...") log("[2/6] Loading libraries...") from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from trl import SFTTrainer, SFTConfig progress(0.15, desc="Loading tokenizer...") log(f"[3/6] Loading tokenizer: {base_model}") tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token progress(0.2, desc="Loading model with 4-bit quantization...") log(f"[4/6] Loading model with 4-bit quantization...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( base_model, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, ) model = prepare_model_for_kbit_training(model) log(f"[4/6] Applying LoRA (r={lora_r})...") lora_config = LoraConfig( r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) progress(0.3, desc="Loading dataset...") log(f"[5/6] Loading dataset: {dataset_id}") dataset = load_dataset(dataset_id, split="train") def format_example(example): if "instruction" in example and "response" in example: return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"} elif "text" in example: return {"text": example["text"]} else: return {"text": " ".join(str(v) for v in example.values() if isinstance(v, str))} dataset = dataset.map(format_example) log(f"Dataset size: {len(dataset)} examples") progress(0.4, desc="Setting up trainer...") log(f"[6/6] Starting training: {epochs} epochs, batch={batch_size}, lr={learning_rate}") # Use SFTConfig instead of TrainingArguments for newer TRL sft_config = SFTConfig( output_dir="./outputs", num_train_epochs=epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, learning_rate=learning_rate, lr_scheduler_type="cosine", warmup_ratio=0.1, logging_steps=10, save_steps=100, bf16=True, gradient_checkpointing=True, push_to_hub=True, hub_model_id=output_repo, hub_token=os.environ.get("HF_TOKEN"), max_length=4096, dataset_text_field="text", ) trainer = SFTTrainer( model=model, args=sft_config, train_dataset=dataset, processing_class=tokenizer, ) log("\n" + "=" * 50) log("TRAINING STARTED") log("=" * 50) trainer.train() progress(0.95, desc="Pushing to Hub...") log("\nPushing model to Hub...") trainer.push_to_hub() progress(1.0, desc="Complete!") log("\n" + "=" * 50) log("TRAINING COMPLETE!") log(f"Model saved to: https://huggingface.co/{output_repo}") log("=" * 50) training_status["running"] = False return training_status["log"] except Exception as e: log(f"\nERROR: {str(e)}") import traceback log(traceback.format_exc()) training_status["running"] = False return training_status["log"] with gr.Blocks(title="Agent Zero Trainer") as demo: gr.Markdown(""" # Agent Zero Music Workflow Trainer **Intuition Labs** • terminals.tech Fine-tune models for coherent multi-context orchestration. Running on L40S GPU (48GB VRAM) - $1.80/hr """) with gr.Row(): with gr.Column(): base_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model") dataset_id = gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID") epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") batch_size = gr.Slider(1, 8, value=2, step=1, label="Batch Size") learning_rate = gr.Number(value=2e-5, label="Learning Rate") lora_r = gr.Slider(8, 64, value=16, step=8, label="LoRA Rank") output_repo = gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo") submit_btn = gr.Button("Start Training", variant="primary") with gr.Column(): output = gr.Textbox(label="Training Log", lines=25, max_lines=50) submit_btn.click( fn=run_training, inputs=[base_model, dataset_id, epochs, batch_size, learning_rate, lora_r, output_repo], outputs=output, ) if __name__ == "__main__": demo.launch()