|
|
import gradio as gr |
|
|
import os |
|
|
import torch |
|
|
|
|
|
training_status = {"running": False, "log": ""} |
|
|
|
|
|
def run_training( |
|
|
base_model: str, |
|
|
dataset_id: str, |
|
|
epochs: int, |
|
|
batch_size: int, |
|
|
learning_rate: float, |
|
|
lora_r: int, |
|
|
output_repo: str, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
global training_status |
|
|
training_status["running"] = True |
|
|
training_status["log"] = "" |
|
|
|
|
|
def log(msg): |
|
|
training_status["log"] += msg + "\n" |
|
|
print(msg) |
|
|
|
|
|
try: |
|
|
log("=" * 50) |
|
|
log("Agent Zero Music Workflow Trainer") |
|
|
log("Intuition Labs • terminals.tech") |
|
|
log("=" * 50) |
|
|
|
|
|
progress(0.05, desc="Installing dependencies...") |
|
|
log("\n[1/6] Installing dependencies...") |
|
|
os.system("pip install -q transformers trl peft datasets accelerate bitsandbytes") |
|
|
|
|
|
progress(0.1, desc="Loading libraries...") |
|
|
log("[2/6] Loading libraries...") |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments |
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
|
|
|
progress(0.15, desc="Loading tokenizer...") |
|
|
log(f"[3/6] Loading tokenizer: {base_model}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
progress(0.2, desc="Loading model with 4-bit quantization...") |
|
|
log(f"[4/6] Loading model with 4-bit quantization...") |
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
log(f"[4/6] Applying LoRA (r={lora_r})...") |
|
|
lora_config = LoraConfig( |
|
|
r=lora_r, |
|
|
lora_alpha=lora_r * 2, |
|
|
lora_dropout=0.05, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
progress(0.3, desc="Loading dataset...") |
|
|
log(f"[5/6] Loading dataset: {dataset_id}") |
|
|
dataset = load_dataset(dataset_id, split="train") |
|
|
|
|
|
def format_example(example): |
|
|
if "instruction" in example and "response" in example: |
|
|
return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"} |
|
|
elif "text" in example: |
|
|
return {"text": example["text"]} |
|
|
else: |
|
|
return {"text": " ".join(str(v) for v in example.values() if isinstance(v, str))} |
|
|
|
|
|
dataset = dataset.map(format_example) |
|
|
log(f"Dataset size: {len(dataset)} examples") |
|
|
|
|
|
progress(0.4, desc="Setting up trainer...") |
|
|
log(f"[6/6] Starting training: {epochs} epochs, batch={batch_size}, lr={learning_rate}") |
|
|
|
|
|
|
|
|
sft_config = SFTConfig( |
|
|
output_dir="./outputs", |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=batch_size, |
|
|
gradient_accumulation_steps=4, |
|
|
learning_rate=learning_rate, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1, |
|
|
logging_steps=10, |
|
|
save_steps=100, |
|
|
bf16=True, |
|
|
gradient_checkpointing=True, |
|
|
push_to_hub=True, |
|
|
hub_model_id=output_repo, |
|
|
hub_token=os.environ.get("HF_TOKEN"), |
|
|
max_length=4096, |
|
|
dataset_text_field="text", |
|
|
) |
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
args=sft_config, |
|
|
train_dataset=dataset, |
|
|
processing_class=tokenizer, |
|
|
) |
|
|
|
|
|
log("\n" + "=" * 50) |
|
|
log("TRAINING STARTED") |
|
|
log("=" * 50) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
progress(0.95, desc="Pushing to Hub...") |
|
|
log("\nPushing model to Hub...") |
|
|
trainer.push_to_hub() |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
log("\n" + "=" * 50) |
|
|
log("TRAINING COMPLETE!") |
|
|
log(f"Model saved to: https://huggingface.co/{output_repo}") |
|
|
log("=" * 50) |
|
|
|
|
|
training_status["running"] = False |
|
|
return training_status["log"] |
|
|
|
|
|
except Exception as e: |
|
|
log(f"\nERROR: {str(e)}") |
|
|
import traceback |
|
|
log(traceback.format_exc()) |
|
|
training_status["running"] = False |
|
|
return training_status["log"] |
|
|
|
|
|
with gr.Blocks(title="Agent Zero Trainer") as demo: |
|
|
gr.Markdown(""" |
|
|
# Agent Zero Music Workflow Trainer |
|
|
**Intuition Labs** • terminals.tech |
|
|
|
|
|
Fine-tune models for coherent multi-context orchestration. |
|
|
Running on L40S GPU (48GB VRAM) - $1.80/hr |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
base_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model") |
|
|
dataset_id = gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID") |
|
|
epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs") |
|
|
batch_size = gr.Slider(1, 8, value=2, step=1, label="Batch Size") |
|
|
learning_rate = gr.Number(value=2e-5, label="Learning Rate") |
|
|
lora_r = gr.Slider(8, 64, value=16, step=8, label="LoRA Rank") |
|
|
output_repo = gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo") |
|
|
submit_btn = gr.Button("Start Training", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Textbox(label="Training Log", lines=25, max_lines=50) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=run_training, |
|
|
inputs=[base_model, dataset_id, epochs, batch_size, learning_rate, lora_r, output_repo], |
|
|
outputs=output, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|