wheattoast11 commited on
Commit
4b4a154
·
verified ·
1 Parent(s): beb8000

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +166 -35
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import gradio as gr
2
  import os
3
- import spaces
 
 
4
 
5
- # Training will be triggered via this interface
6
- @spaces.GPU
7
- def launch_training(
 
8
  base_model: str,
9
  dataset_id: str,
10
  epochs: int,
@@ -12,38 +15,166 @@ def launch_training(
12
  learning_rate: float,
13
  lora_r: int,
14
  output_repo: str,
 
15
  ):
16
- """Configure and launch training."""
17
- config_summary = f"""## Training Configuration
18
-
19
- - **Base Model**: {base_model}
20
- - **Dataset**: {dataset_id}
21
- - **Epochs**: {epochs}
22
- - **Batch Size**: {batch_size}
23
- - **Learning Rate**: {learning_rate}
24
- - **LoRA Rank**: {lora_r}
25
- - **Output Repo**: {output_repo}
26
-
27
- ### Status
28
- Training configured successfully with H200 ZeroGPU!
29
- """
30
- return config_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- demo = gr.Interface(
33
- fn=launch_training,
34
- inputs=[
35
- gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model"),
36
- gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID"),
37
- gr.Slider(1, 10, value=3, step=1, label="Epochs"),
38
- gr.Slider(1, 8, value=2, step=1, label="Batch Size"),
39
- gr.Number(value=2e-5, label="Learning Rate"),
40
- gr.Slider(8, 64, value=16, step=8, label="LoRA Rank"),
41
- gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo"),
42
- ],
43
- outputs=gr.Markdown(),
44
- title="Agent Zero Music Workflow Trainer",
45
- description="Fine-tune models for coherent multi-context orchestration. Intuition Labs + terminals.tech",
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
- demo.launch()
 
1
  import gradio as gr
2
  import os
3
+ import torch
4
+ from threading import Thread
5
+ import time
6
 
7
+ # Training status
8
+ training_status = {"running": False, "log": "", "progress": 0}
9
+
10
+ def run_training(
11
  base_model: str,
12
  dataset_id: str,
13
  epochs: int,
 
15
  learning_rate: float,
16
  lora_r: int,
17
  output_repo: str,
18
+ progress=gr.Progress()
19
  ):
20
+ global training_status
21
+ training_status["running"] = True
22
+ training_status["log"] = ""
23
+
24
+ def log(msg):
25
+ training_status["log"] += msg + "\n"
26
+ print(msg)
27
+
28
+ try:
29
+ log("=" * 50)
30
+ log("Agent Zero Music Workflow Trainer")
31
+ log("Intuition Labs • terminals.tech")
32
+ log("=" * 50)
33
+
34
+ progress(0.05, desc="Installing dependencies...")
35
+ log("\n[1/6] Installing dependencies...")
36
+ os.system("pip install -q transformers trl peft datasets accelerate bitsandbytes")
37
+
38
+ progress(0.1, desc="Loading libraries...")
39
+ log("[2/6] Loading libraries...")
40
+
41
+ from datasets import load_dataset
42
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
43
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
44
+ from trl import SFTTrainer
45
+
46
+ progress(0.15, desc="Loading tokenizer...")
47
+ log(f"[3/6] Loading tokenizer: {base_model}")
48
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ progress(0.2, desc="Loading model with 4-bit quantization...")
53
+ log(f"[4/6] Loading model with 4-bit quantization...")
54
+ bnb_config = BitsAndBytesConfig(
55
+ load_in_4bit=True,
56
+ bnb_4bit_compute_dtype=torch.bfloat16,
57
+ bnb_4bit_quant_type="nf4",
58
+ bnb_4bit_use_double_quant=True,
59
+ )
60
+
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ base_model,
63
+ quantization_config=bnb_config,
64
+ device_map="auto",
65
+ trust_remote_code=True,
66
+ torch_dtype=torch.bfloat16,
67
+ )
68
+ model = prepare_model_for_kbit_training(model)
69
+
70
+ log(f"[4/6] Applying LoRA (r={lora_r})...")
71
+ lora_config = LoraConfig(
72
+ r=lora_r,
73
+ lora_alpha=lora_r * 2,
74
+ lora_dropout=0.05,
75
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
76
+ bias="none",
77
+ task_type="CAUSAL_LM",
78
+ )
79
+ model = get_peft_model(model, lora_config)
80
+
81
+ progress(0.3, desc="Loading dataset...")
82
+ log(f"[5/6] Loading dataset: {dataset_id}")
83
+ dataset = load_dataset(dataset_id, split="train")
84
+
85
+ def format_example(example):
86
+ if "instruction" in example and "response" in example:
87
+ return {"text": f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"}
88
+ elif "text" in example:
89
+ return {"text": example["text"]}
90
+ else:
91
+ return {"text": " ".join(str(v) for v in example.values() if isinstance(v, str))}
92
+
93
+ dataset = dataset.map(format_example)
94
+ log(f"Dataset size: {len(dataset)} examples")
95
+
96
+ progress(0.4, desc="Setting up trainer...")
97
+ log(f"[6/6] Starting training: {epochs} epochs, batch={batch_size}, lr={learning_rate}")
98
+
99
+ training_args = TrainingArguments(
100
+ output_dir="./outputs",
101
+ num_train_epochs=epochs,
102
+ per_device_train_batch_size=batch_size,
103
+ gradient_accumulation_steps=4,
104
+ learning_rate=learning_rate,
105
+ lr_scheduler_type="cosine",
106
+ warmup_ratio=0.1,
107
+ logging_steps=10,
108
+ save_steps=100,
109
+ bf16=True,
110
+ gradient_checkpointing=True,
111
+ push_to_hub=True,
112
+ hub_model_id=output_repo,
113
+ hub_token=os.environ.get("HF_TOKEN"),
114
+ )
115
+
116
+ trainer = SFTTrainer(
117
+ model=model,
118
+ args=training_args,
119
+ train_dataset=dataset,
120
+ tokenizer=tokenizer,
121
+ max_seq_length=4096,
122
+ dataset_text_field="text",
123
+ )
124
+
125
+ log("\n" + "=" * 50)
126
+ log("TRAINING STARTED")
127
+ log("=" * 50)
128
+
129
+ trainer.train()
130
+
131
+ progress(0.95, desc="Pushing to Hub...")
132
+ log("\nPushing model to Hub...")
133
+ trainer.push_to_hub()
134
+
135
+ progress(1.0, desc="Complete!")
136
+ log("\n" + "=" * 50)
137
+ log("TRAINING COMPLETE!")
138
+ log(f"Model saved to: https://huggingface.co/{output_repo}")
139
+ log("=" * 50)
140
+
141
+ training_status["running"] = False
142
+ return training_status["log"]
143
+
144
+ except Exception as e:
145
+ log(f"\nERROR: {str(e)}")
146
+ import traceback
147
+ log(traceback.format_exc())
148
+ training_status["running"] = False
149
+ return training_status["log"]
150
 
151
+ with gr.Blocks(title="Agent Zero Trainer") as demo:
152
+ gr.Markdown("""
153
+ # Agent Zero Music Workflow Trainer
154
+ **Intuition Labs** • terminals.tech
155
+
156
+ Fine-tune models for coherent multi-context orchestration.
157
+ """)
158
+
159
+ with gr.Row():
160
+ with gr.Column():
161
+ base_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="Base Model")
162
+ dataset_id = gr.Textbox(value="wheattoast11/agent-zero-training-data", label="Dataset ID")
163
+ epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
164
+ batch_size = gr.Slider(1, 8, value=2, step=1, label="Batch Size")
165
+ learning_rate = gr.Number(value=2e-5, label="Learning Rate")
166
+ lora_r = gr.Slider(8, 64, value=16, step=8, label="LoRA Rank")
167
+ output_repo = gr.Textbox(value="wheattoast11/agent-zero-music-workflow", label="Output Repo")
168
+ submit_btn = gr.Button("Start Training", variant="primary")
169
+
170
+ with gr.Column():
171
+ output = gr.Textbox(label="Training Log", lines=25, max_lines=50)
172
+
173
+ submit_btn.click(
174
+ fn=run_training,
175
+ inputs=[base_model, dataset_id, epochs, batch_size, learning_rate, lora_r, output_repo],
176
+ outputs=output,
177
+ )
178
 
179
  if __name__ == "__main__":
180
+ demo.launch()
requirements.txt CHANGED
@@ -1,2 +1,9 @@
1
  gradio>=4.0.0
2
  huggingface_hub>=0.20.0
 
 
 
 
 
 
 
 
1
  gradio>=4.0.0
2
  huggingface_hub>=0.20.0
3
+ torch
4
+ transformers>=4.40.0
5
+ trl>=0.8.0
6
+ peft>=0.10.0
7
+ datasets>=2.18.0
8
+ accelerate>=0.27.0
9
+ bitsandbytes>=0.43.0