szxllm commited on
Commit
121c049
·
verified ·
1 Parent(s): 4084655

Update pretrain.py

Browse files
Files changed (1) hide show
  1. pretrain.py +459 -501
pretrain.py CHANGED
@@ -1,502 +1,460 @@
1
- # pretrain.py - 完全修复版本
2
-
3
- import os
4
- import torch
5
- import torch.nn.functional as F
6
- from transformers import AutoTokenizer
7
- from pathlib import Path
8
- import logging
9
- from tqdm import tqdm
10
- import json
11
- from datetime import datetime
12
- from model import MultiModalDenseTransformer
13
- from data_loader import create_pretrain_dataloader
14
-
15
- logging.basicConfig(
16
- level=logging.INFO,
17
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
- )
19
- logger = logging.getLogger(__name__)
20
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
21
-
22
-
23
- class PreTrainer:
24
- """预训练器 - 完全修复版"""
25
- def __init__(
26
- self,
27
- model: MultiModalDenseTransformer,
28
- tokenizer,
29
- learning_rate: float = 3e-4,
30
- weight_decay: float = 0.1,
31
- warmup_steps: int = 1000,
32
- max_steps: int = 100000,
33
- gradient_accumulation_steps: int = 16,
34
- max_grad_norm: float = 1.0,
35
- log_interval: int = 10,
36
- save_interval: int = 1000,
37
- checkpoint_dir: str = "checkpoints/pretrain",
38
- loss_log_file: str = "checkpoints/pretrain/train_loss.log"
39
- ):
40
- self.model = model
41
- self.tokenizer = tokenizer
42
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
-
44
- self.model.to(self.device)
45
-
46
- # 优化器配置 - 使用标准AdamW参数
47
- self.optimizer = torch.optim.AdamW(
48
- model.parameters(),
49
- lr=learning_rate,
50
- weight_decay=weight_decay,
51
- betas=(0.9, 0.95),
52
- eps=1e-8
53
- )
54
-
55
- # 🔧 修复:使用更简单的学习率调度器
56
- from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
57
-
58
- # Warmup + Cosine Decay
59
- self.warmup_steps = warmup_steps
60
- self.max_lr = learning_rate
61
- self.min_lr = learning_rate * 0.1
62
- self.current_step = 0
63
-
64
- # 混合精度
65
- self.use_amp = torch.cuda.is_available()
66
- self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
67
-
68
- # 训练参数
69
- self.gradient_accumulation_steps = gradient_accumulation_steps
70
- self.max_grad_norm = max_grad_norm
71
- self.max_steps = max_steps
72
- self.log_interval = log_interval
73
- self.save_interval = save_interval
74
-
75
- # Checkpoint管理
76
- self.checkpoint_dir = Path(checkpoint_dir)
77
- self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
78
-
79
- # 损失日志
80
- self.loss_log_file = Path(loss_log_file)
81
- self.loss_log_file.parent.mkdir(parents=True, exist_ok=True)
82
-
83
- # 训练状态
84
- self.global_step = 0
85
- self.tokens_seen = 0
86
- self.running_loss = 0.0
87
- self.best_loss = float('inf')
88
-
89
- logger.info(f"PreTrainer initialized:")
90
- logger.info(f" Device: {self.device}")
91
- logger.info(f" Learning Rate: {learning_rate}")
92
- logger.info(f" Max Steps: {max_steps}")
93
- logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
94
- logger.info(f" Effective Batch Size: {gradient_accumulation_steps}")
95
- logger.info(f" Mixed Precision: {self.use_amp}")
96
-
97
- def _get_lr(self) -> float:
98
- """手动计算学习率(Warmup + Cosine)"""
99
- if self.current_step < self.warmup_steps:
100
- # Linear warmup
101
- return self.max_lr * (self.current_step / self.warmup_steps)
102
- else:
103
- # Cosine decay
104
- progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
105
- return self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)))
106
-
107
- def _set_lr(self, lr: float):
108
- """设置学习率"""
109
- for param_group in self.optimizer.param_groups:
110
- param_group['lr'] = lr
111
-
112
- def train_step(self, batch: dict) -> dict:
113
- """
114
- 🔧 完全修复的训练步骤
115
- 关键:不要在loss计算时除以gradient_accumulation_steps
116
- """
117
- input_ids = batch['input_ids'].to(self.device)
118
- attention_mask = batch['attention_mask'].to(self.device)
119
- batch_size, seq_len = input_ids.shape
120
- position_ids= torch.zeros_like(input_ids)
121
-
122
- for i in range(batch_size):
123
- non_pad_mask = attention_mask[i].bool()
124
- if non_pad_mask.any():
125
- positions = torch.cumsum(non_pad_mask.long(), dim=0) -1
126
- position_ids[i]=positions * non_pad_mask.long()
127
-
128
-
129
-
130
- # 准备输入
131
- input_data = {
132
- 'segments': [{
133
- 'type': 'text',
134
- 'data': input_ids,
135
- 'modality_id': 0
136
- }]
137
- }
138
-
139
- # 前向传播
140
- with torch.amp.autocast('cuda', enabled=self.use_amp):
141
- outputs = self.model(
142
- input_data,
143
- attention_mask=attention_mask,
144
- position_ids=position_ids)
145
- logits = outputs['logits']
146
-
147
- # 计算损失(标准自回归)
148
- shift_logits = logits[:, :-1, :].contiguous()
149
- shift_labels = input_ids[:, 1:].contiguous()
150
- shift_attention_mask = attention_mask[:, 1:].contiguous()
151
-
152
- # 🔧 关键修复:直接计算平均loss,不要除以gradient_accumulation_steps
153
- loss = F.cross_entropy(
154
- shift_logits.view(-1, shift_logits.size(-1)),
155
- shift_labels.view(-1),
156
- reduction='none'
157
- )
158
-
159
- # 应用mask
160
- loss = (loss * shift_attention_mask.view(-1)).sum() / (shift_attention_mask.sum() + 1e-8)
161
-
162
- # 🔧 重要:为了数值稳定,在这里手动处理梯度累积
163
- # 方法:缩放loss用于反向传播,但记录原始loss
164
- loss_for_backward = loss / self.gradient_accumulation_steps
165
-
166
- # 反向传播(使用缩放后的loss)
167
- self.scaler.scale(loss_for_backward).backward()
168
-
169
- # 🔧 关键修复:不在这里累积loss,改在optimizer_step时累积
170
- # self.running_loss += loss.item() # ❌ 移除
171
- self.tokens_seen += attention_mask.sum().item()
172
-
173
- return {
174
- 'loss': loss.item(), # 返回真实的、未缩放的loss
175
- 'lr': self.optimizer.param_groups[0]['lr']
176
- }
177
-
178
- def optimizer_step(self):
179
- """优化器步骤"""
180
- # Unscale梯度
181
- self.scaler.unscale_(self.optimizer)
182
-
183
- # 梯度裁剪
184
- grad_norm = torch.nn.utils.clip_grad_norm_(
185
- self.model.parameters(),
186
- self.max_grad_norm
187
- )
188
-
189
- # 更新参数
190
- self.scaler.step(self.optimizer)
191
- self.scaler.update()
192
- self.optimizer.zero_grad(set_to_none=True)
193
-
194
- # 更新学习率
195
- self.current_step += 1
196
- self.global_step += 1
197
- lr = self._get_lr()
198
- self._set_lr(lr)
199
-
200
- return grad_norm.item()
201
-
202
- def _write_loss_to_txt(self, step, avg_loss, lr, tokens_seen):
203
- """写入损失日志"""
204
- log_content = (
205
- f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
206
- f"Step: {step}/{self.max_steps}, "
207
- f"Average Loss: {avg_loss:.4f}, "
208
- f"Learning Rate: {lr:.2e}, "
209
- f"Tokens Seen: {tokens_seen/1e9:.2f}B\n"
210
- )
211
- with open(self.loss_log_file, 'a', encoding='utf-8') as f:
212
- f.write(log_content)
213
-
214
- def train(self, dataloader, resume_from=None):
215
- """训练循环"""
216
- logger.info("\n" + "="*80)
217
- logger.info("Starting Pre-Training (Fixed Version)")
218
- logger.info("="*80 + "\n")
219
-
220
- # 恢复训练
221
- if resume_from:
222
- self.load_checkpoint(resume_from)
223
-
224
- # 初始化日志
225
- if not self.loss_log_file.exists():
226
- with open(self.loss_log_file, 'w', encoding='utf-8') as f:
227
- f.write("🚀 Fixed Training Log (Real Loss Values)\n")
228
- f.write("="*80 + "\n")
229
-
230
- self.model.train()
231
- progress_bar = tqdm(total=self.max_steps, initial=self.global_step)
232
-
233
- step_in_accumulation = 0
234
- accumulated_loss = 0.0 # 🔧 用于累积一个完整step的loss
235
-
236
- batches_to_skip = self.global_step * self.gradient_accumulation_steps
237
-
238
- logger.info(f"Current Global Step: {self.global_step}")
239
- if batches_to_skip > 0:
240
- logger.info(f"🔄 Resuming: Need to skip {batches_to_skip} batches to restore data state...")
241
- logger.info("This might take a while depending on network/disk speed...")
242
-
243
- # 创建迭代器
244
- data_iterator = iter(dataloader)
245
-
246
- # 1. 执行跳过逻辑
247
- skipped = 0
248
- if batches_to_skip > 0:
249
- with tqdm(total=batches_to_skip, desc="Skipping trained batches", unit="batch") as skip_pbar:
250
- while skipped < batches_to_skip:
251
- try:
252
- # 只取数据,不进模型,不计算梯度
253
- _ = next(data_iterator)
254
- skipped += 1
255
- skip_pbar.update(1)
256
- except StopIteration:
257
- logger.error("Dataset exhausted during skipping! Check your dataset size or max_steps.")
258
- return
259
-
260
- logger.info("✅ Data fast-forward complete. Resuming training...")
261
-
262
- # 2. 正式训练循环
263
- try:
264
- # 注意:这里不能再用 for batch in dataloader,因为迭代器���经被消费了一部分
265
- # 我们继续使用上面创建的 data_iterator
266
- while True:
267
- try:
268
- batch = next(data_iterator)
269
- except StopIteration:
270
- break # 数据耗尽
271
-
272
- if batch is None or batch['input_ids'].size(0) == 0:
273
- continue
274
- #print("Sample input:", self.tokenizer.decode(batch['input_ids'][0][:50]))
275
- # 训练步骤
276
- stats = self.train_step(batch)
277
- step_in_accumulation += 1
278
- accumulated_loss += stats['loss'] # 🔧 累积当前micro-batch的loss
279
-
280
- # 梯度累积完成,执行优化器更新
281
- if step_in_accumulation >= self.gradient_accumulation_steps:
282
- # 🔧 计算这个完整step的平均loss
283
- avg_step_loss = accumulated_loss / self.gradient_accumulation_steps
284
-
285
- grad_norm = self.optimizer_step()
286
- stats['grad_norm'] = grad_norm
287
- stats['loss'] = avg_step_loss # 🔧 更新为平均loss
288
-
289
- # 🔧 累积到running_loss(用于日志记录)
290
- self.running_loss += avg_step_loss
291
-
292
- step_in_accumulation = 0
293
- accumulated_loss = 0.0 # 🔧 重置累积器
294
-
295
- # 更新进度条
296
- progress_bar.update(1)
297
- progress_bar.set_postfix({
298
- 'loss': f"{stats['loss']:.4f}",
299
- 'lr': f"{stats['lr']:.2e}",
300
- 'tokens': f"{self.tokens_seen/1e9:.2f}B",
301
- 'grad': f"{grad_norm:.2f}"
302
- })
303
-
304
- # 日志记录
305
- if self.global_step % self.log_interval == 0:
306
- avg_loss = self.running_loss / self.log_interval
307
-
308
- logger.info(
309
- f"Step {self.global_step}/{self.max_steps} | "
310
- f"Loss: {avg_loss:.4f} | "
311
- f"LR: {stats['lr']:.2e} | "
312
- f"GradNorm: {grad_norm:.2f} | "
313
- f"Tokens: {self.tokens_seen/1e9:.2f}B"
314
- )
315
-
316
- # 🔧 检测训练异常
317
- if avg_loss > 10.0 and self.global_step > 100:
318
- logger.warning(f"⚠️ Loss异常高 ({avg_loss:.2f}),可能存在问题!")
319
-
320
- if avg_loss < self.best_loss:
321
- self.best_loss = avg_loss
322
- logger.info(f"✨ New best loss: {self.best_loss:.4f}")
323
-
324
- self._write_loss_to_txt(
325
- step=self.global_step,
326
- avg_loss=avg_loss,
327
- lr=stats['lr'],
328
- tokens_seen=self.tokens_seen
329
- )
330
- self.running_loss = 0.0
331
-
332
- # 保存checkpoint
333
- if self.global_step % self.save_interval == 0:
334
- self.save_checkpoint(
335
- self.checkpoint_dir / f"step_{self.global_step}.pt"
336
- )
337
-
338
- # 完成训练
339
- if self.global_step >= self.max_steps:
340
- break
341
-
342
- except KeyboardInterrupt:
343
- logger.info("\n⚠️ Training interrupted by user")
344
- self.save_checkpoint(
345
- self.checkpoint_dir / f"interrupted_step_{self.global_step}.pt"
346
- )
347
-
348
- finally:
349
- progress_bar.close()
350
-
351
- logger.info("\n" + "="*80)
352
- logger.info("Pre-Training Complete!")
353
- logger.info(f" Total Steps: {self.global_step}")
354
- logger.info(f" Total Tokens: {self.tokens_seen/1e9:.2f}B")
355
- logger.info(f" Best Loss: {self.best_loss:.4f}")
356
- logger.info("="*80 + "\n")
357
-
358
- # 保存最终模型
359
- self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
360
-
361
- def save_checkpoint(self, path: Path):
362
- """保存checkpoint"""
363
- checkpoint = {
364
- 'model_state_dict': self.model.state_dict(),
365
- 'optimizer_state_dict': self.optimizer.state_dict(),
366
- 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
367
- 'global_step': self.global_step,
368
- 'current_step': self.current_step,
369
- 'tokens_seen': self.tokens_seen,
370
- 'best_loss': self.best_loss,
371
- 'timestamp': datetime.now().isoformat()
372
- }
373
-
374
- torch.save(checkpoint, path)
375
- logger.info(f"💾 Checkpoint saved to {path}")
376
-
377
- def load_checkpoint(self, path: str):
378
- """加载checkpoint"""
379
- checkpoint = torch.load(path, map_location=self.device, weights_only=True)
380
-
381
- self.model.load_state_dict(checkpoint['model_state_dict'])
382
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
383
-
384
- if self.use_amp and checkpoint.get('scaler_state_dict'):
385
- self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
386
-
387
- self.global_step = checkpoint['global_step']
388
- self.current_step = checkpoint.get('current_step', self.global_step)
389
- self.tokens_seen = checkpoint['tokens_seen']
390
- self.best_loss = checkpoint.get('best_loss', float('inf'))
391
-
392
- logger.info(f"📂 Checkpoint loaded from {path}")
393
- logger.info(f" Resuming from step {self.global_step}")
394
- logger.info(f" Tokens seen: {self.tokens_seen/1e9:.2f}B")
395
-
396
-
397
- def main():
398
- """主函数"""
399
- # 🔧 优化后的配置
400
- config = {
401
- # 模型配置
402
- 'model_dim': 1536,
403
- 'vocab_size': 151665,
404
- 'n_layers': 12,
405
- 'n_heads': 12,
406
- 'n_kv_heads': 4,
407
- 'max_seq_len': 512, # 🔧 减小以提升速度
408
- 'dropout': 0.1,
409
- 'use_moe': False,
410
-
411
- # 🔧 训练配置(关键修复)
412
- 'batch_size': 4, # 增加
413
- 'gradient_accumulation_steps': 8, # 减少
414
- 'learning_rate': 3e-4, # 标准值
415
- 'weight_decay': 0.1,
416
- 'warmup_steps': 500, # 更快warmup
417
- 'max_steps': 10000,
418
- 'max_grad_norm': 1.0,
419
-
420
- # 数据配置
421
- 'data_mix': 'text_only',
422
- 'max_length': 512, # 🔧 与max_seq_len一致
423
- 'num_workers': 2, # 🔧 减少避免网络问题
424
-
425
- # 日志和保存
426
- 'log_interval': 10,
427
- 'save_interval': 500, # 🔧 更频繁保存
428
- 'checkpoint_dir': 'checkpoints/pretrain_fixed',
429
- 'loss_log_file': 'checkpoints/pretrain_fixed/train_loss.log'
430
- }
431
-
432
- logger.info("="*80)
433
- logger.info("🔧 Fixed Configuration:")
434
- logger.info(json.dumps(config, indent=2))
435
- logger.info("="*80 + "\n")
436
-
437
- # 初始化tokenizer
438
- logger.info("Initializing tokenizer...")
439
- tokenizer = AutoTokenizer.from_pretrained(
440
- "Qwen/Qwen2.5-7B-Instruct",
441
- use_fast=True,
442
- trust_remote_code=True
443
- )
444
-
445
- if tokenizer.pad_token is None:
446
- tokenizer.pad_token = tokenizer.eos_token
447
- tokenizer.pad_token_id = tokenizer.eos_token_id
448
-
449
- config['vocab_size'] = len(tokenizer)
450
- logger.info(f"Vocab size: {config['vocab_size']}\n")
451
-
452
- # 初始化模型
453
- logger.info("Initializing model...")
454
- model = MultiModalDenseTransformer(
455
- model_dim=config['model_dim'],
456
- vocab_size=config['vocab_size'],
457
- n_layers=config['n_layers'],
458
- n_heads=config['n_heads'],
459
- n_kv_heads=config['n_kv_heads'],
460
- max_seq_len=config['max_seq_len'],
461
- dropout=config['dropout'],
462
- use_moe=config['use_moe'],
463
- use_gradient_checkpointing=True,
464
- rope_scaling_type="yarn",
465
- use_multimodal_fusion=False,
466
- use_contrastive=False
467
- )
468
-
469
- # 创建数据加载器
470
- logger.info(f"\nCreating dataloader (mix: {config['data_mix']})...")
471
- dataloader = create_pretrain_dataloader(
472
- mix_name=config['data_mix'],
473
- tokenizer=tokenizer,
474
- batch_size=config['batch_size'],
475
- num_workers=config['num_workers'],
476
- max_length=config['max_length']
477
- )
478
-
479
- # 创建训练器
480
- trainer = PreTrainer(
481
- model=model,
482
- tokenizer=tokenizer,
483
- learning_rate=config['learning_rate'],
484
- weight_decay=config['weight_decay'],
485
- warmup_steps=config['warmup_steps'],
486
- max_steps=config['max_steps'],
487
- gradient_accumulation_steps=config['gradient_accumulation_steps'],
488
- max_grad_norm=config['max_grad_norm'],
489
- log_interval=config['log_interval'],
490
- save_interval=config['save_interval'],
491
- checkpoint_dir=config['checkpoint_dir'],
492
- loss_log_file=config['loss_log_file']
493
- )
494
-
495
- # 🔧 开始训练(从头开始,不要用旧的checkpoint)
496
- logger.info("\n🚀 Starting fresh training with fixes...\n")
497
- trainer.train(dataloader, resume_from="/root/step_6500.pt")
498
- #trainer.train(dataloader)
499
-
500
-
501
- if __name__ == "__main__":
502
  main()
 
1
+ import os
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ from pathlib import Path
6
+ import logging
7
+ from tqdm import tqdm
8
+ import json
9
+ from datetime import datetime
10
+ from model import MultiModalDenseTransformer
11
+ from data_loader import create_pretrain_dataloader
12
+
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
19
+
20
+
21
+ class PreTrainer:
22
+ def __init__(
23
+ self,
24
+ model: MultiModalDenseTransformer,
25
+ tokenizer,
26
+ learning_rate: float = 3e-4,
27
+ weight_decay: float = 0.1,
28
+ warmup_steps: int = 1000,
29
+ max_steps: int = 100000,
30
+ gradient_accumulation_steps: int = 16,
31
+ max_grad_norm: float = 1.0,
32
+ log_interval: int = 10,
33
+ save_interval: int = 1000,
34
+ checkpoint_dir: str = "checkpoints/pretrain",
35
+ loss_log_file: str = "checkpoints/pretrain/train_loss.log"
36
+ ):
37
+ self.model = model
38
+ self.tokenizer = tokenizer
39
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+
41
+ self.model.to(self.device)
42
+
43
+ self.optimizer = torch.optim.AdamW(
44
+ model.parameters(),
45
+ lr=learning_rate,
46
+ weight_decay=weight_decay,
47
+ betas=(0.9, 0.95),
48
+ eps=1e-8
49
+ )
50
+
51
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
52
+
53
+ self.warmup_steps = warmup_steps
54
+ self.max_lr = learning_rate
55
+ self.min_lr = learning_rate * 0.1
56
+ self.current_step = 0
57
+
58
+ # 混合精度
59
+ self.use_amp = torch.cuda.is_available()
60
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
61
+
62
+ # 训练参数
63
+ self.gradient_accumulation_steps = gradient_accumulation_steps
64
+ self.max_grad_norm = max_grad_norm
65
+ self.max_steps = max_steps
66
+ self.log_interval = log_interval
67
+ self.save_interval = save_interval
68
+
69
+ # Checkpoint管理
70
+ self.checkpoint_dir = Path(checkpoint_dir)
71
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
72
+
73
+ # 损失日志
74
+ self.loss_log_file = Path(loss_log_file)
75
+ self.loss_log_file.parent.mkdir(parents=True, exist_ok=True)
76
+
77
+ # 训练状态
78
+ self.global_step = 0
79
+ self.tokens_seen = 0
80
+ self.running_loss = 0.0
81
+ self.best_loss = float('inf')
82
+
83
+ logger.info(f"PreTrainer initialized:")
84
+ logger.info(f" Device: {self.device}")
85
+ logger.info(f" Learning Rate: {learning_rate}")
86
+ logger.info(f" Max Steps: {max_steps}")
87
+ logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
88
+ logger.info(f" Effective Batch Size: {gradient_accumulation_steps}")
89
+ logger.info(f" Mixed Precision: {self.use_amp}")
90
+
91
+ def _get_lr(self) -> float:
92
+ """手动计算学习率(Warmup + Cosine)"""
93
+ if self.current_step < self.warmup_steps:
94
+ # Linear warmup
95
+ return self.max_lr * (self.current_step / self.warmup_steps)
96
+ else:
97
+ # Cosine decay
98
+ progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
99
+ return self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)))
100
+
101
+ def _set_lr(self, lr: float):
102
+ """设置学习率"""
103
+ for param_group in self.optimizer.param_groups:
104
+ param_group['lr'] = lr
105
+
106
+ def train_step(self, batch: dict) -> dict:
107
+ input_ids = batch['input_ids'].to(self.device)
108
+ attention_mask = batch['attention_mask'].to(self.device)
109
+ batch_size, seq_len = input_ids.shape
110
+ position_ids= torch.zeros_like(input_ids)
111
+
112
+ for i in range(batch_size):
113
+ non_pad_mask = attention_mask[i].bool()
114
+ if non_pad_mask.any():
115
+ positions = torch.cumsum(non_pad_mask.long(), dim=0) -1
116
+ position_ids[i]=positions * non_pad_mask.long()
117
+
118
+
119
+
120
+ # 准备输入
121
+ input_data = {
122
+ 'segments': [{
123
+ 'type': 'text',
124
+ 'data': input_ids,
125
+ 'modality_id': 0
126
+ }]
127
+ }
128
+
129
+ # 前向传播
130
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
131
+ outputs = self.model(
132
+ input_data,
133
+ attention_mask=attention_mask,
134
+ position_ids=position_ids)
135
+ logits = outputs['logits']
136
+
137
+ # 计算损失(标准自回归)
138
+ shift_logits = logits[:, :-1, :].contiguous()
139
+ shift_labels = input_ids[:, 1:].contiguous()
140
+ shift_attention_mask = attention_mask[:, 1:].contiguous()
141
+
142
+ loss = F.cross_entropy(
143
+ shift_logits.view(-1, shift_logits.size(-1)),
144
+ shift_labels.view(-1),
145
+ reduction='none'
146
+ )
147
+
148
+ # 应用mask
149
+ loss = (loss * shift_attention_mask.view(-1)).sum() / (shift_attention_mask.sum() + 1e-8)
150
+ loss_for_backward = loss / self.gradient_accumulation_steps
151
+
152
+ self.scaler.scale(loss_for_backward).backward()
153
+ self.tokens_seen += attention_mask.sum().item()
154
+
155
+ return {
156
+ 'loss': loss.item(), # 返回真实的、未缩放的loss
157
+ 'lr': self.optimizer.param_groups[0]['lr']
158
+ }
159
+
160
+ def optimizer_step(self):
161
+ """优化器步骤"""
162
+ # Unscale梯度
163
+ self.scaler.unscale_(self.optimizer)
164
+
165
+ # 梯度裁剪
166
+ grad_norm = torch.nn.utils.clip_grad_norm_(
167
+ self.model.parameters(),
168
+ self.max_grad_norm
169
+ )
170
+
171
+ # 更新参数
172
+ self.scaler.step(self.optimizer)
173
+ self.scaler.update()
174
+ self.optimizer.zero_grad(set_to_none=True)
175
+
176
+ # 更新学习率
177
+ self.current_step += 1
178
+ self.global_step += 1
179
+ lr = self._get_lr()
180
+ self._set_lr(lr)
181
+
182
+ return grad_norm.item()
183
+
184
+ def _write_loss_to_txt(self, step, avg_loss, lr, tokens_seen):
185
+ """写入损失日志"""
186
+ log_content = (
187
+ f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
188
+ f"Step: {step}/{self.max_steps}, "
189
+ f"Average Loss: {avg_loss:.4f}, "
190
+ f"Learning Rate: {lr:.2e}, "
191
+ f"Tokens Seen: {tokens_seen/1e9:.2f}B\n"
192
+ )
193
+ with open(self.loss_log_file, 'a', encoding='utf-8') as f:
194
+ f.write(log_content)
195
+
196
+ def train(self, dataloader, resume_from=None):
197
+ """训练循环"""
198
+ logger.info("\n" + "="*80)
199
+ logger.info("Starting Pre-Training (Fixed Version)")
200
+ logger.info("="*80 + "\n")
201
+
202
+ # 恢复训练
203
+ if resume_from:
204
+ self.load_checkpoint(resume_from)
205
+
206
+ # 初始化日志
207
+ if not self.loss_log_file.exists():
208
+ with open(self.loss_log_file, 'w', encoding='utf-8') as f:
209
+ f.write(" Fixed Training Log (Real Loss Values)\n")
210
+ f.write("="*80 + "\n")
211
+
212
+ self.model.train()
213
+ progress_bar = tqdm(total=self.max_steps, initial=self.global_step)
214
+
215
+ step_in_accumulation = 0
216
+ accumulated_loss = 0.0
217
+
218
+ batches_to_skip = self.global_step * self.gradient_accumulation_steps
219
+
220
+ logger.info(f"Current Global Step: {self.global_step}")
221
+ if batches_to_skip > 0:
222
+ logger.info(f" Resuming: Need to skip {batches_to_skip} batches to restore data state...")
223
+ logger.info("This might take a while depending on network/disk speed...")
224
+
225
+ # 创建迭代器
226
+ data_iterator = iter(dataloader)
227
+
228
+ skipped = 0
229
+ if batches_to_skip > 0:
230
+ with tqdm(total=batches_to_skip, desc="Skipping trained batches", unit="batch") as skip_pbar:
231
+ while skipped < batches_to_skip:
232
+ try:
233
+ # 只取数据,不进模型,不计算梯度
234
+ _ = next(data_iterator)
235
+ skipped += 1
236
+ skip_pbar.update(1)
237
+ except StopIteration:
238
+ logger.error("Dataset exhausted during skipping! Check your dataset size or max_steps.")
239
+ return
240
+
241
+ logger.info(" Data fast-forward complete. Resuming training...")
242
+
243
+ try:
244
+ while True:
245
+ try:
246
+ batch = next(data_iterator)
247
+ except StopIteration:
248
+ break
249
+
250
+ if batch is None or batch['input_ids'].size(0) == 0:
251
+ continue
252
+ stats = self.train_step(batch)
253
+ step_in_accumulation += 1
254
+ accumulated_loss += stats['loss']
255
+
256
+ if step_in_accumulation >= self.gradient_accumulation_steps:
257
+ avg_step_loss = accumulated_loss / self.gradient_accumulation_steps
258
+ grad_norm = self.optimizer_step()
259
+ stats['grad_norm'] = grad_norm
260
+ stats['loss'] = avg_step_loss
261
+ self.running_loss += avg_step_loss
262
+
263
+ step_in_accumulation = 0
264
+ accumulated_loss = 0.0
265
+ progress_bar.update(1)
266
+ progress_bar.set_postfix({
267
+ 'loss': f"{stats['loss']:.4f}",
268
+ 'lr': f"{stats['lr']:.2e}",
269
+ 'tokens': f"{self.tokens_seen/1e9:.2f}B",
270
+ 'grad': f"{grad_norm:.2f}"
271
+ })
272
+
273
+ # 日志记录
274
+ if self.global_step % self.log_interval == 0:
275
+ avg_loss = self.running_loss / self.log_interval
276
+
277
+ logger.info(
278
+ f"Step {self.global_step}/{self.max_steps} | "
279
+ f"Loss: {avg_loss:.4f} | "
280
+ f"LR: {stats['lr']:.2e} | "
281
+ f"GradNorm: {grad_norm:.2f} | "
282
+ f"Tokens: {self.tokens_seen/1e9:.2f}B"
283
+ )
284
+
285
+ if avg_loss < self.best_loss:
286
+ self.best_loss = avg_loss
287
+ logger.info(f" New best loss: {self.best_loss:.4f}")
288
+
289
+ self._write_loss_to_txt(
290
+ step=self.global_step,
291
+ avg_loss=avg_loss,
292
+ lr=stats['lr'],
293
+ tokens_seen=self.tokens_seen
294
+ )
295
+ self.running_loss = 0.0
296
+
297
+ # 保存checkpoint
298
+ if self.global_step % self.save_interval == 0:
299
+ self.save_checkpoint(
300
+ self.checkpoint_dir / f"step_{self.global_step}.pt"
301
+ )
302
+
303
+ # 完成训练
304
+ if self.global_step >= self.max_steps:
305
+ break
306
+
307
+ except KeyboardInterrupt:
308
+ self.save_checkpoint(
309
+ self.checkpoint_dir / f"interrupted_step_{self.global_step}.pt"
310
+ )
311
+
312
+ finally:
313
+ progress_bar.close()
314
+
315
+ logger.info("\n" + "="*80)
316
+ logger.info("Pre-Training Complete!")
317
+ logger.info(f" Total Steps: {self.global_step}")
318
+ logger.info(f" Total Tokens: {self.tokens_seen/1e9:.2f}B")
319
+ logger.info(f" Best Loss: {self.best_loss:.4f}")
320
+ logger.info("="*80 + "\n")
321
+
322
+ # 保存最终模型
323
+ self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
324
+
325
+ def save_checkpoint(self, path: Path):
326
+ """保存checkpoint"""
327
+ checkpoint = {
328
+ 'model_state_dict': self.model.state_dict(),
329
+ 'optimizer_state_dict': self.optimizer.state_dict(),
330
+ 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
331
+ 'global_step': self.global_step,
332
+ 'current_step': self.current_step,
333
+ 'tokens_seen': self.tokens_seen,
334
+ 'best_loss': self.best_loss,
335
+ 'timestamp': datetime.now().isoformat()
336
+ }
337
+
338
+ torch.save(checkpoint, path)
339
+ logger.info(f" Checkpoint saved to {path}")
340
+
341
+ def load_checkpoint(self, path: str):
342
+ """加载checkpoint"""
343
+ checkpoint = torch.load(path, map_location=self.device, weights_only=True)
344
+
345
+ self.model.load_state_dict(checkpoint['model_state_dict'])
346
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
347
+
348
+ if self.use_amp and checkpoint.get('scaler_state_dict'):
349
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
350
+
351
+ self.global_step = checkpoint['global_step']
352
+ self.current_step = checkpoint.get('current_step', self.global_step)
353
+ self.tokens_seen = checkpoint['tokens_seen']
354
+ self.best_loss = checkpoint.get('best_loss', float('inf'))
355
+
356
+ logger.info(f" Checkpoint loaded from {path}")
357
+ logger.info(f" Resuming from step {self.global_step}")
358
+ logger.info(f" Tokens seen: {self.tokens_seen/1e9:.2f}B")
359
+
360
+
361
+ def main():
362
+ config = {
363
+ # 模型配置
364
+ 'model_dim': 1536,
365
+ 'vocab_size': 151665,
366
+ 'n_layers': 12,
367
+ 'n_heads': 12,
368
+ 'n_kv_heads': 4,
369
+ 'max_seq_len': 512,
370
+ 'dropout': 0.1,
371
+ 'use_moe': False,
372
+ 'batch_size': 4,
373
+ 'gradient_accumulation_steps': 8,
374
+ 'learning_rate': 3e-4,
375
+ 'weight_decay': 0.1,
376
+ 'warmup_steps': 500,
377
+ 'max_steps': 10000,
378
+ 'max_grad_norm': 1.0,
379
+
380
+ # 数据配置
381
+ 'data_mix': 'text_only',
382
+ 'max_length': 512,
383
+ 'num_workers': 2,
384
+
385
+ # 日志和保存
386
+ 'log_interval': 10,
387
+ 'save_interval': 500,
388
+ 'checkpoint_dir': 'checkpoints/pretrain_fixed',
389
+ 'loss_log_file': 'checkpoints/pretrain_fixed/train_loss.log'
390
+ }
391
+
392
+ logger.info("="*80)
393
+ logger.info(json.dumps(config, indent=2))
394
+ logger.info("="*80 + "\n")
395
+
396
+ # 初始化tokenizer
397
+ logger.info("Initializing tokenizer...")
398
+ tokenizer = AutoTokenizer.from_pretrained(
399
+ "Qwen/Qwen2.5-7B-Instruct",
400
+ use_fast=True,
401
+ trust_remote_code=True
402
+ )
403
+
404
+ if tokenizer.pad_token is None:
405
+ tokenizer.pad_token = tokenizer.eos_token
406
+ tokenizer.pad_token_id = tokenizer.eos_token_id
407
+
408
+ config['vocab_size'] = len(tokenizer)
409
+ logger.info(f"Vocab size: {config['vocab_size']}\n")
410
+
411
+ # 初始化模型
412
+ logger.info("Initializing model...")
413
+ model = MultiModalDenseTransformer(
414
+ model_dim=config['model_dim'],
415
+ vocab_size=config['vocab_size'],
416
+ n_layers=config['n_layers'],
417
+ n_heads=config['n_heads'],
418
+ n_kv_heads=config['n_kv_heads'],
419
+ max_seq_len=config['max_seq_len'],
420
+ dropout=config['dropout'],
421
+ use_moe=config['use_moe'],
422
+ use_gradient_checkpointing=True,
423
+ rope_scaling_type="yarn",
424
+ use_multimodal_fusion=False,
425
+ use_contrastive=False
426
+ )
427
+
428
+ # 创建数据加载器
429
+ logger.info(f"\nCreating dataloader (mix: {config['data_mix']})...")
430
+ dataloader = create_pretrain_dataloader(
431
+ mix_name=config['data_mix'],
432
+ tokenizer=tokenizer,
433
+ batch_size=config['batch_size'],
434
+ num_workers=config['num_workers'],
435
+ max_length=config['max_length']
436
+ )
437
+
438
+ # 创建训练器
439
+ trainer = PreTrainer(
440
+ model=model,
441
+ tokenizer=tokenizer,
442
+ learning_rate=config['learning_rate'],
443
+ weight_decay=config['weight_decay'],
444
+ warmup_steps=config['warmup_steps'],
445
+ max_steps=config['max_steps'],
446
+ gradient_accumulation_steps=config['gradient_accumulation_steps'],
447
+ max_grad_norm=config['max_grad_norm'],
448
+ log_interval=config['log_interval'],
449
+ save_interval=config['save_interval'],
450
+ checkpoint_dir=config['checkpoint_dir'],
451
+ loss_log_file=config['loss_log_file']
452
+ )
453
+
454
+ logger.info("\n Starting fresh training with fixes...\n")
455
+ trainer.train(dataloader, resume_from="/root/step_6500.pt")
456
+ #trainer.train(dataloader)
457
+
458
+
459
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  main()