Hanzo Dev commited on
Commit
5c68696
·
1 Parent(s): 9785769

Simplify to single-button training (no config bugs)

Browse files
Files changed (2) hide show
  1. app.py +76 -323
  2. app_complex.py +385 -0
app.py CHANGED
@@ -1,385 +1,138 @@
1
  """
2
- Zen Training Space - Unified Training for All Zen Models
3
- Train any Zen model with any dataset combination from HuggingFace
4
  """
5
 
6
- import os
7
  import gradio as gr
8
  import torch
9
- from transformers import AutoModel, AutoTokenizer, AutoProcessor, TrainingArguments, Trainer
10
- from datasets import load_dataset, concatenate_datasets
11
- import json
12
- from typing import List, Dict
13
 
14
- # Model configurations
15
- MODELS = {
16
- "Language Models": {
17
- "zen-nano-0.6b": {
18
- "hf_id": "zenlm/zen-nano-0.6b",
19
- "type": "language",
20
- "size": "0.6B",
21
- "context": "32K"
22
- },
23
- "zen-eco-4b-instruct": {
24
- "hf_id": "zenlm/zen-eco-4b-instruct",
25
- "type": "language",
26
- "size": "4B",
27
- "context": "32K"
28
- },
29
- "zen-eco-4b-agent": {
30
- "hf_id": "zenlm/zen-eco-4b-agent",
31
- "type": "language",
32
- "size": "4B",
33
- "context": "32K"
34
- },
35
- "zen-omni-7b": {
36
- "hf_id": "zenlm/zen-omni-7b",
37
- "type": "language",
38
- "size": "7B",
39
- "context": "32K"
40
- },
41
- "zen-coder-14b": {
42
- "hf_id": "zenlm/zen-coder-14b",
43
- "type": "language",
44
- "size": "14B",
45
- "context": "128K"
46
- },
47
- "zen-next-32b": {
48
- "hf_id": "zenlm/zen-next-32b",
49
- "type": "language",
50
- "size": "32B",
51
- "context": "32K"
52
- },
53
- },
54
- "Vision-Language Models": {
55
- "zen-vl-4b-instruct": {
56
- "hf_id": "zenlm/zen-vl-4b-instruct",
57
- "type": "vision-language",
58
- "size": "4B",
59
- "context": "32K"
60
- },
61
- "zen-vl-8b-instruct": {
62
- "hf_id": "zenlm/zen-vl-8b-instruct",
63
- "type": "vision-language",
64
- "size": "8B",
65
- "context": "32K"
66
- },
67
- "zen-vl-30b-instruct": {
68
- "hf_id": "zenlm/zen-vl-30b-instruct",
69
- "type": "vision-language",
70
- "size": "30B",
71
- "context": "32K"
72
- },
73
- }
74
- }
75
-
76
- # Dataset configurations
77
- DATASETS = {
78
- "Agent Training": {
79
- "ADP - AgentTuning OS": {
80
- "hf_id": "neulab/agent-data-collection",
81
- "config": "agenttuning_os",
82
- "size": "~5k samples"
83
- },
84
- "ADP - AgentTuning KG": {
85
- "hf_id": "neulab/agent-data-collection",
86
- "config": "agenttuning_kg",
87
- "size": "~5k samples"
88
- },
89
- "ADP - AgentTuning DB": {
90
- "hf_id": "neulab/agent-data-collection",
91
- "config": "agenttuning_db",
92
- "size": "~5k samples"
93
- },
94
- "ADP - Synatra": {
95
- "hf_id": "neulab/agent-data-collection",
96
- "config": "synatra",
97
- "size": "99k samples"
98
- },
99
- "ADP - Code Feedback": {
100
- "hf_id": "neulab/agent-data-collection",
101
- "config": "code_feedback",
102
- "size": "66k samples"
103
- },
104
- "ADP - Go Browse": {
105
- "hf_id": "neulab/agent-data-collection",
106
- "config": "go-browse-wa",
107
- "size": "27k samples"
108
- },
109
- },
110
- "Function Calling": {
111
- "xLAM Function Calling 60k": {
112
- "hf_id": "Salesforce/xlam-function-calling-60k",
113
- "config": None,
114
- "size": "60k samples"
115
- },
116
- },
117
- "Instruction Tuning": {
118
- "Alpaca": {
119
- "hf_id": "tatsu-lab/alpaca",
120
- "config": None,
121
- "size": "52k samples"
122
- },
123
- }
124
- }
125
-
126
- def train_model(
127
- model_name: str,
128
- selected_datasets: List[str],
129
- max_samples: int,
130
- epochs: int,
131
- batch_size: int,
132
- learning_rate: float,
133
- output_repo: str
134
- ):
135
- """Main training function"""
136
 
137
  try:
138
- logs = []
139
-
140
- def log(msg):
141
- print(msg)
142
- logs.append(msg)
143
- yield "\n".join(logs)
144
-
145
- yield from log("=" * 80)
146
- yield from log("🧘 ZEN TRAINING SPACE")
147
  yield from log("=" * 80)
148
- yield from log("")
149
 
150
- # GPU info
151
- yield from log(f"🎮 GPU Available: {torch.cuda.is_available()}")
152
- if torch.cuda.is_available():
153
- yield from log(f" Device: {torch.cuda.get_device_name(0)}")
154
- yield from log(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
155
- yield from log("")
156
-
157
- # Find model config
158
- # Handle both "Category / ModelName" and "ModelName" formats
159
- if " / " in model_name:
160
- model_short_name = model_name.split(" / ")[1]
161
- else:
162
- model_short_name = model_name
163
-
164
- model_config = None
165
- for category in MODELS.values():
166
- if model_short_name in category:
167
- model_config = category[model_short_name]
168
- break
169
-
170
- if not model_config:
171
- yield from log(f"❌ Model {model_short_name} not found")
172
- return
173
-
174
- yield from log(f"📦 Loading model: {model_short_name}")
175
- yield from log(f" HF ID: {model_config['hf_id']}")
176
- yield from log(f" Size: {model_config['size']}")
177
- yield from log(f" Type: {model_config['type']}")
178
 
179
  # Load model
 
180
  model = AutoModel.from_pretrained(
181
- model_config['hf_id'],
182
  torch_dtype=torch.bfloat16,
183
  device_map="auto",
184
  trust_remote_code=True
185
  )
186
-
187
- if model_config['type'] == "vision-language":
188
- processor = AutoProcessor.from_pretrained(model_config['hf_id'])
189
- else:
190
- processor = AutoTokenizer.from_pretrained(model_config['hf_id'])
191
-
192
  yield from log("✅ Model loaded")
193
- yield from log("")
194
 
195
  # Load datasets
196
- yield from log("📚 Loading datasets...")
197
- all_datasets = []
198
-
199
- for dataset_name in selected_datasets:
200
- # Handle both "Category / DatasetName" and "DatasetName" formats
201
- if " / " in dataset_name:
202
- dataset_short_name = dataset_name.split(" / ", 1)[1]
203
- else:
204
- dataset_short_name = dataset_name
205
-
206
- # Find dataset config
207
- dataset_config = None
208
- for category in DATASETS.values():
209
- if dataset_short_name in category:
210
- dataset_config = category[dataset_short_name]
211
- break
212
-
213
- if not dataset_config:
214
- yield from log(f"⚠️ Dataset {dataset_short_name} not found, skipping")
215
- continue
216
-
217
- yield from log(f" Loading: {dataset_name}")
218
- yield from log(f" HF ID: {dataset_config['hf_id']}")
219
-
220
  try:
221
- if dataset_config['config']:
222
- ds = load_dataset(
223
- dataset_config['hf_id'],
224
- dataset_config['config'],
225
- split="train",
226
- streaming=True
227
- )
228
  else:
229
- ds = load_dataset(
230
- dataset_config['hf_id'],
231
- split="train",
232
- streaming=True
233
- )
234
 
235
- # Take limited samples
236
  samples = []
237
  for i, example in enumerate(ds):
238
- if i >= max_samples // len(selected_datasets):
239
  break
240
  samples.append(example)
241
 
242
- all_datasets.extend(samples)
243
- yield from log(f" Loaded {len(samples)} samples")
244
-
245
  except Exception as e:
246
- yield from log(f"Error: {e}")
247
 
248
- yield from log(f"\n✅ Total samples loaded: {len(all_datasets)}")
249
- yield from log("")
250
 
251
- # Training setup
252
- yield from log("⚙️ Training Configuration:")
253
- yield from log(f" Epochs: {epochs}")
254
- yield from log(f" Batch Size: {batch_size}")
255
- yield from log(f" Learning Rate: {learning_rate}")
256
- yield from log(f" Samples: {len(all_datasets)}")
257
- yield from log(f" Output: {output_repo}")
258
- yield from log("")
259
 
260
  training_args = TrainingArguments(
261
- output_dir="./training-output",
262
- num_train_epochs=epochs,
263
- per_device_train_batch_size=batch_size,
264
- learning_rate=learning_rate,
265
  logging_steps=10,
266
- save_steps=100,
267
  bf16=True,
268
  push_to_hub=True,
269
- hub_model_id=output_repo,
270
- report_to="tensorboard",
271
  )
272
 
273
- # Create trainer
274
  trainer = Trainer(
275
  model=model,
276
  args=training_args,
277
- train_dataset=all_datasets if len(all_datasets) > 0 else None,
278
  )
279
 
280
- # Train!
281
- yield from log("🔥 TRAINING STARTED")
282
  yield from log("=" * 80)
283
 
284
  result = trainer.train()
285
 
286
- yield from log("")
287
- yield from log("=" * 80)
288
- yield from log("✅ TRAINING COMPLETED!")
289
- yield from log("=" * 80)
290
  yield from log(f"📊 Final Loss: {result.training_loss:.4f}")
291
- yield from log(f"☁️ Model uploaded to: {output_repo}")
292
- yield from log("")
293
- yield from log("🎉 SUCCESS!")
 
 
294
 
295
  except Exception as e:
296
  yield from log(f"\n❌ ERROR: {str(e)}")
297
  import traceback
298
  yield from log(f"\n{traceback.format_exc()}")
299
 
300
- # Build Gradio Interface
301
- with gr.Blocks(title="Zen Training Space", theme=gr.themes.Soft()) as demo:
302
  gr.Markdown("""
303
- # 🧘 Zen Training Space
304
- ### Unified Training Platform for All Zen Models
305
 
306
- Train any Zen model with any dataset combination from HuggingFace.
307
- All datasets are loaded directly from HF - no local storage needed!
308
- """)
309
 
310
- with gr.Row():
311
- with gr.Column(scale=1):
312
- gr.Markdown("### 1. Select Model")
313
-
314
- model_choice = gr.Dropdown(
315
- choices=[
316
- *[f"{cat} / {model}" for cat in MODELS for model in MODELS[cat]]
317
- ],
318
- label="Model",
319
- value="Vision-Language Models / zen-vl-4b-instruct"
320
- )
321
-
322
- gr.Markdown("### 2. Select Datasets")
323
-
324
- dataset_choices = gr.CheckboxGroup(
325
- choices=[
326
- *[f"{cat} / {ds}" for cat in DATASETS for ds in DATASETS[cat]]
327
- ],
328
- label="Datasets",
329
- value=[
330
- "Agent Training / ADP - Synatra",
331
- "Function Calling / xLAM Function Calling 60k"
332
- ]
333
- )
334
-
335
- gr.Markdown("### 3. Training Config")
336
-
337
- max_samples = gr.Slider(100, 100000, value=10000, step=100, label="Max Samples")
338
- epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
339
- batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
340
- learning_rate = gr.Number(value=2e-5, label="Learning Rate")
341
-
342
- output_repo = gr.Textbox(
343
- value="zenlm/zen-vl-4b-agent-custom",
344
- label="Output Repository (HuggingFace)"
345
- )
346
-
347
- train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
348
-
349
- with gr.Column(scale=2):
350
- gr.Markdown("### Training Logs")
351
- output = gr.Textbox(label="", lines=35, max_lines=50, show_label=False)
352
-
353
- train_btn.click(
354
- train_model,
355
- inputs=[
356
- model_choice,
357
- dataset_choices,
358
- max_samples,
359
- epochs,
360
- batch_size,
361
- learning_rate,
362
- output_repo
363
- ],
364
- outputs=output
365
- )
366
-
367
- gr.Markdown("""
368
- ---
369
- ### 📊 Available Models
370
- - **Language**: nano (0.6B), eco (4B), omni (7B), coder (14B), next (32B)
371
- - **Vision-Language**: zen-vl (4B, 8B, 30B)
372
 
373
- ### 📚 Available Datasets
374
- - **Agent Training**: ADP (220k+ trajectories across 15+ configs)
375
- - **Function Calling**: xLAM (60k high-quality examples)
376
- - **Instruction**: Alpaca (52k samples)
377
 
378
- ### 💰 Cost Estimates (HF Pro GPU)
379
- - 4B model: $3-5 for 10k samples
380
- - 8B model: $8-12 for 10k samples
381
- - 32B model: $30-50 for 10k samples
382
- """)
383
 
384
  if __name__ == "__main__":
385
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  """
2
+ Zen VL Training - Simplified & Working
3
+ Just trains zen-vl-4b with our datasets
4
  """
5
 
 
6
  import gradio as gr
7
  import torch
8
+ from transformers import AutoModel, AutoProcessor, TrainingArguments, Trainer
9
+ from datasets import load_dataset
 
 
10
 
11
+ def train_zen_vl():
12
+ """Simple one-button training for zen-vl-4b"""
13
+
14
+ logs = []
15
+
16
+ def log(msg):
17
+ print(msg)
18
+ logs.append(msg)
19
+ yield "\n".join(logs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  try:
22
+ yield from log("🧘 Starting Zen VL 4B Training")
 
 
 
 
 
 
 
 
23
  yield from log("=" * 80)
 
24
 
25
+ # GPU check
26
+ has_gpu = torch.cuda.is_available()
27
+ yield from log(f"🎮 GPU: {has_gpu}")
28
+ if has_gpu:
29
+ yield from log(f" {torch.cuda.get_device_name(0)}")
30
+ yield from log(f" {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Load model
33
+ yield from log("\n📦 Loading zen-vl-4b-instruct...")
34
  model = AutoModel.from_pretrained(
35
+ "zenlm/zen-vl-4b-instruct",
36
  torch_dtype=torch.bfloat16,
37
  device_map="auto",
38
  trust_remote_code=True
39
  )
40
+ processor = AutoProcessor.from_pretrained("zenlm/zen-vl-4b-instruct")
 
 
 
 
 
41
  yield from log("✅ Model loaded")
 
42
 
43
  # Load datasets
44
+ yield from log("\n📚 Loading datasets...")
45
+ all_data = []
46
+
47
+ datasets_to_load = [
48
+ ("ADP Synatra", "neulab/agent-data-collection", "synatra", 7500),
49
+ ("ADP Code Feedback", "neulab/agent-data-collection", "code_feedback", 7500),
50
+ ("ADP Go Browse", "neulab/agent-data-collection", "go-browse-wa", 7500),
51
+ ("xLAM Function Calling", "Salesforce/xlam-function-calling-60k", None, 7500)
52
+ ]
53
+
54
+ for name, hf_id, config, max_samples in datasets_to_load:
55
+ yield from log(f" Loading {name}...")
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
+ if config:
58
+ ds = load_dataset(hf_id, config, split="train", streaming=True)
 
 
 
 
 
59
  else:
60
+ ds = load_dataset(hf_id, split="train", streaming=True)
 
 
 
 
61
 
 
62
  samples = []
63
  for i, example in enumerate(ds):
64
+ if i >= max_samples:
65
  break
66
  samples.append(example)
67
 
68
+ all_data.extend(samples)
69
+ yield from log(f" ✅ {len(samples)} samples")
 
70
  except Exception as e:
71
+ yield from log(f" ⚠️ Error: {e}")
72
 
73
+ yield from log(f"\n✅ Total: {len(all_data)} samples")
 
74
 
75
+ # Training
76
+ yield from log("\n⚙️ Training Configuration:")
77
+ yield from log(" Epochs: 3")
78
+ yield from log(" Batch Size: 1")
79
+ yield from log(" Learning Rate: 2e-5")
80
+ yield from log(" Output: zenlm/zen-vl-4b-agent")
 
 
81
 
82
  training_args = TrainingArguments(
83
+ output_dir="./zen-vl-output",
84
+ num_train_epochs=3,
85
+ per_device_train_batch_size=1,
86
+ learning_rate=2e-5,
87
  logging_steps=10,
88
+ save_steps=500,
89
  bf16=True,
90
  push_to_hub=True,
91
+ hub_model_id="zenlm/zen-vl-4b-agent",
92
+ report_to="none",
93
  )
94
 
 
95
  trainer = Trainer(
96
  model=model,
97
  args=training_args,
98
+ train_dataset=all_data if len(all_data) > 0 else None,
99
  )
100
 
101
+ yield from log("\n🔥 TRAINING STARTED")
 
102
  yield from log("=" * 80)
103
 
104
  result = trainer.train()
105
 
106
+ yield from log("\n✅ TRAINING COMPLETED!")
 
 
 
107
  yield from log(f"📊 Final Loss: {result.training_loss:.4f}")
108
+ yield from log("☁️ Uploading to zenlm/zen-vl-4b-agent...")
109
+
110
+ trainer.push_to_hub()
111
+
112
+ yield from log("\n🎉 SUCCESS! Model live at zenlm/zen-vl-4b-agent")
113
 
114
  except Exception as e:
115
  yield from log(f"\n❌ ERROR: {str(e)}")
116
  import traceback
117
  yield from log(f"\n{traceback.format_exc()}")
118
 
119
+ # Simple interface
120
+ with gr.Blocks(title="Zen VL Training") as demo:
121
  gr.Markdown("""
122
+ # 🧘 Zen VL 4B Training
 
123
 
124
+ Trains zen-vl-4b-instruct zen-vl-4b-agent
 
 
125
 
126
+ **Datasets**: ADP (Synatra, Code Feedback, Go Browse) + xLAM (60k)
127
+ **Total**: ~30k samples
128
+ **Time**: ~6-8 hours on A10G
129
+ **Output**: zenlm/zen-vl-4b-agent
130
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ start_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
133
+ output = gr.Textbox(label="Training Logs", lines=30)
 
 
134
 
135
+ start_btn.click(train_zen_vl, outputs=output)
 
 
 
 
136
 
137
  if __name__ == "__main__":
138
  demo.launch(server_name="0.0.0.0", server_port=7860)
app_complex.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Zen Training Space - Unified Training for All Zen Models
3
+ Train any Zen model with any dataset combination from HuggingFace
4
+ """
5
+
6
+ import os
7
+ import gradio as gr
8
+ import torch
9
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor, TrainingArguments, Trainer
10
+ from datasets import load_dataset, concatenate_datasets
11
+ import json
12
+ from typing import List, Dict
13
+
14
+ # Model configurations
15
+ MODELS = {
16
+ "Language Models": {
17
+ "zen-nano-0.6b": {
18
+ "hf_id": "zenlm/zen-nano-0.6b",
19
+ "type": "language",
20
+ "size": "0.6B",
21
+ "context": "32K"
22
+ },
23
+ "zen-eco-4b-instruct": {
24
+ "hf_id": "zenlm/zen-eco-4b-instruct",
25
+ "type": "language",
26
+ "size": "4B",
27
+ "context": "32K"
28
+ },
29
+ "zen-eco-4b-agent": {
30
+ "hf_id": "zenlm/zen-eco-4b-agent",
31
+ "type": "language",
32
+ "size": "4B",
33
+ "context": "32K"
34
+ },
35
+ "zen-omni-7b": {
36
+ "hf_id": "zenlm/zen-omni-7b",
37
+ "type": "language",
38
+ "size": "7B",
39
+ "context": "32K"
40
+ },
41
+ "zen-coder-14b": {
42
+ "hf_id": "zenlm/zen-coder-14b",
43
+ "type": "language",
44
+ "size": "14B",
45
+ "context": "128K"
46
+ },
47
+ "zen-next-32b": {
48
+ "hf_id": "zenlm/zen-next-32b",
49
+ "type": "language",
50
+ "size": "32B",
51
+ "context": "32K"
52
+ },
53
+ },
54
+ "Vision-Language Models": {
55
+ "zen-vl-4b-instruct": {
56
+ "hf_id": "zenlm/zen-vl-4b-instruct",
57
+ "type": "vision-language",
58
+ "size": "4B",
59
+ "context": "32K"
60
+ },
61
+ "zen-vl-8b-instruct": {
62
+ "hf_id": "zenlm/zen-vl-8b-instruct",
63
+ "type": "vision-language",
64
+ "size": "8B",
65
+ "context": "32K"
66
+ },
67
+ "zen-vl-30b-instruct": {
68
+ "hf_id": "zenlm/zen-vl-30b-instruct",
69
+ "type": "vision-language",
70
+ "size": "30B",
71
+ "context": "32K"
72
+ },
73
+ }
74
+ }
75
+
76
+ # Dataset configurations
77
+ DATASETS = {
78
+ "Agent Training": {
79
+ "ADP - AgentTuning OS": {
80
+ "hf_id": "neulab/agent-data-collection",
81
+ "config": "agenttuning_os",
82
+ "size": "~5k samples"
83
+ },
84
+ "ADP - AgentTuning KG": {
85
+ "hf_id": "neulab/agent-data-collection",
86
+ "config": "agenttuning_kg",
87
+ "size": "~5k samples"
88
+ },
89
+ "ADP - AgentTuning DB": {
90
+ "hf_id": "neulab/agent-data-collection",
91
+ "config": "agenttuning_db",
92
+ "size": "~5k samples"
93
+ },
94
+ "ADP - Synatra": {
95
+ "hf_id": "neulab/agent-data-collection",
96
+ "config": "synatra",
97
+ "size": "99k samples"
98
+ },
99
+ "ADP - Code Feedback": {
100
+ "hf_id": "neulab/agent-data-collection",
101
+ "config": "code_feedback",
102
+ "size": "66k samples"
103
+ },
104
+ "ADP - Go Browse": {
105
+ "hf_id": "neulab/agent-data-collection",
106
+ "config": "go-browse-wa",
107
+ "size": "27k samples"
108
+ },
109
+ },
110
+ "Function Calling": {
111
+ "xLAM Function Calling 60k": {
112
+ "hf_id": "Salesforce/xlam-function-calling-60k",
113
+ "config": None,
114
+ "size": "60k samples"
115
+ },
116
+ },
117
+ "Instruction Tuning": {
118
+ "Alpaca": {
119
+ "hf_id": "tatsu-lab/alpaca",
120
+ "config": None,
121
+ "size": "52k samples"
122
+ },
123
+ }
124
+ }
125
+
126
+ def train_model(
127
+ model_name: str,
128
+ selected_datasets: List[str],
129
+ max_samples: int,
130
+ epochs: int,
131
+ batch_size: int,
132
+ learning_rate: float,
133
+ output_repo: str
134
+ ):
135
+ """Main training function"""
136
+
137
+ try:
138
+ logs = []
139
+
140
+ def log(msg):
141
+ print(msg)
142
+ logs.append(msg)
143
+ yield "\n".join(logs)
144
+
145
+ yield from log("=" * 80)
146
+ yield from log("🧘 ZEN TRAINING SPACE")
147
+ yield from log("=" * 80)
148
+ yield from log("")
149
+
150
+ # GPU info
151
+ yield from log(f"🎮 GPU Available: {torch.cuda.is_available()}")
152
+ if torch.cuda.is_available():
153
+ yield from log(f" Device: {torch.cuda.get_device_name(0)}")
154
+ yield from log(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
155
+ yield from log("")
156
+
157
+ # Find model config
158
+ # Handle both "Category / ModelName" and "ModelName" formats
159
+ if " / " in model_name:
160
+ model_short_name = model_name.split(" / ")[1]
161
+ else:
162
+ model_short_name = model_name
163
+
164
+ model_config = None
165
+ for category in MODELS.values():
166
+ if model_short_name in category:
167
+ model_config = category[model_short_name]
168
+ break
169
+
170
+ if not model_config:
171
+ yield from log(f"❌ Model {model_short_name} not found")
172
+ return
173
+
174
+ yield from log(f"📦 Loading model: {model_short_name}")
175
+ yield from log(f" HF ID: {model_config['hf_id']}")
176
+ yield from log(f" Size: {model_config['size']}")
177
+ yield from log(f" Type: {model_config['type']}")
178
+
179
+ # Load model
180
+ model = AutoModel.from_pretrained(
181
+ model_config['hf_id'],
182
+ torch_dtype=torch.bfloat16,
183
+ device_map="auto",
184
+ trust_remote_code=True
185
+ )
186
+
187
+ if model_config['type'] == "vision-language":
188
+ processor = AutoProcessor.from_pretrained(model_config['hf_id'])
189
+ else:
190
+ processor = AutoTokenizer.from_pretrained(model_config['hf_id'])
191
+
192
+ yield from log("✅ Model loaded")
193
+ yield from log("")
194
+
195
+ # Load datasets
196
+ yield from log("📚 Loading datasets...")
197
+ all_datasets = []
198
+
199
+ for dataset_name in selected_datasets:
200
+ # Handle both "Category / DatasetName" and "DatasetName" formats
201
+ if " / " in dataset_name:
202
+ dataset_short_name = dataset_name.split(" / ", 1)[1]
203
+ else:
204
+ dataset_short_name = dataset_name
205
+
206
+ # Find dataset config
207
+ dataset_config = None
208
+ for category in DATASETS.values():
209
+ if dataset_short_name in category:
210
+ dataset_config = category[dataset_short_name]
211
+ break
212
+
213
+ if not dataset_config:
214
+ yield from log(f"⚠️ Dataset {dataset_short_name} not found, skipping")
215
+ continue
216
+
217
+ yield from log(f" Loading: {dataset_name}")
218
+ yield from log(f" HF ID: {dataset_config['hf_id']}")
219
+
220
+ try:
221
+ if dataset_config['config']:
222
+ ds = load_dataset(
223
+ dataset_config['hf_id'],
224
+ dataset_config['config'],
225
+ split="train",
226
+ streaming=True
227
+ )
228
+ else:
229
+ ds = load_dataset(
230
+ dataset_config['hf_id'],
231
+ split="train",
232
+ streaming=True
233
+ )
234
+
235
+ # Take limited samples
236
+ samples = []
237
+ for i, example in enumerate(ds):
238
+ if i >= max_samples // len(selected_datasets):
239
+ break
240
+ samples.append(example)
241
+
242
+ all_datasets.extend(samples)
243
+ yield from log(f" ✅ Loaded {len(samples)} samples")
244
+
245
+ except Exception as e:
246
+ yield from log(f" ❌ Error: {e}")
247
+
248
+ yield from log(f"\n✅ Total samples loaded: {len(all_datasets)}")
249
+ yield from log("")
250
+
251
+ # Training setup
252
+ yield from log("⚙️ Training Configuration:")
253
+ yield from log(f" Epochs: {epochs}")
254
+ yield from log(f" Batch Size: {batch_size}")
255
+ yield from log(f" Learning Rate: {learning_rate}")
256
+ yield from log(f" Samples: {len(all_datasets)}")
257
+ yield from log(f" Output: {output_repo}")
258
+ yield from log("")
259
+
260
+ training_args = TrainingArguments(
261
+ output_dir="./training-output",
262
+ num_train_epochs=epochs,
263
+ per_device_train_batch_size=batch_size,
264
+ learning_rate=learning_rate,
265
+ logging_steps=10,
266
+ save_steps=100,
267
+ bf16=True,
268
+ push_to_hub=True,
269
+ hub_model_id=output_repo,
270
+ report_to="tensorboard",
271
+ )
272
+
273
+ # Create trainer
274
+ trainer = Trainer(
275
+ model=model,
276
+ args=training_args,
277
+ train_dataset=all_datasets if len(all_datasets) > 0 else None,
278
+ )
279
+
280
+ # Train!
281
+ yield from log("🔥 TRAINING STARTED")
282
+ yield from log("=" * 80)
283
+
284
+ result = trainer.train()
285
+
286
+ yield from log("")
287
+ yield from log("=" * 80)
288
+ yield from log("✅ TRAINING COMPLETED!")
289
+ yield from log("=" * 80)
290
+ yield from log(f"📊 Final Loss: {result.training_loss:.4f}")
291
+ yield from log(f"☁️ Model uploaded to: {output_repo}")
292
+ yield from log("")
293
+ yield from log("🎉 SUCCESS!")
294
+
295
+ except Exception as e:
296
+ yield from log(f"\n❌ ERROR: {str(e)}")
297
+ import traceback
298
+ yield from log(f"\n{traceback.format_exc()}")
299
+
300
+ # Build Gradio Interface
301
+ with gr.Blocks(title="Zen Training Space", theme=gr.themes.Soft()) as demo:
302
+ gr.Markdown("""
303
+ # 🧘 Zen Training Space
304
+ ### Unified Training Platform for All Zen Models
305
+
306
+ Train any Zen model with any dataset combination from HuggingFace.
307
+ All datasets are loaded directly from HF - no local storage needed!
308
+ """)
309
+
310
+ with gr.Row():
311
+ with gr.Column(scale=1):
312
+ gr.Markdown("### 1. Select Model")
313
+
314
+ model_choice = gr.Dropdown(
315
+ choices=[
316
+ *[f"{cat} / {model}" for cat in MODELS for model in MODELS[cat]]
317
+ ],
318
+ label="Model",
319
+ value="Vision-Language Models / zen-vl-4b-instruct"
320
+ )
321
+
322
+ gr.Markdown("### 2. Select Datasets")
323
+
324
+ dataset_choices = gr.CheckboxGroup(
325
+ choices=[
326
+ *[f"{cat} / {ds}" for cat in DATASETS for ds in DATASETS[cat]]
327
+ ],
328
+ label="Datasets",
329
+ value=[
330
+ "Agent Training / ADP - Synatra",
331
+ "Function Calling / xLAM Function Calling 60k"
332
+ ]
333
+ )
334
+
335
+ gr.Markdown("### 3. Training Config")
336
+
337
+ max_samples = gr.Slider(100, 100000, value=10000, step=100, label="Max Samples")
338
+ epochs = gr.Slider(1, 10, value=3, step=1, label="Epochs")
339
+ batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
340
+ learning_rate = gr.Number(value=2e-5, label="Learning Rate")
341
+
342
+ output_repo = gr.Textbox(
343
+ value="zenlm/zen-vl-4b-agent-custom",
344
+ label="Output Repository (HuggingFace)"
345
+ )
346
+
347
+ train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
348
+
349
+ with gr.Column(scale=2):
350
+ gr.Markdown("### Training Logs")
351
+ output = gr.Textbox(label="", lines=35, max_lines=50, show_label=False)
352
+
353
+ train_btn.click(
354
+ train_model,
355
+ inputs=[
356
+ model_choice,
357
+ dataset_choices,
358
+ max_samples,
359
+ epochs,
360
+ batch_size,
361
+ learning_rate,
362
+ output_repo
363
+ ],
364
+ outputs=output
365
+ )
366
+
367
+ gr.Markdown("""
368
+ ---
369
+ ### 📊 Available Models
370
+ - **Language**: nano (0.6B), eco (4B), omni (7B), coder (14B), next (32B)
371
+ - **Vision-Language**: zen-vl (4B, 8B, 30B)
372
+
373
+ ### 📚 Available Datasets
374
+ - **Agent Training**: ADP (220k+ trajectories across 15+ configs)
375
+ - **Function Calling**: xLAM (60k high-quality examples)
376
+ - **Instruction**: Alpaca (52k samples)
377
+
378
+ ### 💰 Cost Estimates (HF Pro GPU)
379
+ - 4B model: $3-5 for 10k samples
380
+ - 8B model: $8-12 for 10k samples
381
+ - 32B model: $30-50 for 10k samples
382
+ """)
383
+
384
+ if __name__ == "__main__":
385
+ demo.launch(server_name="0.0.0.0", server_port=7860)