File size: 6,427 Bytes
3c1fb2e
 
4b4a154
3c1fb2e
a13a4a1
4b4a154
 
beb8000
 
 
 
 
 
 
4b4a154
3c1fb2e
4b4a154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a13a4a1
4b4a154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a13a4a1
 
4b4a154
 
 
 
 
 
 
 
 
 
 
 
 
 
fa1982f
a13a4a1
4b4a154
 
 
 
a13a4a1
4b4a154
a13a4a1
4b4a154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beb8000
4b4a154
 
 
 
 
 
a13a4a1
4b4a154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c1fb2e
 
4b4a154
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import os
import torch

training_status = {"running": False, "log": ""}

def run_training(
    base_model: str,
    dataset_id: str,
    epochs: int,
    batch_size: int,
    learning_rate: float,
    lora_r: int,
    output_repo: str,
    progress=gr.Progress()
):
    global training_status
    training_status["running"] = True
    training_status["log"] = ""
    
    def log(msg):
        training_status["log"] += msg + "\n"
        print(msg)
    
    try:
        log("=" * 50)
        log("Agent Zero Music Workflow Trainer")
        log("Intuition Labs • terminals.tech")
        log("=" * 50)
        
        progress(0.05, desc="Installing dependencies...")
        log("\n[1/6] Installing dependencies...")
        os.system("pip install -q transformers trl peft datasets accelerate bitsandbytes")
        
        progress(0.1, desc="Loading libraries...")
        log("[2/6] Loading libraries...")
        
        from datasets import load_dataset
        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
        from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
        from trl import SFTTrainer, SFTConfig
        
        progress(0.15, desc="Loading tokenizer...")
        log(f"[3/6] Loading tokenizer: {base_model}")
        tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        progress(0.2, desc="Loading model with 4-bit quantization...")
        log(f"[4/6] Loading model with 4-bit quantization...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
        )
        model = prepare_model_for_kbit_training(model)
        
        log(f"[4/6] Applying LoRA (r={lora_r})...")
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_r * 2,
            lora_dropout=0.05,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config)
        
        progress(0.3, desc="Loading dataset...")
        log(f"[5/6] Loading dataset: {dataset_id}")
        dataset = load_dataset(dataset_id, split="train")
        
        def format_example(example):
            if "instruction" in example and "response" in example:
                return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"}
            elif "text" in example:
                return {"text": example["text"]}
            else:
                return {"text": " ".join(str(v) for v in example.values() if isinstance(v, str))}
        
        dataset = dataset.map(format_example)
        log(f"Dataset size: {len(dataset)} examples")
        
        progress(0.4, desc="Setting up trainer...")
        log(f"[6/6] Starting training: {epochs} epochs, batch={batch_size}, lr={learning_rate}")
        
        # Use SFTConfig instead of TrainingArguments for newer TRL
        sft_config = SFTConfig(
            output_dir="./outputs",
            num_train_epochs=epochs,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=4,
            learning_rate=learning_rate,
            lr_scheduler_type="cosine",
            warmup_ratio=0.1,
            logging_steps=10,
            save_steps=100,
            bf16=True,
            gradient_checkpointing=True,
            push_to_hub=True,
            hub_model_id=output_repo,
            hub_token=os.environ.get("HF_TOKEN"),
            max_length=4096,
            dataset_text_field="text",
        )
        
        trainer = SFTTrainer(
            model=model,
            args=sft_config,
            train_dataset=dataset,
            processing_class=tokenizer,
        )
        
        log("\n" + "=" * 50)
        log("TRAINING STARTED")
        log("=" * 50)
        
        trainer.train()
        
        progress(0.95, desc="Pushing to Hub...")
        log("\nPushing model to Hub...")
        trainer.push_to_hub()
        
        progress(1.0, desc="Complete!")
        log("\n" + "=" * 50)
        log("TRAINING COMPLETE!")
        log(f"Model saved to: https://huggingface.co/{output_repo}")
        log("=" * 50)
        
        training_status["running"] = False
        return training_status["log"]
        
    except Exception as e:
        log(f"\nERROR: {str(e)}")
        import traceback
        log(traceback.format_exc())
        training_status["running"] = False
        return training_status["log"]

with gr.Blocks(title="Agent Zero Trainer") as demo:
    gr.Markdown("""
    # Agent Zero Music Workflow Trainer
    **Intuition Labs** • terminals.tech
    
    Fine-tune models for coherent multi-context orchestration.
    Running on L40S GPU (48GB VRAM) - $1.80/hr
    """)
    
    with gr.Row():
        with gr.Column():
            base_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model")
            dataset_id = gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID")
            epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
            batch_size = gr.Slider(1, 8, value=2, step=1, label="Batch Size")
            learning_rate = gr.Number(value=2e-5, label="Learning Rate")
            lora_r = gr.Slider(8, 64, value=16, step=8, label="LoRA Rank")
            output_repo = gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo")
            submit_btn = gr.Button("Start Training", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(label="Training Log", lines=25, max_lines=50)
    
    submit_btn.click(
        fn=run_training,
        inputs=[base_model, dataset_id, epochs, batch_size, learning_rate, lora_r, output_repo],
        outputs=output,
    )

if __name__ == "__main__":
    demo.launch()