szxllm commited on
Commit
7c4acc6
·
verified ·
1 Parent(s): 3d1c312

Delete post.py

Browse files
Files changed (1) hide show
  1. post.py +0 -532
post.py DELETED
@@ -1,532 +0,0 @@
1
- # posttrain.py
2
- """
3
- 后训练脚本 - Instruction tuning和对齐
4
- """
5
- import os
6
- import torch
7
- import torch.nn.functional as F
8
- from transformers import AutoTokenizer
9
- from pathlib import Path
10
- import logging
11
- from tqdm import tqdm
12
- import json
13
- from datetime import datetime
14
- import copy
15
- from model import MultiModalDenseTransformer
16
-
17
- from data_loader import (
18
- create_posttrain_dataloader,
19
- create_preference_dataloader
20
- )
21
- from data_config import POSTTRAIN_MIX
22
- from reward_model import RewardModel, RewardModelTrainer
23
- from grpo import GRPOTrainer
24
- from typing import Optional
25
-
26
- logging.basicConfig(
27
- level=logging.INFO,
28
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
29
- )
30
- logger = logging.getLogger(__name__)
31
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
32
-
33
- class PostTrainer:
34
- """后训练器 - Supervised Fine-Tuning"""
35
- def __init__(
36
- self,
37
- model: MultiModalDenseTransformer,
38
- tokenizer,
39
- learning_rate: float = 1e-5,
40
- weight_decay: float = 0.01,
41
- num_epochs: int = 3,
42
- gradient_accumulation_steps: int = 1,
43
- max_grad_norm: float = 1.0,
44
- log_interval: int = 10,
45
- eval_interval: int = 500,
46
- save_interval: int = 1000,
47
- checkpoint_dir: str = "checkpoints/posttrain"
48
- ):
49
- self.model = model
50
- self.tokenizer = tokenizer
51
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
-
53
- self.model.to(self.device)
54
-
55
- # 优化器
56
- self.optimizer = torch.optim.AdamW(
57
- model.parameters(),
58
- lr=learning_rate,
59
- weight_decay=weight_decay,
60
- betas=(0.9, 0.95),
61
- eps=1e-8
62
- )
63
-
64
- # 混合精度
65
- self.use_amp = torch.cuda.is_available()
66
- self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
67
-
68
- # 训练参数
69
- self.num_epochs = num_epochs
70
- self.gradient_accumulation_steps = gradient_accumulation_steps
71
- self.max_grad_norm = max_grad_norm
72
- self.log_interval = log_interval
73
- self.eval_interval = eval_interval
74
- self.save_interval = save_interval
75
-
76
- # Checkpoint管理
77
- self.checkpoint_dir = Path(checkpoint_dir)
78
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
79
-
80
- # 训练状态
81
- self.global_step = 0
82
- self.best_eval_loss = float('inf')
83
-
84
- logger.info(f"PostTrainer initialized:")
85
- logger.info(f" Device: {self.device}")
86
- logger.info(f" Learning Rate: {learning_rate}")
87
- logger.info(f" Num Epochs: {num_epochs}")
88
- logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
89
-
90
- def train_step(self, batch: dict) -> dict:
91
- """单步训练"""
92
- instruction_ids = batch['instruction'].to(self.device)
93
- response_ids = batch['response'].to(self.device)
94
-
95
- # 获取 DataLoader 返回的掩码
96
- instruction_mask = batch['instruction_mask'].to(self.device)
97
- response_mask = batch['response_mask'].to(self.device)
98
- # 拼接输入
99
- input_ids = torch.cat([instruction_ids, response_ids], dim=1)
100
- attention_mask = torch.cat([instruction_mask, response_mask], dim=1).float()
101
- # 创建标签(只计算response部分的损失)
102
- labels = input_ids.clone()
103
- instr_len = instruction_ids.shape[1]
104
- labels[:, :instr_len] = -100
105
- labels[attention_mask == 0] = -100
106
-
107
-
108
- # 准备输入数据
109
- input_data = {
110
- 'segments': [{
111
- 'type': 'text',
112
- 'data': input_ids,
113
- 'modality_id': 0
114
- }]
115
- }
116
-
117
- # 前向传播
118
- with torch.amp.autocast('cuda', enabled=self.use_amp):
119
- outputs = self.model(input_data,attention_mask=attention_mask)
120
- logits = outputs['logits']
121
-
122
- # 计算损失
123
- shift_logits = logits[:, :-1, :].contiguous()
124
- shift_labels = labels[:, 1:].contiguous()
125
-
126
- loss = F.cross_entropy(
127
- shift_logits.view(-1, shift_logits.size(-1)),
128
- shift_labels.view(-1),
129
- ignore_index=-100
130
- )
131
- raw_loss = loss.item()
132
- loss = loss / self.gradient_accumulation_steps
133
-
134
- # 反向传播
135
- self.scaler.scale(loss).backward()
136
-
137
- return {
138
- 'loss': raw_loss
139
- }
140
-
141
- def optimizer_step(self):
142
- """优化器步骤"""
143
- self.scaler.unscale_(self.optimizer)
144
- grad_norm = torch.nn.utils.clip_grad_norm_(
145
- self.model.parameters(),
146
- self.max_grad_norm
147
- )
148
-
149
- self.scaler.step(self.optimizer)
150
- self.scaler.update()
151
- self.optimizer.zero_grad(set_to_none=True)
152
- self.global_step += 1
153
- return grad_norm.item()
154
-
155
- @torch.no_grad()
156
- def evaluate(self, dataloader, max_batches: int = 50) -> float:
157
- """评估"""
158
- self.model.eval()
159
- total_loss = 0.0
160
- num_batches = 0
161
-
162
- for i, batch in enumerate(dataloader):
163
- if i >= max_batches:
164
- break
165
-
166
- if batch is None:
167
- continue
168
-
169
- instruction_ids = batch['instruction'].to(self.device)
170
- response_ids = batch['response'].to(self.device)
171
- input_ids = torch.cat([instruction_ids, response_ids], dim=1)
172
-
173
- labels = input_ids.clone()
174
- labels[:, :instruction_ids.shape[1]] = -100
175
- labels[input_ids == self.tokenizer.pad_token_id] = -100
176
-
177
- input_data = {
178
- 'segments': [{
179
- 'type': 'text',
180
- 'data': input_ids,
181
- 'modality_id': 0
182
- }]
183
- }
184
-
185
- with torch.amp.autocast('cuda', enabled=self.use_amp):
186
- outputs = self.model(input_data)
187
- logits = outputs['logits']
188
-
189
- shift_logits = logits[:, :-1, :].contiguous()
190
- shift_labels = labels[:, 1:].contiguous()
191
-
192
- loss = F.cross_entropy(
193
- shift_logits.view(-1, shift_logits.size(-1)),
194
- shift_labels.view(-1),
195
- ignore_index=-100
196
- )
197
-
198
- total_loss += loss.item()
199
- num_batches += 1
200
-
201
- self.model.train()
202
- return total_loss / max(num_batches, 1)
203
-
204
- def train(
205
- self,
206
- train_dataloader,
207
- eval_dataloader=None,
208
- resume_from: Optional[str] = None
209
- ):
210
- """训练循环"""
211
- logger.info("\n" + "="*80)
212
- logger.info("Starting Post-Training (SFT)")
213
- logger.info("="*80 + "\n")
214
-
215
- if resume_from:
216
- self.load_checkpoint(resume_from)
217
-
218
- self.model.train()
219
-
220
- for epoch in range(self.num_epochs):
221
- logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
222
-
223
- progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
224
- running_loss = 0.0
225
- step_in_accumulation = 0
226
-
227
- for batch_idx, batch in enumerate(progress_bar):
228
- if batch is None:
229
- continue
230
-
231
- # 训练步骤
232
- stats = self.train_step(batch)
233
- running_loss += stats['loss']
234
- step_in_accumulation += 1
235
-
236
- # 优化器更新
237
- if step_in_accumulation == self.gradient_accumulation_steps:
238
- grad_norm = self.optimizer_step()
239
- step_in_accumulation = 0
240
-
241
- # 更新进度条
242
- progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
243
-
244
- # 日志
245
- if self.global_step % self.log_interval == 0:
246
- avg_loss = running_loss / self.log_interval
247
- logger.info(
248
- f"Step {self.global_step} | "
249
- f"Epoch {epoch+1} | "
250
- f"Loss: {avg_loss:.4f}"
251
- )
252
- running_loss = 0.0
253
-
254
- # 评估
255
- if eval_dataloader and self.global_step % self.eval_interval == 0:
256
- eval_loss = self.evaluate(eval_dataloader)
257
- logger.info(f"Eval Loss: {eval_loss:.4f}")
258
-
259
- if eval_loss < self.best_eval_loss:
260
- self.best_eval_loss = eval_loss
261
- self.save_checkpoint(
262
- self.checkpoint_dir / "best_model.pt",
263
- is_best=True
264
- )
265
-
266
- # 保存
267
- if self.global_step % self.save_interval == 0:
268
- self.save_checkpoint(
269
- self.checkpoint_dir / f"step_{self.global_step}.pt"
270
- )
271
-
272
- # Epoch结束评估
273
- if eval_dataloader:
274
- eval_loss = self.evaluate(eval_dataloader)
275
- logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
276
-
277
- logger.info("\n" + "="*80)
278
- logger.info("Post-Training Complete!")
279
- logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
280
- logger.info("="*80 + "\n")
281
-
282
- self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
283
-
284
- def save_checkpoint(self, path: Path, is_best: bool = False):
285
- """保存checkpoint"""
286
- checkpoint = {
287
- 'model_state_dict': self.model.state_dict(),
288
- 'optimizer_state_dict': self.optimizer.state_dict(),
289
- 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
290
- 'global_step': self.global_step,
291
- 'best_eval_loss': self.best_eval_loss,
292
- 'timestamp': datetime.now().isoformat()
293
- }
294
-
295
- torch.save(checkpoint, path)
296
- logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
297
-
298
- def load_checkpoint(self, path: str):
299
- """加载checkpoint"""
300
- checkpoint = torch.load(path, map_location=self.device)
301
-
302
- self.model.load_state_dict(checkpoint['model_state_dict'])
303
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
304
-
305
- if self.use_amp and checkpoint.get('scaler_state_dict'):
306
- self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
307
-
308
- self.global_step = checkpoint['global_step']
309
- self.best_eval_loss = checkpoint['best_eval_loss']
310
-
311
- logger.info(f"Checkpoint loaded from {path}")
312
-
313
- def main():
314
- """主函数"""
315
- # 配置
316
- config = {
317
- # 模型配置
318
- 'model_dim': 1536,
319
- 'vocab_size': 151665,
320
- 'n_layers': 12,
321
- 'n_heads': 12,
322
- 'n_kv_heads': 4,
323
- 'max_seq_len': 512,
324
- 'dropout': 0.0,
325
- 'use_moe': False,
326
- # 训练配置
327
- 'batch_size': 2,
328
- 'gradient_accumulation_steps': 8,
329
- 'learning_rate': 1e-4,
330
- 'weight_decay': 0.01,
331
- 'num_epochs': 1,
332
- 'max_grad_norm': 1.0,
333
-
334
- # 数据配置
335
- 'data_mix': 'debug_mix',
336
- 'max_samples_train': 1000,
337
- 'max_samples_eval': 1000,
338
- 'max_length': 512,
339
- 'num_workers': 4,
340
-
341
- # RLHF配置
342
- 'do_rlhf': False,
343
- 'preference_dataset': 'hh_rlhf',
344
- 'grpo_iterations': 3,
345
- 'grpo_kl_coef': 0.04,
346
- 'grpo_group_size': 4,
347
-
348
- # 路径
349
- 'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
350
- 'checkpoint_dir': 'checkpoints/posttrain',
351
- 'log_interval': 50,
352
- 'eval_interval': 500,
353
- 'save_interval': 1000,
354
- }
355
-
356
- logger.info("Configuration:")
357
- logger.info(json.dumps(config, indent=2))
358
-
359
- # 初始化tokenizer
360
- logger.info("\nInitializing tokenizer...")
361
- tokenizer = AutoTokenizer.from_pretrained(
362
- "Qwen/Qwen2.5-7B-Instruct",
363
- use_fast=True,
364
- trust_remote_code=True
365
- )
366
-
367
- if tokenizer.pad_token is None:
368
- tokenizer.pad_token = tokenizer.eos_token
369
- tokenizer.pad_token_id = tokenizer.eos_token_id
370
-
371
- config['vocab_size'] = len(tokenizer)
372
-
373
- # 初始化或加载模型
374
- logger.info("\nInitializing model...")
375
- model = MultiModalDenseTransformer(
376
- model_dim=config['model_dim'],
377
- vocab_size=config['vocab_size'],
378
- n_layers=config['n_layers'],
379
- n_heads=config['n_heads'],
380
- n_kv_heads=config['n_kv_heads'],
381
- max_seq_len=config['max_seq_len'],
382
- dropout=config['dropout'],
383
- use_moe=config['use_moe'],
384
- use_gradient_checkpointing=False,
385
- rope_scaling_type="yarn",
386
- use_multimodal_fusion=False,
387
- use_contrastive=False
388
- )
389
-
390
- # 加载预训练checkpoint(如果有)
391
- if config['pretrain_checkpoint']:
392
- logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
393
- checkpoint = torch.load(config['pretrain_checkpoint'])
394
- model.load_state_dict(checkpoint['model_state_dict'])
395
-
396
- # ===== 阶段1: Supervised Fine-Tuning =====
397
- logger.info("\n" + "="*80)
398
- logger.info("PHASE 1: Supervised Fine-Tuning")
399
- logger.info("="*80)
400
-
401
- # 创建数据加载器
402
- train_dataloader = create_posttrain_dataloader(
403
- mix_name=config['data_mix'],
404
- tokenizer=tokenizer,
405
- batch_size=config['batch_size'],
406
- num_workers=config['num_workers'],
407
- max_length=config['max_length'],
408
- max_samples=config['max_samples_train'],
409
- split='train',
410
- shuffle=True
411
- )
412
-
413
- eval_dataloader = create_posttrain_dataloader(
414
- mix_name=config['data_mix'],
415
- tokenizer=tokenizer,
416
- batch_size=config['batch_size'] * 2,
417
- num_workers=config['num_workers'],
418
- max_length=config['max_length'],
419
- max_samples=config['max_samples_eval'],
420
- split='train', # 使用train的后部分作为验证
421
- shuffle=False
422
- )
423
-
424
- # 创建训练器
425
- trainer = PostTrainer(
426
- model=model,
427
- tokenizer=tokenizer,
428
- learning_rate=config['learning_rate'],
429
- weight_decay=config['weight_decay'],
430
- num_epochs=config['num_epochs'],
431
- gradient_accumulation_steps=config['gradient_accumulation_steps'],
432
- max_grad_norm=config['max_grad_norm'],
433
- log_interval=config['log_interval'],
434
- eval_interval=config['eval_interval'],
435
- save_interval=config['save_interval'],
436
- checkpoint_dir=config['checkpoint_dir']
437
- )
438
-
439
- # 开始SFT训练
440
- trainer.train(train_dataloader, eval_dataloader)
441
-
442
- # ===== 阶段2: RLHF with GRPO =====
443
- if config['do_rlhf']:
444
- logger.info("\n" + "="*80)
445
- logger.info("PHASE 2: RLHF with GRPO")
446
- logger.info("="*80)
447
-
448
- try:
449
- # 训练奖励模型
450
- logger.info("\nTraining Reward Model...")
451
-
452
- reward_base_model = copy.deepcopy(model)
453
- reward_model = RewardModel(reward_base_model, use_value_head=True)
454
-
455
- preference_dataloader = create_preference_dataloader(
456
- dataset_name=config['preference_dataset'],
457
- tokenizer=tokenizer,
458
- batch_size=config['batch_size'],
459
- num_workers=config['num_workers'],
460
- max_samples=5000,
461
- split='train'
462
- )
463
-
464
- reward_trainer = RewardModelTrainer(
465
- reward_model=reward_model,
466
- learning_rate=1e-5
467
- )
468
-
469
- reward_trainer.train(preference_dataloader, num_epochs=1)
470
-
471
- # GRPO训练
472
- logger.info("\nStarting GRPO Training...")
473
-
474
- ref_model = copy.deepcopy(model)
475
- ref_model.eval()
476
-
477
- grpo_trainer = GRPOTrainer(
478
- actor_model=model,
479
- reward_model=reward_model,
480
- ref_model=ref_model,
481
- tokenizer=tokenizer,
482
- learning_rate=1e-6,
483
- kl_coef=config['grpo_kl_coef'],
484
- group_size=config['grpo_group_size'],
485
- update_batch_size=2,
486
- use_amp=True
487
- )
488
-
489
- # 准备prompts
490
- prompt_dataloader = create_posttrain_dataloader(
491
- mix_name=config['data_mix'],
492
- tokenizer=tokenizer,
493
- batch_size=4,
494
- num_workers=2,
495
- max_samples=1000,
496
- split='train'
497
- )
498
-
499
- # 提取prompts
500
- prompts = []
501
- for batch in prompt_dataloader:
502
- if batch and batch.get('instruction') is not None:
503
- prompts.append(batch['instruction'])
504
- if len(prompts) >= 200:
505
- break
506
-
507
- if prompts:
508
- prompt_tensor = torch.cat(prompts[:200], dim=0)
509
- from torch.utils.data import TensorDataset, DataLoader
510
- prompt_loader = DataLoader(
511
- TensorDataset(prompt_tensor),
512
- batch_size=4
513
- )
514
-
515
- grpo_trainer.train(
516
- prompt_loader,
517
- num_iterations=config['grpo_iterations'],
518
- max_gen_len=50,
519
- save_path=config['checkpoint_dir'] + "/grpo"
520
- )
521
-
522
- except Exception as e:
523
- logger.error(f"Error in RLHF: {e}")
524
- import traceback
525
- traceback.print_exc()
526
-
527
- logger.info("\n" + "="*80)
528
- logger.info("All Training Complete!")
529
- logger.info("="*80)
530
-
531
- if __name__ == "__main__":
532
- main()