gate369 commited on
Commit
d0e6871
·
verified ·
1 Parent(s): 46c8f97

Create Initial_Train_MoR.py

Browse files
Files changed (1) hide show
  1. Initial_Train_MoR.py +545 -0
Initial_Train_MoR.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################
2
+ #Mixture of Recursions w/ Expert Choice Routing#
3
+ ################################################
4
+
5
+ #This code is what i used to initially train this model. I continued training with 'Continue_Training_MoR.py'
6
+ from re import M
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from torch.utils.checkpoint import checkpoint
13
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers
14
+ from tqdm import tqdm
15
+ import matplotlib.pyplot as plt
16
+ from torch.cuda.amp import autocast, GradScaler
17
+ import numpy as np
18
+ import os
19
+ from safetensors.torch import save_file
20
+ import json
21
+ import os
22
+ from transformers import PreTrainedTokenizerFast
23
+ # Add this at the top to help with debugging
24
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
25
+ def save_huggingface_model(model, tokenizer, folder_path="MoR-v1"):
26
+ # Create directory structure
27
+ os.makedirs(folder_path, exist_ok=True)
28
+ # 1. Save model weights in safetensors format
29
+ weights = model.state_dict()
30
+ save_file(weights, os.path.join(folder_path, "model.safetensors"))
31
+ # 2. Create and save config.json
32
+ config = {
33
+ "vocab_size": VOCAB_SIZE,
34
+ "dim": DIM,
35
+ "num_layers": NUM_LAYERS,
36
+ "num_heads": HEADS,
37
+ "max_recursion": MAX_RECURSIONS,
38
+ "num_experts": MAX_RECURSIONS,
39
+ "ffn_expansion": 4,
40
+ "max_position_embeddings": 2048,
41
+ "model_type": "MoR",
42
+ "architecture": "MixtureOfRecursions",
43
+ "hidden_act": "gelu"
44
+ }
45
+ with open(os.path.join(folder_path, "config.json"), "w") as f:
46
+ json.dump(config, f, indent=2)
47
+ # 3. Save tokenizer files
48
+ hf_tokenizer = PreTrainedTokenizerFast(
49
+ tokenizer_object=tokenizer,
50
+ unk_token="[UNK]",
51
+ pad_token="[PAD]",
52
+ bos_token="[BOS]",
53
+ eos_token="[EOS]",
54
+ )
55
+ hf_tokenizer.save_pretrained(folder_path)
56
+ # 4. Create safetensors index file
57
+ index = {
58
+ "metadata": {"total_size": sum(p.numel() * p.element_size() for p in model.parameters())},
59
+ "weight_map": {name: "model.safetensors" for name in weights.keys()}
60
+ }
61
+ with open(os.path.join(folder_path, "model.safetensors.index.json"), "w") as f:
62
+ json.dump(index, f, indent=2)
63
+ print(f"Model saved in Hugging Face format to {folder_path}/")
64
+
65
+ VOCAB_SIZE = 10000
66
+ DIM = 1536
67
+ NUM_LAYERS = 6
68
+ HEADS = 8
69
+ BATCH_SIZE = 32
70
+ SEQ_LEN = 512
71
+ MAX_RECURSIONS = 4
72
+ learn_rate = 5e-5
73
+ EPOCHS = 3
74
+ NUM_EXPERTS = 12
75
+ GRAD_ACCUM_STEPS = 4 # Gradient accumulation steps
76
+
77
+ # ----------------------
78
+ # Character-Level Tokenizer
79
+ # ----------------------
80
+ def train_tokenizer(file_path, vocab_size=VOCAB_SIZE):
81
+ print("Training tokenizer...")
82
+ tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
83
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
84
+ # GPU-accelerated text loading and preprocessing
85
+ if torch.cuda.is_available():
86
+ print("Using GPU for text preprocessing...")
87
+ with open(file_path, 'r') as f:
88
+ text = f.read()
89
+ # Process text in chunks on GPU
90
+ chunk_size = 1000000 # 1 million characters per chunk
91
+ chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
92
+ processed_chunks = []
93
+ for chunk in tqdm(chunks, desc="Processing text chunks on GPU"):
94
+ # Create tensor on GPU
95
+ chunk_tensor = torch.tensor([ord(c) for c in chunk], dtype=torch.int32, device='cuda')
96
+ # Simple GPU preprocessing (example: remove control characters)
97
+ processed_tensor = chunk_tensor[chunk_tensor >= 32] # Keep only printable ASCII
98
+ processed_chunks.append(processed_tensor.cpu().numpy().tobytes().decode('utf-8', errors='replace'))
99
+ text = ''.join(processed_chunks)
100
+ trainer = trainers.BpeTrainer(
101
+ vocab_size=vocab_size,
102
+ special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
103
+ min_frequency=2
104
+ )
105
+
106
+ # Train tokenizer using memory-mapped files for large datasets
107
+ if os.path.getsize(file_path) > 100 * 1024 * 1024: # > 100MB
108
+ print("Using memory-mapped files for large dataset...")
109
+ tokenizer.train([file_path], trainer=trainer)
110
+ else:
111
+ # For smaller datasets, use preprocessed text
112
+ tokenizer.train_from_iterator([text], trainer=trainer, length=len(text))
113
+ print("Tokenizer successfully trained")
114
+ return tokenizer
115
+
116
+ def prepare_datasets(file_path, tokenizer, seq_len=SEQ_LEN, val_split=0.05):
117
+ print("Preparing datasets with GPU acceleration...")
118
+ # Memory-mapped file reading for large datasets
119
+ with open(file_path, 'r') as f:
120
+ text = f.read()
121
+ # GPU-accelerated tokenization pipeline
122
+ if torch.cuda.is_available():
123
+ print("Using GPU for tokenization pipeline...")
124
+ # Process text in chunks
125
+ chunk_size = 500000 # 500k characters per chunk
126
+ chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
127
+ encoded_chunks = []
128
+ for chunk in tqdm(chunks, desc="Tokenizing on GPU"):
129
+ # Encode on CPU
130
+ chunk_encoded = tokenizer.encode(chunk).ids
131
+ # Move to GPU for processing
132
+ chunk_tensor = torch.tensor(chunk_encoded, device='cuda')
133
+ encoded_chunks.append(chunk_tensor)
134
+ # Concatenate all chunks on GPU
135
+ encoded = torch.cat(encoded_chunks)
136
+ else:
137
+ # CPU fallback
138
+ encoded = tokenizer.encode(text).ids
139
+ encoded = torch.tensor(encoded, device='cpu')
140
+ total_tokens = len(encoded)
141
+ split_idx = int(total_tokens * (1 - val_split))
142
+ # Create datasets with direct device placement
143
+ train_dataset = TextDataset(encoded[:split_idx], seq_len)
144
+ val_dataset = TextDataset(encoded[split_idx:], seq_len)
145
+ total_batch_length = len(train_dataset)
146
+ print(f"Training samples: {total_batch_length}")
147
+ print(f"Validation samples: {len(val_dataset)}")
148
+ print(f"Total tokens: {total_tokens}")
149
+ return train_dataset, val_dataset
150
+
151
+ class TextDataset(Dataset):
152
+ def __init__(self, encoded_data, seq_len=SEQ_LEN):
153
+ # Keep data on its original device (GPU/CPU)
154
+ self.encoded = encoded_data
155
+ self.seq_len = seq_len
156
+ self.device = encoded_data.device
157
+
158
+ def __len__(self):
159
+ return len(self.encoded) // self.seq_len
160
+
161
+ def __getitem__(self, idx):
162
+ start = idx * self.seq_len
163
+ end = start + self.seq_len + 1
164
+ segment = self.encoded[start:end]
165
+ # Return tensors directly on correct device
166
+ return segment[:-1], segment[1:]
167
+
168
+ # ----------------------
169
+ # MoR Model Components
170
+ # ----------------------
171
+ print("Defining components...")
172
+ class ExpertChoiceRouter(nn.Module):
173
+ """Expert Choice Routing: Experts select top-k tokens"""
174
+ def __init__(self, dim, num_experts, k=2):
175
+ super().__init__()
176
+ self.num_experts = num_experts
177
+ self.k = k
178
+ self.gate = nn.Linear(dim, num_experts, bias=False)
179
+
180
+ def forward(self, x):
181
+ # x: (batch, seq_len, dim)
182
+ scores = self.gate(x) # (batch, seq_len, num_experts)
183
+ expert_weights, expert_indices = torch.topk(scores, self.k, dim=-1)
184
+ return expert_weights.softmax(dim=-1), expert_indices
185
+
186
+ # ----------------------
187
+ # 4-bit Quantization Utilities
188
+ # ----------------------
189
+ # Improved Quantization with gradient scaling
190
+ class Quantizer4Bit(nn.Module):
191
+ def __init__(self):
192
+ super().__init__()
193
+
194
+ @staticmethod
195
+ def quantize(tensor):
196
+ """Quantize tensor to 4-bit integers with gradient scaling"""
197
+ # Use per-tensor scaling with safe normalization
198
+ max_val = tensor.abs().max()
199
+ scale = max_val / 7.5 if max_val > 1e-8 else 1.0
200
+ quantized = torch.clamp(torch.round(tensor / scale), -8, 7)
201
+ return quantized.to(torch.int8), scale
202
+
203
+ @staticmethod
204
+ def dequantize(quantized, scale):
205
+ """Dequantize 4-bit integers to float"""
206
+ return quantized.float() * scale
207
+
208
+ # Weight initialization function
209
+ def init_weights(module):
210
+ if isinstance(module, nn.Linear):
211
+ nn.init.xavier_uniform_(module.weight)
212
+ if module.bias is not None:
213
+ nn.init.zeros_(module.bias)
214
+ elif isinstance(module, nn.Embedding):
215
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
216
+ elif isinstance(module, nn.LayerNorm):
217
+ nn.init.ones_(module.weight)
218
+ nn.init.zeros_(module.bias)
219
+
220
+ # ----------------------
221
+ # MoR Model Components with Quantization
222
+ # ----------------------
223
+ class QuantizedRecursiveTransformerBlock(nn.Module):
224
+ def __init__(self, dim, num_heads, ffn_expansion=4):
225
+ super().__init__()
226
+ self.dim = dim
227
+ self.num_heads = num_heads
228
+ self.head_dim = dim // num_heads
229
+ # Attention layers
230
+ self.q_proj = nn.Linear(dim, dim)
231
+ self.k_proj = nn.Linear(dim, dim)
232
+ self.v_proj = nn.Linear(dim, dim)
233
+ self.attn_out = nn.Linear(dim, dim)
234
+ # FFN layers
235
+ self.ffn = nn.Sequential(
236
+ nn.Linear(dim, ffn_expansion * dim),
237
+ nn.GELU(),
238
+ nn.Linear(ffn_expansion * dim, dim)
239
+ )
240
+ # Normalization
241
+ self.norm1 = nn.LayerNorm(dim)
242
+ self.norm2 = nn.LayerNorm(dim)
243
+
244
+ def forward(self, x):
245
+ # Use gradient checkpointing for this block
246
+ return checkpoint(self._forward, x, use_reentrant=False)
247
+
248
+ def _forward(self, x):
249
+ # x: (batch, seq_len, dim)
250
+ residual = x
251
+ x = self.norm1(x)
252
+ # Projections
253
+ q = self.q_proj(x)
254
+ k = self.k_proj(x)
255
+ v = self.v_proj(x)
256
+ # Quantize K and V
257
+ k_quant, k_scale = Quantizer4Bit.quantize(k)
258
+ v_quant, v_scale = Quantizer4Bit.quantize(v)
259
+ # Dequantize for computation
260
+ k = Quantizer4Bit.dequantize(k_quant, k_scale)
261
+ v = Quantizer4Bit.dequantize(v_quant, v_scale)
262
+ # Attention
263
+ B, T, _ = q.shape
264
+ q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
265
+ k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
266
+ v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
267
+ # Memory-efficient attention computation
268
+ attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
269
+ attn = attn.softmax(dim=-1)
270
+ attn_out = (attn @ v).transpose(1, 2).contiguous().view(B, T, self.dim)
271
+ attn_out = self.attn_out(attn_out)
272
+ # Residual connection
273
+ x = residual + attn_out
274
+ # FFN
275
+ x = x + self.ffn(self.norm2(x))
276
+ return x
277
+
278
+ class RecursionDepthRouter(nn.Module):
279
+ """Lightweight Router for Dynamic Recursion Depth"""
280
+ def __init__(self, dim, max_depth=4):
281
+ super().__init__()
282
+ self.max_depth = max_depth
283
+ self.router = nn.Sequential(
284
+ nn.Linear(dim, dim), # Increased capacity
285
+ nn.ReLU(),
286
+ nn.Linear(dim, max_depth)
287
+ )
288
+ # Initialize router weights properly
289
+ for layer in self.router:
290
+ if isinstance(layer, nn.Linear):
291
+ nn.init.xavier_uniform_(layer.weight)
292
+ nn.init.zeros_(layer.bias)
293
+
294
+ def forward(self, x):
295
+ # x: (batch, seq_len, dim)
296
+ # Global average pooling across batch and sequence
297
+ x_pooled = x.mean(dim=(0, 1)) # (dim)
298
+ router_logits = self.router(x_pooled) # (max_depth)
299
+ return router_logits.softmax(dim=-1)
300
+
301
+ # ----------------------
302
+ # Main MoR Architecture (with Quantization)
303
+ # ----------------------
304
+ class QuantizedMoRModel(nn.Module):
305
+ def __init__(self, vocab_size, dim=DIM, num_layers=NUM_LAYERS,
306
+ num_heads=HEADS, max_recursion=MAX_RECURSIONS, num_experts=NUM_EXPERTS):
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.max_recursion = max_recursion
310
+ self.num_experts = num_experts
311
+ # Embedding layers (unique parameters)
312
+ self.embedding = nn.Embedding(vocab_size, dim)
313
+ self.pos_embed = nn.Embedding(2048, dim)
314
+ # Initial unique layers
315
+ self.init_layers = nn.ModuleList([
316
+ QuantizedRecursiveTransformerBlock(dim, num_heads)
317
+ for _ in range(2)
318
+ ])
319
+ # Middle-cycle shared layers
320
+ self.cycle_depth = 3
321
+ self.recursive_blocks = nn.ModuleList([
322
+ QuantizedRecursiveTransformerBlock(dim, num_heads)
323
+ for _ in range(self.cycle_depth)
324
+ ])
325
+ # Recursion routers
326
+ self.recursion_routers = nn.ModuleList([
327
+ RecursionDepthRouter(dim, max_depth=max_recursion)
328
+ for _ in range(num_layers - 4)
329
+ ])
330
+ # Expert choice routing
331
+ self.expert_routers = nn.ModuleList([
332
+ ExpertChoiceRouter(dim, num_experts)
333
+ for _ in range(max_recursion)
334
+ ])
335
+ # Final unique layers
336
+ self.final_layers = nn.ModuleList([
337
+ QuantizedRecursiveTransformerBlock(dim, num_heads)
338
+ for _ in range(2)
339
+ ])
340
+ # Output head
341
+ self.ln_f = nn.LayerNorm(dim)
342
+ self.head = nn.Linear(dim, vocab_size, bias=False)
343
+
344
+ def forward(self, x):
345
+ # Embedding with scaling
346
+ pos = torch.arange(0, x.shape[1], device=x.device)
347
+ x = self.embedding(x) * 0.02 # Scale embeddings
348
+ x = x + self.pos_embed(pos)
349
+ for layer in self.init_layers:
350
+ x = layer(x) * 0.8 # Scale residual
351
+ # Middle-cycle with recursion
352
+ batch_size, seq_len, _ = x.shape
353
+ recursion_outputs = []
354
+
355
+ for router in self.recursion_routers:
356
+ # Get recursion depth probabilities (scalar for whole batch)
357
+ depth_probs = router(x) # (max_depth)
358
+ # Sample single depth for entire batch
359
+ depth = torch.multinomial(depth_probs, 1).item() # convert to int
360
+
361
+ # Process through recursive blocks
362
+ expert_weights, expert_indices = self.expert_routers[depth](x)
363
+
364
+ # Create full weight matrix
365
+ full_weights = torch.zeros((batch_size, seq_len, self.num_experts),
366
+ device=x.device)
367
+ full_weights.scatter_(2, expert_indices, expert_weights)
368
+
369
+ # Process each expert in parallel without conditionals
370
+ expert_outputs = []
371
+ for expert_idx in range(self.num_experts):
372
+ # Create expert input using weights
373
+ expert_x = x * full_weights[:, :, expert_idx].unsqueeze(-1)
374
+ # Process through block
375
+ out = self.recursive_blocks[depth % self.cycle_depth](expert_x)
376
+ expert_outputs.append(out)
377
+
378
+ # Combine expert outputs
379
+ x = sum(expert_outputs)
380
+ recursion_outputs.append(x)
381
+
382
+ # Combine outputs from different recursion depths
383
+ if recursion_outputs:
384
+ x = torch.stack(recursion_outputs).mean(dim=0)
385
+
386
+ # Final unique layers
387
+ for layer in self.final_layers:
388
+ x = layer(x)
389
+
390
+ # Output
391
+ x = self.ln_f(x)
392
+ logits = self.head(x)
393
+ return logits
394
+
395
+ # ----------------------
396
+ # Learning Rate Scheduler
397
+ # ----------------------
398
+ def get_lr(current_step, total_steps, warmup_steps, max_lr):
399
+ """Cosine annealing with warmup"""
400
+ if current_step < warmup_steps:
401
+ return max_lr * (current_step / warmup_steps)
402
+ else:
403
+ decay_ratio = (current_step - warmup_steps) / (total_steps - warmup_steps)
404
+ return max_lr * 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
405
+
406
+ # ----------------------
407
+ # Training Loop with Validation
408
+ # ----------------------
409
+ def train_model():
410
+ # Config
411
+ LR = learn_rate
412
+ # Initialize tokenizer and datasets
413
+ tokenizer = train_tokenizer("input.txt", VOCAB_SIZE)
414
+ train_dataset, val_dataset = prepare_datasets("input.txt", tokenizer, SEQ_LEN, val_split=0.05)
415
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
416
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
417
+
418
+ # Initialize model
419
+ model = QuantizedMoRModel(
420
+ vocab_size=VOCAB_SIZE,
421
+ dim=DIM,
422
+ num_layers=NUM_LAYERS,
423
+ num_heads=HEADS
424
+ )
425
+ model.apply(init_weights)
426
+
427
+ # Parameter counting
428
+ total_params = sum(p.numel() for p in model.parameters())
429
+ print(f"Model Parameters: {total_params/1e6:.2f}M")
430
+
431
+ # Optimizer
432
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
433
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
434
+ model = model.to(device)
435
+
436
+ # Mixed precision training
437
+ scaler = GradScaler()
438
+
439
+ # Training setup
440
+ total_steps = EPOCHS * len(train_loader)
441
+ warmup_steps = int(0.1 * total_steps) # 10% warmup
442
+ print(f"Total training steps: {total_steps}, Warmup steps: {warmup_steps}")
443
+
444
+ # Training loop
445
+ train_losses = []
446
+ val_losses = []
447
+ best_val_loss = float('inf')
448
+
449
+ for epoch in range(EPOCHS):
450
+ # Training phase
451
+ model.train()
452
+ epoch_train_loss = 0
453
+ accumulated_loss = 0
454
+ optimizer.zero_grad()
455
+
456
+ for step, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} Training")):
457
+ global_step = epoch * len(train_loader) + step
458
+ current_lr = get_lr(global_step, total_steps, warmup_steps, LR)
459
+
460
+ # Update learning rate
461
+ for param_group in optimizer.param_groups:
462
+ param_group['lr'] = current_lr
463
+
464
+ inputs, targets = inputs.to(device), targets.to(device)
465
+
466
+ with autocast():
467
+ logits = model(inputs)
468
+ loss = F.cross_entropy(
469
+ logits.view(-1, VOCAB_SIZE),
470
+ targets.view(-1),
471
+ ignore_index=0 # Ignore padding index
472
+ ) / GRAD_ACCUM_STEPS
473
+
474
+ # Scale loss and backprop
475
+ scaler.scale(loss).backward()
476
+ accumulated_loss += loss.item() * GRAD_ACCUM_STEPS
477
+
478
+ # Print every 100 batches (not update steps)
479
+ if step % 100 == 0:
480
+ print(f"Step {global_step}: Batch Loss={accumulated_loss:.4f}, LR={current_lr:.2e}")
481
+
482
+ # Gradient accumulation
483
+ if (step + 1) % GRAD_ACCUM_STEPS == 0 or step == len(train_loader) - 1:
484
+ # Gradient clipping
485
+ scaler.unscale_(optimizer)
486
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
487
+
488
+ # Update weights
489
+ scaler.step(optimizer)
490
+ scaler.update()
491
+ optimizer.zero_grad()
492
+
493
+ # Logging for update steps
494
+ epoch_train_loss += accumulated_loss
495
+ #print(f"UPDATE Step {global_step}/{total_steps}: Loss={accumulated_loss:.4f}, GradNorm={grad_norm:.4f}")
496
+ accumulated_loss = 0
497
+
498
+ avg_train_loss = epoch_train_loss / len(train_loader)
499
+ train_losses.append(avg_train_loss)
500
+
501
+ # Validation phase
502
+ model.eval()
503
+ epoch_val_loss = 0
504
+ with torch.no_grad():
505
+ for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
506
+ inputs, targets = inputs.to(device), targets.to(device)
507
+ with autocast():
508
+ logits = model(inputs)
509
+ loss = F.cross_entropy(
510
+ logits.view(-1, VOCAB_SIZE),
511
+ targets.view(-1),
512
+ ignore_index=0
513
+ )
514
+ epoch_val_loss += loss.item()
515
+
516
+ avg_val_loss = epoch_val_loss / len(val_loader)
517
+ val_losses.append(avg_val_loss)
518
+
519
+ # Save best model
520
+ if avg_val_loss < best_val_loss:
521
+ best_val_loss = avg_val_loss
522
+ save_huggingface_model(model, tokenizer, "MoR-v1")
523
+ print(f"Saved new best model with val loss: {best_val_loss:.4f}")
524
+
525
+ print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | LR: {current_lr:.2e}")
526
+
527
+ # Plot training and validation
528
+ plt.figure(figsize=(10, 5))
529
+ plt.plot(train_losses, label='Training Loss')
530
+ plt.plot(val_losses, label='Validation Loss')
531
+ plt.title("Training and Validation Loss")
532
+ plt.xlabel("Epoch")
533
+ plt.ylabel("Loss")
534
+ plt.legend()
535
+ plt.savefig("training_validation_loss.png")
536
+
537
+ # Save final model
538
+ save_huggingface_model(model, tokenizer, "MoR-v1")
539
+ print("Training complete. Models saved.")
540
+
541
+ # ----------------------
542
+ # Execution
543
+ # ----------------------
544
+ if __name__ == "__main__":
545
+ train_model()