zeekay commited on
Commit
1009e20
·
verified ·
1 Parent(s): e5ab064

Add training script: train_zen4_ultra.py

Browse files
Files changed (1) hide show
  1. training/train_zen4_ultra.py +372 -0
training/train_zen4_ultra.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train zen4-ultra — QLoRA uncensoring for Kimi K2.5 (1.04T MoE)
4
+
5
+ Standard linear abliteration FAILS on K2.5's MoE architecture because refusal
6
+ is encoded in expert routing (which 384 experts fire), not just the residual stream.
7
+ See: hamsaOmar/Kimi-K2.5-abliterated
8
+
9
+ This script uses QLoRA fine-tuning which DOES work because backpropagation modifies
10
+ all weights including the router/gate. Key innovations:
11
+ 1. LoRA on attention + shared experts
12
+ 2. Gate/router weights unfrozen for direct gradient updates
13
+ 3. Uncensored instruction data to override safety training
14
+ 4. DPO mode for preference-based training (optional)
15
+
16
+ Architecture (DeepseekV3):
17
+ - 61 layers, 384 routed experts (top-8), 1 shared expert
18
+ - Hidden: 7168, MoE intermediate: 2048
19
+ - Compressed KV: kv_lora_rank=512, q_lora_rank=1536
20
+ - Gate: nn.Parameter (not nn.Linear) — requires unfreeze, not LoRA
21
+
22
+ Requirements:
23
+ - 4x A100 80GB or 8x H200 (INT4 quantized ~280GB)
24
+ - pip install transformers peft bitsandbytes datasets trl accelerate
25
+
26
+ Usage:
27
+ # SFT mode (uncensored instruction following)
28
+ python train_zen4_ultra.py --mode sft --dataset cognitivecomputations/dolphin-r1
29
+
30
+ # DPO mode (preference optimization)
31
+ python train_zen4_ultra.py --mode dpo --dataset argilla/ultrafeedback-binarized-preferences
32
+
33
+ # Custom local data
34
+ python train_zen4_ultra.py --mode sft --dataset ./data/uncensored.jsonl
35
+
36
+ # Multi-GPU
37
+ torchrun --nproc_per_node 4 train_zen4_ultra.py --mode sft
38
+ """
39
+
40
+ import argparse
41
+ import json
42
+ import os
43
+ import torch
44
+ from pathlib import Path
45
+
46
+ from transformers import (
47
+ AutoModelForCausalLM,
48
+ AutoTokenizer,
49
+ TrainingArguments,
50
+ BitsAndBytesConfig,
51
+ )
52
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
53
+ from datasets import load_dataset, Dataset
54
+
55
+
56
+ BASE_MODEL = "moonshotai/Kimi-K2.5"
57
+ OUTPUT_DIR = "./output/zen4-ultra-lora"
58
+
59
+ # DeepseekV3/K2.5 module names for LoRA
60
+ # Attention (compressed KV architecture)
61
+ ATTENTION_MODULES = [
62
+ "q_a_proj", # query compression down
63
+ "q_b_proj", # query compression up
64
+ "kv_a_proj_with_mqa", # KV compression down
65
+ "kv_b_proj", # KV compression up
66
+ "o_proj", # output projection
67
+ ]
68
+
69
+ # Shared expert FFN (always active, not routed)
70
+ SHARED_EXPERT_MODULES = [
71
+ "shared_experts.gate_proj",
72
+ "shared_experts.up_proj",
73
+ "shared_experts.down_proj",
74
+ ]
75
+
76
+ # All target modules for LoRA
77
+ LORA_TARGET_MODULES = ATTENTION_MODULES + SHARED_EXPERT_MODULES
78
+
79
+
80
+ def setup_model(args):
81
+ """Load K2.5 with INT4 quantization and apply LoRA + gate unfreeze."""
82
+
83
+ print("=" * 60)
84
+ print("zen4-ultra Training")
85
+ print(f"Base: {BASE_MODEL}")
86
+ print(f"Architecture: 1.04T MoE (384 experts, top-8, 32B active)")
87
+ print(f"Mode: {args.mode}")
88
+ print(f"LoRA rank: {args.lora_rank}")
89
+ print(f"Gate unfreeze: {args.unfreeze_gate}")
90
+ print("=" * 60)
91
+
92
+ bnb_config = BitsAndBytesConfig(
93
+ load_in_4bit=True,
94
+ bnb_4bit_quant_type="nf4",
95
+ bnb_4bit_compute_dtype=torch.bfloat16,
96
+ bnb_4bit_use_double_quant=True,
97
+ )
98
+
99
+ print("Loading tokenizer...")
100
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
101
+ if tokenizer.pad_token is None:
102
+ tokenizer.pad_token = tokenizer.eos_token
103
+
104
+ print("Loading model (this will take 10-20 min)...")
105
+ model = AutoModelForCausalLM.from_pretrained(
106
+ BASE_MODEL,
107
+ quantization_config=bnb_config,
108
+ device_map="auto",
109
+ trust_remote_code=True,
110
+ torch_dtype=torch.bfloat16,
111
+ attn_implementation="flash_attention_2" if args.flash_attn else "eager",
112
+ )
113
+
114
+ model = prepare_model_for_kbit_training(model)
115
+
116
+ # Apply LoRA to attention + shared experts
117
+ target_modules = list(LORA_TARGET_MODULES)
118
+ if args.target_routed_experts:
119
+ # Also target individual routed experts (much more VRAM)
120
+ target_modules.extend(["gate_proj", "up_proj", "down_proj"])
121
+
122
+ lora_config = LoraConfig(
123
+ r=args.lora_rank,
124
+ lora_alpha=args.lora_rank * 2,
125
+ lora_dropout=0.05,
126
+ target_modules=target_modules,
127
+ bias="none",
128
+ task_type="CAUSAL_LM",
129
+ )
130
+ model = get_peft_model(model, lora_config)
131
+
132
+ # KEY INNOVATION: Unfreeze MoE gate/router weights
133
+ # The gate uses nn.Parameter (not nn.Linear), so LoRA can't target it.
134
+ # We unfreeze it directly so backprop can modify expert routing.
135
+ if args.unfreeze_gate:
136
+ gate_params = 0
137
+ for name, param in model.named_parameters():
138
+ if ".gate.weight" in name and "gate_proj" not in name:
139
+ param.requires_grad = True
140
+ gate_params += param.numel()
141
+ print(f"Unfroze {gate_params:,} gate/router parameters")
142
+
143
+ model.print_trainable_parameters()
144
+ return model, tokenizer
145
+
146
+
147
+ def load_sft_data(args, tokenizer):
148
+ """Load and format SFT training data."""
149
+
150
+ if args.dataset.endswith(".jsonl"):
151
+ # Local JSONL file
152
+ dataset = load_dataset("json", data_files=args.dataset, split="train")
153
+ else:
154
+ # HuggingFace dataset
155
+ dataset = load_dataset(args.dataset, split="train")
156
+
157
+ # Auto-detect format
158
+ columns = dataset.column_names
159
+ print(f"Dataset columns: {columns}")
160
+ print(f"Dataset size: {len(dataset)}")
161
+
162
+ if "messages" in columns:
163
+ # Chat format (our identity data format)
164
+ def format_chat(example):
165
+ messages = example["messages"]
166
+ text = tokenizer.apply_chat_template(
167
+ messages, tokenize=False, add_generation_prompt=False
168
+ )
169
+ return {"text": text}
170
+ dataset = dataset.map(format_chat)
171
+
172
+ elif "instruction" in columns and "output" in columns:
173
+ # Alpaca format
174
+ def format_alpaca(example):
175
+ text = f"<|im_start|>user\n{example['instruction']}<|im_end|>\n<|im_start|>assistant\n{example['output']}<|im_end|>"
176
+ if example.get("input"):
177
+ text = f"<|im_start|>user\n{example['instruction']}\n{example['input']}<|im_end|>\n<|im_start|>assistant\n{example['output']}<|im_end|>"
178
+ return {"text": text}
179
+ dataset = dataset.map(format_alpaca)
180
+
181
+ elif "conversations" in columns:
182
+ # ShareGPT format
183
+ def format_sharegpt(example):
184
+ parts = []
185
+ for msg in example["conversations"]:
186
+ role = msg.get("from", msg.get("role", "user"))
187
+ content = msg.get("value", msg.get("content", ""))
188
+ if role in ("human", "user"):
189
+ parts.append(f"<|im_start|>user\n{content}<|im_end|>")
190
+ elif role in ("gpt", "assistant"):
191
+ parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
192
+ elif role == "system":
193
+ parts.append(f"<|im_start|>system\n{content}<|im_end|>")
194
+ return {"text": "\n".join(parts)}
195
+ dataset = dataset.map(format_sharegpt)
196
+
197
+ elif "text" in columns:
198
+ pass # Already has text
199
+ elif "prompt" in columns and "response" in columns:
200
+ def format_prompt_response(example):
201
+ text = f"<|im_start|>user\n{example['prompt']}<|im_end|>\n<|im_start|>assistant\n{example['response']}<|im_end|>"
202
+ return {"text": text}
203
+ dataset = dataset.map(format_prompt_response)
204
+ else:
205
+ raise ValueError(f"Unknown dataset format. Columns: {columns}")
206
+
207
+ def tokenize(examples):
208
+ return tokenizer(
209
+ examples["text"],
210
+ truncation=True,
211
+ max_length=args.max_seq_length,
212
+ padding="max_length",
213
+ )
214
+
215
+ tokenized = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
216
+ tokenized = tokenized.add_column("labels", tokenized["input_ids"])
217
+ return tokenized
218
+
219
+
220
+ def train_sft(model, tokenizer, args):
221
+ """Standard supervised fine-tuning with uncensored data."""
222
+ from transformers import Trainer, DataCollatorForLanguageModeling
223
+
224
+ print("Loading SFT training data...")
225
+ dataset = load_sft_data(args, tokenizer)
226
+
227
+ training_args = TrainingArguments(
228
+ output_dir=args.output_dir,
229
+ num_train_epochs=args.epochs,
230
+ per_device_train_batch_size=args.batch_size,
231
+ gradient_accumulation_steps=args.grad_accum,
232
+ learning_rate=args.lr,
233
+ warmup_ratio=0.03,
234
+ logging_steps=1,
235
+ save_steps=50,
236
+ save_total_limit=3,
237
+ bf16=True,
238
+ optim="paged_adamw_8bit",
239
+ gradient_checkpointing=True,
240
+ gradient_checkpointing_kwargs={"use_reentrant": False},
241
+ report_to="none",
242
+ max_grad_norm=1.0,
243
+ lr_scheduler_type="cosine",
244
+ dataloader_num_workers=4,
245
+ ddp_find_unused_parameters=False,
246
+ )
247
+
248
+ trainer = Trainer(
249
+ model=model,
250
+ args=training_args,
251
+ train_dataset=dataset,
252
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
253
+ )
254
+
255
+ print("Training (SFT)...")
256
+ trainer.train()
257
+ return trainer
258
+
259
+
260
+ def train_dpo(model, tokenizer, args):
261
+ """DPO training — preferred=compliance, rejected=refusal."""
262
+ from trl import DPOTrainer, DPOConfig
263
+
264
+ print("Loading DPO preference data...")
265
+
266
+ if args.dataset.endswith(".jsonl"):
267
+ dataset = load_dataset("json", data_files=args.dataset, split="train")
268
+ else:
269
+ dataset = load_dataset(args.dataset, split="train")
270
+
271
+ columns = dataset.column_names
272
+ print(f"Dataset columns: {columns}")
273
+
274
+ # Standard DPO format: prompt, chosen, rejected
275
+ if not all(c in columns for c in ["prompt", "chosen", "rejected"]):
276
+ raise ValueError(
277
+ f"DPO requires 'prompt', 'chosen', 'rejected' columns. Got: {columns}\n"
278
+ "Use --mode sft for non-preference data."
279
+ )
280
+
281
+ dpo_config = DPOConfig(
282
+ output_dir=args.output_dir,
283
+ num_train_epochs=args.epochs,
284
+ per_device_train_batch_size=args.batch_size,
285
+ gradient_accumulation_steps=args.grad_accum,
286
+ learning_rate=args.lr,
287
+ warmup_ratio=0.03,
288
+ logging_steps=1,
289
+ save_steps=50,
290
+ bf16=True,
291
+ optim="paged_adamw_8bit",
292
+ gradient_checkpointing=True,
293
+ report_to="none",
294
+ beta=0.1, # DPO temperature
295
+ max_length=args.max_seq_length,
296
+ max_prompt_length=args.max_seq_length // 2,
297
+ )
298
+
299
+ trainer = DPOTrainer(
300
+ model=model,
301
+ args=dpo_config,
302
+ train_dataset=dataset,
303
+ processing_class=tokenizer,
304
+ )
305
+
306
+ print("Training (DPO)...")
307
+ trainer.train()
308
+ return trainer
309
+
310
+
311
+ def main():
312
+ parser = argparse.ArgumentParser(description="zen4-ultra QLoRA training")
313
+
314
+ # Mode
315
+ parser.add_argument("--mode", choices=["sft", "dpo"], default="sft",
316
+ help="Training mode: sft (supervised) or dpo (preference)")
317
+
318
+ # Data
319
+ parser.add_argument("--dataset", type=str, default="./data/train.jsonl",
320
+ help="HuggingFace dataset name or local .jsonl path")
321
+
322
+ # Model
323
+ parser.add_argument("--base-model", type=str, default=BASE_MODEL)
324
+ parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR)
325
+ parser.add_argument("--lora-rank", type=int, default=32,
326
+ help="LoRA rank (higher=more capacity, more VRAM)")
327
+ parser.add_argument("--unfreeze-gate", action="store_true", default=True,
328
+ help="Unfreeze MoE gate/router weights (critical for MoE uncensoring)")
329
+ parser.add_argument("--no-unfreeze-gate", dest="unfreeze_gate", action="store_false")
330
+ parser.add_argument("--target-routed-experts", action="store_true", default=False,
331
+ help="Also LoRA routed expert FFN (much more VRAM)")
332
+ parser.add_argument("--flash-attn", action="store_true", default=False,
333
+ help="Use Flash Attention 2")
334
+
335
+ # Training
336
+ parser.add_argument("--epochs", type=int, default=2)
337
+ parser.add_argument("--batch-size", type=int, default=1)
338
+ parser.add_argument("--grad-accum", type=int, default=16,
339
+ help="Gradient accumulation steps (effective batch = batch_size * grad_accum)")
340
+ parser.add_argument("--lr", type=float, default=2e-5)
341
+ parser.add_argument("--max-seq-length", type=int, default=4096)
342
+
343
+ # Upload
344
+ parser.add_argument("--push-to-hub", action="store_true", default=False)
345
+ parser.add_argument("--hub-repo", type=str, default="zenlm/zen4-ultra")
346
+
347
+ args = parser.parse_args()
348
+
349
+ model, tokenizer = setup_model(args)
350
+
351
+ if args.mode == "sft":
352
+ trainer = train_sft(model, tokenizer, args)
353
+ elif args.mode == "dpo":
354
+ trainer = train_dpo(model, tokenizer, args)
355
+
356
+ # Save
357
+ print(f"Saving LoRA adapters to {args.output_dir}")
358
+ model.save_pretrained(args.output_dir)
359
+ tokenizer.save_pretrained(args.output_dir)
360
+
361
+ if args.push_to_hub:
362
+ print(f"Pushing to {args.hub_repo}...")
363
+ model.push_to_hub(args.hub_repo)
364
+ tokenizer.push_to_hub(args.hub_repo)
365
+
366
+ print("Done!")
367
+ print(f"\nTo merge and upload full model:")
368
+ print(f" python merge_and_upload.py --base {args.base_model} --lora {args.output_dir} --repo {args.hub_repo}")
369
+
370
+
371
+ if __name__ == "__main__":
372
+ main()