szxllm commited on
Commit
4084655
·
verified ·
1 Parent(s): 7c4acc6

Update posttrain.py

Browse files
Files changed (1) hide show
  1. posttrain.py +534 -553
posttrain.py CHANGED
@@ -1,554 +1,535 @@
1
- # posttrain.py
2
- """
3
- 后训练脚本 - Instruction tuning和对齐
4
- """
5
- import os
6
- import torch
7
- import torch.nn.functional as F
8
- from transformers import AutoTokenizer
9
- from pathlib import Path
10
- import logging
11
- from tqdm import tqdm
12
- import json
13
- from datetime import datetime
14
- import copy
15
- from model import MultiModalDenseTransformer
16
-
17
- from data_loader import (
18
- create_posttrain_dataloader,
19
- create_preference_dataloader
20
- )
21
- from data_config import POSTTRAIN_MIX
22
- from reward_model import RewardModel, RewardModelTrainer
23
- from grpo import GRPOTrainer
24
- from typing import Optional
25
-
26
- logging.basicConfig(
27
- level=logging.INFO,
28
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
29
- )
30
- logger = logging.getLogger(__name__)
31
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
32
-
33
- class PostTrainer:
34
- """后训练器 - Supervised Fine-Tuning"""
35
- def __init__(
36
- self,
37
- model: MultiModalDenseTransformer,
38
- tokenizer,
39
- learning_rate: float = 1e-5,
40
- weight_decay: float = 0.01,
41
- num_epochs: int = 3,
42
- gradient_accumulation_steps: int = 1,
43
- max_grad_norm: float = 1.0,
44
- log_interval: int = 10,
45
- eval_interval: int = 500,
46
- save_interval: int = 1000,
47
- checkpoint_dir: str = "checkpoints/posttrain"
48
- ):
49
- self.model = model
50
- self.tokenizer = tokenizer
51
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
-
53
- self.model.to(self.device)
54
-
55
- # 优化器
56
- self.optimizer = torch.optim.AdamW(
57
- model.parameters(),
58
- lr=learning_rate,
59
- weight_decay=weight_decay,
60
- betas=(0.9, 0.95),
61
- eps=1e-8
62
- )
63
-
64
- # 混合精度
65
- self.use_amp = torch.cuda.is_available()
66
- self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
67
-
68
- # 训练参数
69
- self.num_epochs = num_epochs
70
- self.gradient_accumulation_steps = gradient_accumulation_steps
71
- self.max_grad_norm = max_grad_norm
72
- self.log_interval = log_interval
73
- self.eval_interval = eval_interval
74
- self.save_interval = save_interval
75
-
76
- # Checkpoint管理
77
- self.checkpoint_dir = Path(checkpoint_dir)
78
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
79
-
80
- # 训练状态
81
- self.global_step = 0
82
- self.best_eval_loss = float('inf')
83
-
84
- logger.info(f"PostTrainer initialized:")
85
- logger.info(f" Device: {self.device}")
86
- logger.info(f" Learning Rate: {learning_rate}")
87
- logger.info(f" Num Epochs: {num_epochs}")
88
- logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
89
-
90
- def train_step(self, batch: dict) -> dict:
91
- """单步训练"""
92
- instruction_ids = batch['instruction'].to(self.device)
93
- response_ids = batch['response'].to(self.device)
94
-
95
- # 1. 获取 Mask (这是之前代码里漏掉的)
96
- instruction_mask = batch['instruction_mask'].to(self.device)
97
- response_mask = batch['response_mask'].to(self.device)
98
-
99
- # 2. 拼接输入 ID 和 Mask
100
- input_ids = torch.cat([instruction_ids, response_ids], dim=1)
101
- attention_mask = torch.cat([instruction_mask, response_mask], dim=1)
102
-
103
- batch_size , seq_len = input_ids.shape
104
- position_ids=torch.zeros_like(input_ids)
105
-
106
- for i in range(batch_size):
107
- non_pad_mask = attention_mask[i].bool()
108
- if non_pad_mask.any():
109
- positions=torch.cumsum(non_pad_mask.long(), dim=0) -1
110
- position_ids[i] = positions * non_pad_mask.long()
111
-
112
-
113
-
114
-
115
-
116
- # 3. 创建标签
117
- labels = input_ids.clone()
118
-
119
- # 屏蔽 Instruction 部分
120
- instr_len = instruction_ids.shape[1]
121
- labels[:, :instr_len] = -100
122
-
123
- labels[attention_mask == 0] = -100
124
-
125
-
126
- # 准备输入数据
127
- input_data = {
128
- 'segments': [{
129
- 'type': 'text',
130
- 'data': input_ids,
131
- 'modality_id': 0
132
- }]
133
- }
134
-
135
- # 前向传播
136
- with torch.amp.autocast('cuda', enabled=self.use_amp):
137
- # === 核心修改点 2 ===
138
- # 必须传入 attention_mask,否则 transformer 不知道哪里是 padding
139
- outputs = self.model(input_data, attention_mask=attention_mask,
140
- position_ids = position_ids)
141
-
142
- logits = outputs['logits']
143
-
144
- # 计算损失
145
- shift_logits = logits[:, :-1, :].contiguous()
146
- shift_labels = labels[:, 1:].contiguous()
147
-
148
- loss = F.cross_entropy(
149
- shift_logits.view(-1, shift_logits.size(-1)),
150
- shift_labels.view(-1),
151
- ignore_index=-100
152
- )
153
- raw_loss = loss.item()
154
- loss = loss / self.gradient_accumulation_steps
155
-
156
- # 反向传播
157
- self.scaler.scale(loss).backward()
158
-
159
- return {
160
- 'loss': raw_loss
161
- }
162
-
163
- def optimizer_step(self):
164
- """优化器步骤"""
165
- self.scaler.unscale_(self.optimizer)
166
- grad_norm = torch.nn.utils.clip_grad_norm_(
167
- self.model.parameters(),
168
- self.max_grad_norm
169
- )
170
-
171
- self.scaler.step(self.optimizer)
172
- self.scaler.update()
173
- self.optimizer.zero_grad(set_to_none=True)
174
- self.global_step += 1
175
- return grad_norm.item()
176
-
177
- @torch.no_grad()
178
- def evaluate(self, dataloader, max_batches: int = 50) -> float:
179
- """评估"""
180
- self.model.eval()
181
- total_loss = 0.0
182
- num_batches = 0
183
-
184
- for i, batch in enumerate(dataloader):
185
- if i >= max_batches:
186
- break
187
-
188
- if batch is None:
189
- continue
190
-
191
- instruction_ids = batch['instruction'].to(self.device)
192
- response_ids = batch['response'].to(self.device)
193
- input_ids = torch.cat([instruction_ids, response_ids], dim=1)
194
-
195
- labels = input_ids.clone()
196
- labels[:, :instruction_ids.shape[1]] = -100
197
- labels[input_ids == self.tokenizer.pad_token_id] = -100
198
-
199
- input_data = {
200
- 'segments': [{
201
- 'type': 'text',
202
- 'data': input_ids,
203
- 'modality_id': 0
204
- }]
205
- }
206
-
207
- with torch.amp.autocast('cuda', enabled=self.use_amp):
208
- outputs = self.model(input_data)
209
- logits = outputs['logits']
210
-
211
- shift_logits = logits[:, :-1, :].contiguous()
212
- shift_labels = labels[:, 1:].contiguous()
213
-
214
- loss = F.cross_entropy(
215
- shift_logits.view(-1, shift_logits.size(-1)),
216
- shift_labels.view(-1),
217
- ignore_index=-100
218
- )
219
-
220
- total_loss += loss.item()
221
- num_batches += 1
222
-
223
- self.model.train()
224
- return total_loss / max(num_batches, 1)
225
-
226
- def train(
227
- self,
228
- train_dataloader,
229
- eval_dataloader=None,
230
- resume_from: Optional[str] = None
231
- ):
232
- """训练循环"""
233
- logger.info("\n" + "="*80)
234
- logger.info("Starting Post-Training (SFT)")
235
- logger.info("="*80 + "\n")
236
-
237
- if resume_from:
238
- self.load_checkpoint(resume_from)
239
-
240
- self.model.train()
241
-
242
- for epoch in range(self.num_epochs):
243
- logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
244
-
245
- progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
246
- running_loss = 0.0
247
- step_in_accumulation = 0
248
-
249
- for batch_idx, batch in enumerate(progress_bar):
250
- if batch is None:
251
- continue
252
-
253
- # 训练步骤
254
- stats = self.train_step(batch)
255
- running_loss += stats['loss']
256
- step_in_accumulation += 1
257
-
258
- # 优化器更新
259
- if step_in_accumulation == self.gradient_accumulation_steps:
260
- grad_norm = self.optimizer_step()
261
- step_in_accumulation = 0
262
-
263
- # 更新进度条
264
- progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
265
-
266
- # 日志
267
- if self.global_step % self.log_interval == 0:
268
- avg_loss = running_loss / self.log_interval
269
- logger.info(
270
- f"Step {self.global_step} | "
271
- f"Epoch {epoch+1} | "
272
- f"Loss: {avg_loss:.4f}"
273
- )
274
- running_loss = 0.0
275
-
276
- # 评估
277
- if eval_dataloader and self.global_step % self.eval_interval == 0:
278
- eval_loss = self.evaluate(eval_dataloader)
279
- logger.info(f"Eval Loss: {eval_loss:.4f}")
280
-
281
- if eval_loss < self.best_eval_loss:
282
- self.best_eval_loss = eval_loss
283
- self.save_checkpoint(
284
- self.checkpoint_dir / "best_model.pt",
285
- is_best=True
286
- )
287
-
288
- # 保存
289
- if self.global_step % self.save_interval == 0:
290
- self.save_checkpoint(
291
- self.checkpoint_dir / f"step_{self.global_step}.pt"
292
- )
293
-
294
- # Epoch结束评估
295
- if eval_dataloader:
296
- eval_loss = self.evaluate(eval_dataloader)
297
- logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
298
-
299
- logger.info("\n" + "="*80)
300
- logger.info("Post-Training Complete!")
301
- logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
302
- logger.info("="*80 + "\n")
303
-
304
- self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
305
-
306
- def save_checkpoint(self, path: Path, is_best: bool = False):
307
- """保存checkpoint"""
308
- checkpoint = {
309
- 'model_state_dict': self.model.state_dict(),
310
- 'optimizer_state_dict': self.optimizer.state_dict(),
311
- 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
312
- 'global_step': self.global_step,
313
- 'best_eval_loss': self.best_eval_loss,
314
- 'timestamp': datetime.now().isoformat()
315
- }
316
-
317
- torch.save(checkpoint, path)
318
- logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
319
-
320
- def load_checkpoint(self, path: str):
321
- """加载checkpoint"""
322
- checkpoint = torch.load(path, map_location=self.device)
323
-
324
- self.model.load_state_dict(checkpoint['model_state_dict'])
325
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
326
-
327
- if self.use_amp and checkpoint.get('scaler_state_dict'):
328
- self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
329
-
330
- self.global_step = checkpoint['global_step']
331
- self.best_eval_loss = checkpoint['best_eval_loss']
332
-
333
- logger.info(f"Checkpoint loaded from {path}")
334
-
335
- def main():
336
- """主函数"""
337
- # 配置
338
- config = {
339
- # 模型配置
340
- 'model_dim': 1536,
341
- 'vocab_size': 151665,
342
- 'n_layers': 12,
343
- 'n_heads': 12,
344
- 'n_kv_heads': 4,
345
- 'max_seq_len': 512,
346
- 'dropout': 0.0,
347
- 'use_moe': False,
348
- # 训练配置
349
- 'batch_size': 2,
350
- 'gradient_accumulation_steps': 8,
351
- 'learning_rate': 1e-5,
352
- 'weight_decay': 0.01,
353
- 'num_epochs': 3,
354
- 'max_grad_norm': 1.0,
355
-
356
- # 数据配置
357
- 'data_mix': 'simple_instruct',
358
- 'max_samples_train': 20000,
359
- 'max_samples_eval': 1000,
360
- 'max_length': 512,
361
- 'num_workers': 4,
362
-
363
- # RLHF配置
364
- 'do_rlhf': False,
365
- 'preference_dataset': 'hh_rlhf',
366
- 'grpo_iterations': 3,
367
- 'grpo_kl_coef': 0.04,
368
- 'grpo_group_size': 4,
369
-
370
- # 路径
371
- 'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
372
- 'checkpoint_dir': 'checkpoints/posttrain',
373
- 'log_interval': 50,
374
- 'eval_interval': 500,
375
- 'save_interval': 1000,
376
- }
377
-
378
- logger.info("Configuration:")
379
- logger.info(json.dumps(config, indent=2))
380
-
381
- # 初始化tokenizer
382
- logger.info("\nInitializing tokenizer...")
383
- tokenizer = AutoTokenizer.from_pretrained(
384
- "Qwen/Qwen2.5-7B-Instruct",
385
- use_fast=True,
386
- trust_remote_code=True
387
- )
388
-
389
- if tokenizer.pad_token is None:
390
- tokenizer.pad_token = tokenizer.eos_token
391
- tokenizer.pad_token_id = tokenizer.eos_token_id
392
-
393
- config['vocab_size'] = len(tokenizer)
394
-
395
- # 初始化或加载模型
396
- logger.info("\nInitializing model...")
397
- model = MultiModalDenseTransformer(
398
- model_dim=config['model_dim'],
399
- vocab_size=config['vocab_size'],
400
- n_layers=config['n_layers'],
401
- n_heads=config['n_heads'],
402
- n_kv_heads=config['n_kv_heads'],
403
- max_seq_len=config['max_seq_len'],
404
- dropout=config['dropout'],
405
- use_moe=config['use_moe'],
406
- use_gradient_checkpointing=False,
407
- rope_scaling_type="yarn",
408
- use_multimodal_fusion=False,
409
- use_contrastive=False
410
- )
411
-
412
- # 加载预训练checkpoint(如果有)
413
- if config['pretrain_checkpoint']:
414
- logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
415
- checkpoint = torch.load(config['pretrain_checkpoint'])
416
- model.load_state_dict(checkpoint['model_state_dict'])
417
-
418
- # ===== 阶段1: Supervised Fine-Tuning =====
419
- logger.info("\n" + "="*80)
420
- logger.info("PHASE 1: Supervised Fine-Tuning")
421
- logger.info("="*80)
422
-
423
- # 创建数据加载器
424
- train_dataloader = create_posttrain_dataloader(
425
- mix_name=config['data_mix'],
426
- tokenizer=tokenizer,
427
- batch_size=config['batch_size'],
428
- num_workers=config['num_workers'],
429
- max_length=config['max_length'],
430
- max_samples=config['max_samples_train'],
431
- split='train',
432
- shuffle=True
433
- )
434
-
435
- eval_dataloader = create_posttrain_dataloader(
436
- mix_name=config['data_mix'],
437
- tokenizer=tokenizer,
438
- batch_size=config['batch_size'] * 2,
439
- num_workers=config['num_workers'],
440
- max_length=config['max_length'],
441
- max_samples=config['max_samples_eval'],
442
- split='train', # 使用train的后部分作为验证
443
- shuffle=False
444
- )
445
-
446
- # 创建训练器
447
- trainer = PostTrainer(
448
- model=model,
449
- tokenizer=tokenizer,
450
- learning_rate=config['learning_rate'],
451
- weight_decay=config['weight_decay'],
452
- num_epochs=config['num_epochs'],
453
- gradient_accumulation_steps=config['gradient_accumulation_steps'],
454
- max_grad_norm=config['max_grad_norm'],
455
- log_interval=config['log_interval'],
456
- eval_interval=config['eval_interval'],
457
- save_interval=config['save_interval'],
458
- checkpoint_dir=config['checkpoint_dir']
459
- )
460
-
461
- # 开始SFT训练
462
- trainer.train(train_dataloader, eval_dataloader)
463
-
464
- # ===== 阶段2: RLHF with GRPO =====
465
- if config['do_rlhf']:
466
- logger.info("\n" + "="*80)
467
- logger.info("PHASE 2: RLHF with GRPO")
468
- logger.info("="*80)
469
-
470
- try:
471
- # 训练奖励模型
472
- logger.info("\nTraining Reward Model...")
473
-
474
- reward_base_model = copy.deepcopy(model)
475
- reward_model = RewardModel(reward_base_model, use_value_head=True)
476
-
477
- preference_dataloader = create_preference_dataloader(
478
- dataset_name=config['preference_dataset'],
479
- tokenizer=tokenizer,
480
- batch_size=config['batch_size'],
481
- num_workers=config['num_workers'],
482
- max_samples=5000,
483
- split='train'
484
- )
485
-
486
- reward_trainer = RewardModelTrainer(
487
- reward_model=reward_model,
488
- learning_rate=1e-5
489
- )
490
-
491
- reward_trainer.train(preference_dataloader, num_epochs=1)
492
-
493
- # GRPO训练
494
- logger.info("\nStarting GRPO Training...")
495
-
496
- ref_model = copy.deepcopy(model)
497
- ref_model.eval()
498
-
499
- grpo_trainer = GRPOTrainer(
500
- actor_model=model,
501
- reward_model=reward_model,
502
- ref_model=ref_model,
503
- tokenizer=tokenizer,
504
- learning_rate=1e-6,
505
- kl_coef=config['grpo_kl_coef'],
506
- group_size=config['grpo_group_size'],
507
- update_batch_size=2,
508
- use_amp=True
509
- )
510
-
511
- # 准备prompts
512
- prompt_dataloader = create_posttrain_dataloader(
513
- mix_name=config['data_mix'],
514
- tokenizer=tokenizer,
515
- batch_size=4,
516
- num_workers=2,
517
- max_samples=1000,
518
- split='train'
519
- )
520
-
521
- # 提取prompts
522
- prompts = []
523
- for batch in prompt_dataloader:
524
- if batch and batch.get('instruction') is not None:
525
- prompts.append(batch['instruction'])
526
- if len(prompts) >= 200:
527
- break
528
-
529
- if prompts:
530
- prompt_tensor = torch.cat(prompts[:200], dim=0)
531
- from torch.utils.data import TensorDataset, DataLoader
532
- prompt_loader = DataLoader(
533
- TensorDataset(prompt_tensor),
534
- batch_size=4
535
- )
536
-
537
- grpo_trainer.train(
538
- prompt_loader,
539
- num_iterations=config['grpo_iterations'],
540
- max_gen_len=50,
541
- save_path=config['checkpoint_dir'] + "/grpo"
542
- )
543
-
544
- except Exception as e:
545
- logger.error(f"Error in RLHF: {e}")
546
- import traceback
547
- traceback.print_exc()
548
-
549
- logger.info("\n" + "="*80)
550
- logger.info("All Training Complete!")
551
- logger.info("="*80)
552
-
553
- if __name__ == "__main__":
554
  main()
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ from pathlib import Path
6
+ import logging
7
+ from tqdm import tqdm
8
+ import json
9
+ from datetime import datetime
10
+ import copy
11
+ from model import MultiModalDenseTransformer
12
+
13
+ from data_loader import (
14
+ create_posttrain_dataloader,
15
+ create_preference_dataloader
16
+ )
17
+ from data_config import POSTTRAIN_MIX
18
+ from reward_model import RewardModel, RewardModelTrainer
19
+ from grpo import GRPOTrainer
20
+ from typing import Optional
21
+
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
28
+
29
+ class PostTrainer:
30
+ def __init__(
31
+ self,
32
+ model: MultiModalDenseTransformer,
33
+ tokenizer,
34
+ learning_rate: float = 1e-5,
35
+ weight_decay: float = 0.01,
36
+ num_epochs: int = 3,
37
+ gradient_accumulation_steps: int = 1,
38
+ max_grad_norm: float = 1.0,
39
+ log_interval: int = 10,
40
+ eval_interval: int = 500,
41
+ save_interval: int = 1000,
42
+ checkpoint_dir: str = "checkpoints/posttrain"
43
+ ):
44
+ self.model = model
45
+ self.tokenizer = tokenizer
46
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+
48
+ self.model.to(self.device)
49
+
50
+ # 优化器
51
+ self.optimizer = torch.optim.AdamW(
52
+ model.parameters(),
53
+ lr=learning_rate,
54
+ weight_decay=weight_decay,
55
+ betas=(0.9, 0.95),
56
+ eps=1e-8
57
+ )
58
+
59
+ # 混合精度
60
+ self.use_amp = torch.cuda.is_available()
61
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
62
+
63
+ # 训练参数
64
+ self.num_epochs = num_epochs
65
+ self.gradient_accumulation_steps = gradient_accumulation_steps
66
+ self.max_grad_norm = max_grad_norm
67
+ self.log_interval = log_interval
68
+ self.eval_interval = eval_interval
69
+ self.save_interval = save_interval
70
+
71
+ # Checkpoint管理
72
+ self.checkpoint_dir = Path(checkpoint_dir)
73
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
74
+
75
+ # 训练状态
76
+ self.global_step = 0
77
+ self.best_eval_loss = float('inf')
78
+
79
+ logger.info(f"PostTrainer initialized:")
80
+ logger.info(f" Device: {self.device}")
81
+ logger.info(f" Learning Rate: {learning_rate}")
82
+ logger.info(f" Num Epochs: {num_epochs}")
83
+ logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
84
+
85
+ def train_step(self, batch: dict) -> dict:
86
+ """单步训练"""
87
+ instruction_ids = batch['instruction'].to(self.device)
88
+ response_ids = batch['response'].to(self.device)
89
+
90
+ instruction_mask = batch['instruction_mask'].to(self.device)
91
+ response_mask = batch['response_mask'].to(self.device)
92
+
93
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
94
+ attention_mask = torch.cat([instruction_mask, response_mask], dim=1)
95
+
96
+ batch_size , seq_len = input_ids.shape
97
+ position_ids=torch.zeros_like(input_ids)
98
+
99
+ for i in range(batch_size):
100
+ non_pad_mask = attention_mask[i].bool()
101
+ if non_pad_mask.any():
102
+ positions=torch.cumsum(non_pad_mask.long(), dim=0) -1
103
+ position_ids[i] = positions * non_pad_mask.long()
104
+ labels = input_ids.clone()
105
+
106
+ # 屏蔽 Instruction 部分
107
+ instr_len = instruction_ids.shape[1]
108
+ labels[:, :instr_len] = -100
109
+
110
+ labels[attention_mask == 0] = -100
111
+
112
+
113
+ # 准备输入数据
114
+ input_data = {
115
+ 'segments': [{
116
+ 'type': 'text',
117
+ 'data': input_ids,
118
+ 'modality_id': 0
119
+ }]
120
+ }
121
+
122
+ # 前向传播
123
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
124
+ outputs = self.model(input_data, attention_mask=attention_mask,
125
+ position_ids = position_ids)
126
+
127
+ logits = outputs['logits']
128
+
129
+ # 计算损失
130
+ shift_logits = logits[:, :-1, :].contiguous()
131
+ shift_labels = labels[:, 1:].contiguous()
132
+
133
+ loss = F.cross_entropy(
134
+ shift_logits.view(-1, shift_logits.size(-1)),
135
+ shift_labels.view(-1),
136
+ ignore_index=-100
137
+ )
138
+ raw_loss = loss.item()
139
+ loss = loss / self.gradient_accumulation_steps
140
+
141
+ # 反向传播
142
+ self.scaler.scale(loss).backward()
143
+
144
+ return {
145
+ 'loss': raw_loss
146
+ }
147
+
148
+ def optimizer_step(self):
149
+ """优化器步骤"""
150
+ self.scaler.unscale_(self.optimizer)
151
+ grad_norm = torch.nn.utils.clip_grad_norm_(
152
+ self.model.parameters(),
153
+ self.max_grad_norm
154
+ )
155
+
156
+ self.scaler.step(self.optimizer)
157
+ self.scaler.update()
158
+ self.optimizer.zero_grad(set_to_none=True)
159
+ self.global_step += 1
160
+ return grad_norm.item()
161
+
162
+ @torch.no_grad()
163
+ def evaluate(self, dataloader, max_batches: int = 50) -> float:
164
+ """评估"""
165
+ self.model.eval()
166
+ total_loss = 0.0
167
+ num_batches = 0
168
+
169
+ for i, batch in enumerate(dataloader):
170
+ if i >= max_batches:
171
+ break
172
+
173
+ if batch is None:
174
+ continue
175
+
176
+ instruction_ids = batch['instruction'].to(self.device)
177
+ response_ids = batch['response'].to(self.device)
178
+ input_ids = torch.cat([instruction_ids, response_ids], dim=1)
179
+
180
+ labels = input_ids.clone()
181
+ labels[:, :instruction_ids.shape[1]] = -100
182
+ labels[input_ids == self.tokenizer.pad_token_id] = -100
183
+
184
+ input_data = {
185
+ 'segments': [{
186
+ 'type': 'text',
187
+ 'data': input_ids,
188
+ 'modality_id': 0
189
+ }]
190
+ }
191
+
192
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
193
+ outputs = self.model(input_data)
194
+ logits = outputs['logits']
195
+
196
+ shift_logits = logits[:, :-1, :].contiguous()
197
+ shift_labels = labels[:, 1:].contiguous()
198
+
199
+ loss = F.cross_entropy(
200
+ shift_logits.view(-1, shift_logits.size(-1)),
201
+ shift_labels.view(-1),
202
+ ignore_index=-100
203
+ )
204
+
205
+ total_loss += loss.item()
206
+ num_batches += 1
207
+
208
+ self.model.train()
209
+ return total_loss / max(num_batches, 1)
210
+
211
+ def train(
212
+ self,
213
+ train_dataloader,
214
+ eval_dataloader=None,
215
+ resume_from: Optional[str] = None
216
+ ):
217
+ """训练循环"""
218
+ logger.info("\n" + "="*80)
219
+ logger.info("Starting Post-Training (SFT)")
220
+ logger.info("="*80 + "\n")
221
+
222
+ if resume_from:
223
+ self.load_checkpoint(resume_from)
224
+
225
+ self.model.train()
226
+
227
+ for epoch in range(self.num_epochs):
228
+ logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
229
+
230
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
231
+ running_loss = 0.0
232
+ step_in_accumulation = 0
233
+
234
+ for batch_idx, batch in enumerate(progress_bar):
235
+ if batch is None:
236
+ continue
237
+
238
+ # 训练步骤
239
+ stats = self.train_step(batch)
240
+ running_loss += stats['loss']
241
+ step_in_accumulation += 1
242
+
243
+ # 优化器更新
244
+ if step_in_accumulation == self.gradient_accumulation_steps:
245
+ grad_norm = self.optimizer_step()
246
+ step_in_accumulation = 0
247
+
248
+ # 更新进度条
249
+ progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
250
+
251
+ # 日志
252
+ if self.global_step % self.log_interval == 0:
253
+ avg_loss = running_loss / self.log_interval
254
+ logger.info(
255
+ f"Step {self.global_step} | "
256
+ f"Epoch {epoch+1} | "
257
+ f"Loss: {avg_loss:.4f}"
258
+ )
259
+ running_loss = 0.0
260
+
261
+ # 评估
262
+ if eval_dataloader and self.global_step % self.eval_interval == 0:
263
+ eval_loss = self.evaluate(eval_dataloader)
264
+ logger.info(f"Eval Loss: {eval_loss:.4f}")
265
+
266
+ if eval_loss < self.best_eval_loss:
267
+ self.best_eval_loss = eval_loss
268
+ self.save_checkpoint(
269
+ self.checkpoint_dir / "best_model.pt",
270
+ is_best=True
271
+ )
272
+
273
+ # 保存
274
+ if self.global_step % self.save_interval == 0:
275
+ self.save_checkpoint(
276
+ self.checkpoint_dir / f"step_{self.global_step}.pt"
277
+ )
278
+
279
+ # Epoch结束评估
280
+ if eval_dataloader:
281
+ eval_loss = self.evaluate(eval_dataloader)
282
+ logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
283
+
284
+ logger.info("\n" + "="*80)
285
+ logger.info("Post-Training Complete!")
286
+ logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
287
+ logger.info("="*80 + "\n")
288
+
289
+ self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
290
+
291
+ def save_checkpoint(self, path: Path, is_best: bool = False):
292
+ """保存checkpoint"""
293
+ checkpoint = {
294
+ 'model_state_dict': self.model.state_dict(),
295
+ 'optimizer_state_dict': self.optimizer.state_dict(),
296
+ 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
297
+ 'global_step': self.global_step,
298
+ 'best_eval_loss': self.best_eval_loss,
299
+ 'timestamp': datetime.now().isoformat()
300
+ }
301
+
302
+ torch.save(checkpoint, path)
303
+ logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
304
+
305
+ def load_checkpoint(self, path: str):
306
+ """加载checkpoint"""
307
+ checkpoint = torch.load(path, map_location=self.device)
308
+
309
+ self.model.load_state_dict(checkpoint['model_state_dict'])
310
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
311
+
312
+ if self.use_amp and checkpoint.get('scaler_state_dict'):
313
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
314
+
315
+ self.global_step = checkpoint['global_step']
316
+ self.best_eval_loss = checkpoint['best_eval_loss']
317
+
318
+ logger.info(f"Checkpoint loaded from {path}")
319
+
320
+ def main():
321
+ """主函数"""
322
+ # 配置
323
+ config = {
324
+ # 模型配置
325
+ 'model_dim': 1536,
326
+ 'vocab_size': 151665,
327
+ 'n_layers': 12,
328
+ 'n_heads': 12,
329
+ 'n_kv_heads': 4,
330
+ 'max_seq_len': 512,
331
+ 'dropout': 0.0,
332
+ 'use_moe': False,
333
+ # 训练配置
334
+ 'batch_size': 2,
335
+ 'gradient_accumulation_steps': 8,
336
+ 'learning_rate': 1e-5,
337
+ 'weight_decay': 0.01,
338
+ 'num_epochs': 3,
339
+ 'max_grad_norm': 1.0,
340
+
341
+ # 数据配置
342
+ 'data_mix': 'simple_instruct',
343
+ 'max_samples_train': 20000,
344
+ 'max_samples_eval': 1000,
345
+ 'max_length': 512,
346
+ 'num_workers': 4,
347
+
348
+ # RLHF配置
349
+ 'do_rlhf': False,
350
+ 'preference_dataset': 'hh_rlhf',
351
+ 'grpo_iterations': 3,
352
+ 'grpo_kl_coef': 0.04,
353
+ 'grpo_group_size': 4,
354
+
355
+ # 路径
356
+ 'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
357
+ 'checkpoint_dir': 'checkpoints/posttrain',
358
+ 'log_interval': 50,
359
+ 'eval_interval': 500,
360
+ 'save_interval': 1000,
361
+ }
362
+
363
+ logger.info("Configuration:")
364
+ logger.info(json.dumps(config, indent=2))
365
+
366
+ # 初始化tokenizer
367
+ logger.info("\nInitializing tokenizer...")
368
+ tokenizer = AutoTokenizer.from_pretrained(
369
+ "Qwen/Qwen2.5-7B-Instruct",
370
+ use_fast=True,
371
+ trust_remote_code=True
372
+ )
373
+
374
+ if tokenizer.pad_token is None:
375
+ tokenizer.pad_token = tokenizer.eos_token
376
+ tokenizer.pad_token_id = tokenizer.eos_token_id
377
+
378
+ config['vocab_size'] = len(tokenizer)
379
+
380
+ # 初始化或加载模型
381
+ logger.info("\nInitializing model...")
382
+ model = MultiModalDenseTransformer(
383
+ model_dim=config['model_dim'],
384
+ vocab_size=config['vocab_size'],
385
+ n_layers=config['n_layers'],
386
+ n_heads=config['n_heads'],
387
+ n_kv_heads=config['n_kv_heads'],
388
+ max_seq_len=config['max_seq_len'],
389
+ dropout=config['dropout'],
390
+ use_moe=config['use_moe'],
391
+ use_gradient_checkpointing=False,
392
+ rope_scaling_type="yarn",
393
+ use_multimodal_fusion=False,
394
+ use_contrastive=False
395
+ )
396
+
397
+ if config['pretrain_checkpoint']:
398
+ logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
399
+ checkpoint = torch.load(config['pretrain_checkpoint'])
400
+ model.load_state_dict(checkpoint['model_state_dict'])
401
+
402
+ logger.info("\n" + "="*80)
403
+ logger.info("PHASE 1: Supervised Fine-Tuning")
404
+ logger.info("="*80)
405
+
406
+ # 创建数据加载器
407
+ train_dataloader = create_posttrain_dataloader(
408
+ mix_name=config['data_mix'],
409
+ tokenizer=tokenizer,
410
+ batch_size=config['batch_size'],
411
+ num_workers=config['num_workers'],
412
+ max_length=config['max_length'],
413
+ max_samples=config['max_samples_train'],
414
+ split='train',
415
+ shuffle=True
416
+ )
417
+
418
+ eval_dataloader = create_posttrain_dataloader(
419
+ mix_name=config['data_mix'],
420
+ tokenizer=tokenizer,
421
+ batch_size=config['batch_size'] * 2,
422
+ num_workers=config['num_workers'],
423
+ max_length=config['max_length'],
424
+ max_samples=config['max_samples_eval'],
425
+ split='train', # 使用train的后部分作为验证
426
+ shuffle=False
427
+ )
428
+
429
+ # 创建训练器
430
+ trainer = PostTrainer(
431
+ model=model,
432
+ tokenizer=tokenizer,
433
+ learning_rate=config['learning_rate'],
434
+ weight_decay=config['weight_decay'],
435
+ num_epochs=config['num_epochs'],
436
+ gradient_accumulation_steps=config['gradient_accumulation_steps'],
437
+ max_grad_norm=config['max_grad_norm'],
438
+ log_interval=config['log_interval'],
439
+ eval_interval=config['eval_interval'],
440
+ save_interval=config['save_interval'],
441
+ checkpoint_dir=config['checkpoint_dir']
442
+ )
443
+
444
+ trainer.train(train_dataloader, eval_dataloader)
445
+
446
+ if config['do_rlhf']:
447
+ logger.info("\n" + "="*80)
448
+ logger.info("PHASE 2: RLHF with GRPO")
449
+ logger.info("="*80)
450
+
451
+ try:
452
+ # 训练奖励模型
453
+ logger.info("\nTraining Reward Model...")
454
+
455
+ reward_base_model = copy.deepcopy(model)
456
+ reward_model = RewardModel(reward_base_model, use_value_head=True)
457
+
458
+ preference_dataloader = create_preference_dataloader(
459
+ dataset_name=config['preference_dataset'],
460
+ tokenizer=tokenizer,
461
+ batch_size=config['batch_size'],
462
+ num_workers=config['num_workers'],
463
+ max_samples=5000,
464
+ split='train'
465
+ )
466
+
467
+ reward_trainer = RewardModelTrainer(
468
+ reward_model=reward_model,
469
+ learning_rate=1e-5
470
+ )
471
+
472
+ reward_trainer.train(preference_dataloader, num_epochs=1)
473
+
474
+ # GRPO训练
475
+ logger.info("\nStarting GRPO Training...")
476
+
477
+ ref_model = copy.deepcopy(model)
478
+ ref_model.eval()
479
+
480
+ grpo_trainer = GRPOTrainer(
481
+ actor_model=model,
482
+ reward_model=reward_model,
483
+ ref_model=ref_model,
484
+ tokenizer=tokenizer,
485
+ learning_rate=1e-6,
486
+ kl_coef=config['grpo_kl_coef'],
487
+ group_size=config['grpo_group_size'],
488
+ update_batch_size=2,
489
+ use_amp=True
490
+ )
491
+
492
+ # 准备prompts
493
+ prompt_dataloader = create_posttrain_dataloader(
494
+ mix_name=config['data_mix'],
495
+ tokenizer=tokenizer,
496
+ batch_size=4,
497
+ num_workers=2,
498
+ max_samples=1000,
499
+ split='train'
500
+ )
501
+
502
+ # 提取prompts
503
+ prompts = []
504
+ for batch in prompt_dataloader:
505
+ if batch and batch.get('instruction') is not None:
506
+ prompts.append(batch['instruction'])
507
+ if len(prompts) >= 200:
508
+ break
509
+
510
+ if prompts:
511
+ prompt_tensor = torch.cat(prompts[:200], dim=0)
512
+ from torch.utils.data import TensorDataset, DataLoader
513
+ prompt_loader = DataLoader(
514
+ TensorDataset(prompt_tensor),
515
+ batch_size=4
516
+ )
517
+
518
+ grpo_trainer.train(
519
+ prompt_loader,
520
+ num_iterations=config['grpo_iterations'],
521
+ max_gen_len=50,
522
+ save_path=config['checkpoint_dir'] + "/grpo"
523
+ )
524
+
525
+ except Exception as e:
526
+ logger.error(f"Error in RLHF: {e}")
527
+ import traceback
528
+ traceback.print_exc()
529
+
530
+ logger.info("\n" + "="*80)
531
+ logger.info("All Training Complete!")
532
+ logger.info("="*80)
533
+
534
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  main()