wheattoast11's picture
Upload folder using huggingface_hub
fa1982f verified
import gradio as gr
import os
import torch
training_status = {"running": False, "log": ""}
def run_training(
base_model: str,
dataset_id: str,
epochs: int,
batch_size: int,
learning_rate: float,
lora_r: int,
output_repo: str,
progress=gr.Progress()
):
global training_status
training_status["running"] = True
training_status["log"] = ""
def log(msg):
training_status["log"] += msg + "\n"
print(msg)
try:
log("=" * 50)
log("Agent Zero Music Workflow Trainer")
log("Intuition Labs • terminals.tech")
log("=" * 50)
progress(0.05, desc="Installing dependencies...")
log("\n[1/6] Installing dependencies...")
os.system("pip install -q transformers trl peft datasets accelerate bitsandbytes")
progress(0.1, desc="Loading libraries...")
log("[2/6] Loading libraries...")
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
progress(0.15, desc="Loading tokenizer...")
log(f"[3/6] Loading tokenizer: {base_model}")
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
progress(0.2, desc="Loading model with 4-bit quantization...")
log(f"[4/6] Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model = prepare_model_for_kbit_training(model)
log(f"[4/6] Applying LoRA (r={lora_r})...")
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_r * 2,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
progress(0.3, desc="Loading dataset...")
log(f"[5/6] Loading dataset: {dataset_id}")
dataset = load_dataset(dataset_id, split="train")
def format_example(example):
if "instruction" in example and "response" in example:
return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"}
elif "text" in example:
return {"text": example["text"]}
else:
return {"text": " ".join(str(v) for v in example.values() if isinstance(v, str))}
dataset = dataset.map(format_example)
log(f"Dataset size: {len(dataset)} examples")
progress(0.4, desc="Setting up trainer...")
log(f"[6/6] Starting training: {epochs} epochs, batch={batch_size}, lr={learning_rate}")
# Use SFTConfig instead of TrainingArguments for newer TRL
sft_config = SFTConfig(
output_dir="./outputs",
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
logging_steps=10,
save_steps=100,
bf16=True,
gradient_checkpointing=True,
push_to_hub=True,
hub_model_id=output_repo,
hub_token=os.environ.get("HF_TOKEN"),
max_length=4096,
dataset_text_field="text",
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
processing_class=tokenizer,
)
log("\n" + "=" * 50)
log("TRAINING STARTED")
log("=" * 50)
trainer.train()
progress(0.95, desc="Pushing to Hub...")
log("\nPushing model to Hub...")
trainer.push_to_hub()
progress(1.0, desc="Complete!")
log("\n" + "=" * 50)
log("TRAINING COMPLETE!")
log(f"Model saved to: https://huggingface.co/{output_repo}")
log("=" * 50)
training_status["running"] = False
return training_status["log"]
except Exception as e:
log(f"\nERROR: {str(e)}")
import traceback
log(traceback.format_exc())
training_status["running"] = False
return training_status["log"]
with gr.Blocks(title="Agent Zero Trainer") as demo:
gr.Markdown("""
# Agent Zero Music Workflow Trainer
**Intuition Labs** • terminals.tech
Fine-tune models for coherent multi-context orchestration.
Running on L40S GPU (48GB VRAM) - $1.80/hr
""")
with gr.Row():
with gr.Column():
base_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model")
dataset_id = gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID")
epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
batch_size = gr.Slider(1, 8, value=2, step=1, label="Batch Size")
learning_rate = gr.Number(value=2e-5, label="Learning Rate")
lora_r = gr.Slider(8, 64, value=16, step=8, label="LoRA Rank")
output_repo = gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo")
submit_btn = gr.Button("Start Training", variant="primary")
with gr.Column():
output = gr.Textbox(label="Training Log", lines=25, max_lines=50)
submit_btn.click(
fn=run_training,
inputs=[base_model, dataset_id, epochs, batch_size, learning_rate, lora_r, output_repo],
outputs=output,
)
if __name__ == "__main__":
demo.launch()