ikaganacar commited on
Commit
8f73121
·
1 Parent(s): 6b1c605

Some Fixes

Browse files
Model_Architecture/diognose_weights.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def diagnose_checkpoint(checkpoint_path, config, device):
2
+ """Diagnose if the checkpoint has actually learned anything"""
3
+ import torch
4
+ import numpy as np
5
+
6
+ print("🔍 Diagnosing checkpoint...")
7
+
8
+ # Load checkpoint
9
+ ckpt = torch.load(checkpoint_path, map_location=device)
10
+
11
+ # Create model with fixes
12
+ from model import ismail, ModelArgs
13
+ args = ModelArgs(**config["model"])
14
+ model = ismail(args).to(device)
15
+
16
+ # Load weights
17
+ model.load_state_dict(ckpt["model_state_dict"], strict=False)
18
+ model.eval()
19
+
20
+ # Check expert weight statistics
21
+ print("\n📊 Expert Weight Analysis:")
22
+ for name, param in model.named_parameters():
23
+ if "experts" in name and "routed" in name:
24
+ expert_idx = int(name.split("experts.")[1].split(".")[0])
25
+ weight_std = param.std().item()
26
+ weight_mean = param.mean().item()
27
+ print(f" Expert {expert_idx}: mean={weight_mean:.6f}, std={weight_std:.6f}")
28
+
29
+ # Check router weights
30
+ print("\n🎯 Router Weight Analysis:")
31
+ for name, param in model.named_parameters():
32
+ if "gate.weight" in name:
33
+ weight_std = param.std().item()
34
+ weight_range = (param.max() - param.min()).item()
35
+ print(f" {name}: std={weight_std:.6f}, range={weight_range:.6f}")
36
+
37
+ # Check if router has learned to differentiate
38
+ router_weights = param.detach().cpu()
39
+ correlations = []
40
+ for i in range(min(5, router_weights.shape[0])):
41
+ for j in range(i+1, min(5, router_weights.shape[0])):
42
+ corr = torch.corrcoef(torch.stack([router_weights[i], router_weights[j]]))[0,1].item()
43
+ correlations.append(abs(corr))
44
+
45
+ if correlations:
46
+ avg_correlation = np.mean(correlations)
47
+ print(f" Average correlation between experts: {avg_correlation:.4f}")
48
+ if avg_correlation < 0.9:
49
+ print(" ✅ Experts show differentiation (good!)")
50
+ else:
51
+ print(" ⚠️ Experts are too similar (potential issue)")
52
+
53
+ # Test with random input
54
+ print("\n🎲 Testing with random input:")
55
+ with torch.no_grad():
56
+ test_input = torch.randint(0, config["model"]["vocab_size"], (2, 128)).to(device)
57
+ output = model(test_input)
58
+ if isinstance(output, tuple):
59
+ output = output[0]
60
+
61
+ # Check output statistics
62
+ output_std = output.std().item()
63
+ output_mean = output.mean().item()
64
+ print(f" Output mean: {output_mean:.6f}, std: {output_std:.6f}")
65
+
66
+ if output_std > 0.1:
67
+ print(" ✅ Model produces varied outputs")
68
+ else:
69
+ print(" ⚠️ Model outputs might be collapsed")
70
+
71
+ return ckpt["step"]
72
+
73
+
74
+ if __name__== "__main__":
75
+ import json
76
+
77
+ # Load config
78
+ with open("./config.json", "r") as f:
79
+ config = json.load(f)
80
+
81
+ # Run diagnostic
82
+ current_step = diagnose_checkpoint("./checkpoints/your_latest_checkpoint.pt", config, "cuda")
83
+ print(f"\n📍 Current step: {current_step}")
Model_Architecture/model.py CHANGED
@@ -395,33 +395,27 @@ class MoE(nn.Module):
395
  original_shape = x.size()
396
  x = x.view(-1, self.dim)
397
 
398
- router_logits = F.linear(x, self.gate.weight, self.gate.bias) # Use bias directly
399
  router_probs = router_logits.sigmoid()
400
  weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1)
401
 
402
-
403
  # Normalize weights
404
- weights = weights / weights.sum(dim=-1, keepdim=True)
405
  weights = weights * self.gate.route_scale
406
 
407
- # FIX: Sequential Training Mode - correct indexing logic
408
  if self.training and self.active_expert_idx is not None:
 
409
  y = torch.zeros_like(x)
410
  i = self.active_expert_idx
411
 
412
  # Find tokens where expert i is in the top-k
413
- # indices shape: [num_tokens, top_k]
414
- mask = (indices == i) # shape: [num_tokens, top_k]
415
- idx = torch.where(mask.any(dim=1))[0] # token indices
416
 
417
  if idx.numel() > 0:
418
- # For each token, find which position in top-k contains expert i
419
- top_positions = torch.argmax(mask[idx].int(), dim=1) # shape: [num_selected_tokens]
420
-
421
- # Get weights for expert i
422
- expert_weights = weights[idx, top_positions].unsqueeze(-1) # shape: [num_selected_tokens, 1]
423
-
424
- # Forward pass ONLY for active expert
425
  expert_out = self.experts[i](x[idx])
426
  y[idx] = expert_out * expert_weights
427
 
@@ -430,31 +424,32 @@ class MoE(nn.Module):
430
 
431
  # Shared experts
432
  z = self.shared_experts(x)
433
-
434
  return (y + z).view(original_shape), lb_loss
 
 
 
 
 
 
 
435
 
436
- # Normal MoE Mode
437
- y = torch.zeros_like(x)
438
- for i in range(self.n_routed_experts):
439
- mask = (indices == i)
440
- idx = torch.where(mask.any(dim=1))[0]
441
-
442
- if idx.numel() == 0:
443
- continue
444
 
445
- top_positions = torch.argmax(mask[idx].int(), dim=1)
446
- expert_weights = weights[idx, top_positions].unsqueeze(-1)
447
- expert_out = self.experts[i](x[idx])
448
- y[idx] += expert_out * expert_weights
449
 
450
- z = self.shared_experts(x)
451
- output = (y + z).view(original_shape)
452
 
453
- if self.training:
454
- lb_loss = self.compute_load_balance_loss(router_probs, indices)
455
- return output, lb_loss
456
- else:
457
- return output, None
 
458
 
459
 
460
 
@@ -536,6 +531,7 @@ class ismail(nn.Module):
536
  h = self.tok_embeddings(tokens).to(Linear.dtype)
537
  freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
538
 
 
539
  if start_pos == 0:
540
  for layer in self.layers:
541
  if hasattr(layer.attn, 'kv_cache'):
@@ -545,9 +541,9 @@ class ismail(nn.Module):
545
 
546
  mask = None
547
  if seqlen > 1:
548
- mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
549
  mask = torch.triu(mask, diagonal=1)
550
- mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
551
 
552
  total_lb_loss = 0.0
553
 
@@ -559,7 +555,8 @@ class ismail(nn.Module):
559
  h = self.norm(h)
560
  output = self.output(h)
561
 
 
562
  if self.training and total_lb_loss > 0:
563
  return output, total_lb_loss
564
- return output
565
-
 
395
  original_shape = x.size()
396
  x = x.view(-1, self.dim)
397
 
398
+ router_logits = linear(x, self.gate.weight, self.gate.bias)
399
  router_probs = router_logits.sigmoid()
400
  weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1)
401
 
 
402
  # Normalize weights
403
+ weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8) # Add epsilon for stability
404
  weights = weights * self.gate.route_scale
405
 
406
+ # CRITICAL FIX: Check training mode AND active expert
407
  if self.training and self.active_expert_idx is not None:
408
+ # Sequential training mode - only train one expert
409
  y = torch.zeros_like(x)
410
  i = self.active_expert_idx
411
 
412
  # Find tokens where expert i is in the top-k
413
+ mask = (indices == i)
414
+ idx = torch.where(mask.any(dim=1))[0]
 
415
 
416
  if idx.numel() > 0:
417
+ top_positions = torch.argmax(mask[idx].int(), dim=1)
418
+ expert_weights = weights[idx, top_positions].unsqueeze(-1)
 
 
 
 
 
419
  expert_out = self.experts[i](x[idx])
420
  y[idx] = expert_out * expert_weights
421
 
 
424
 
425
  # Shared experts
426
  z = self.shared_experts(x)
 
427
  return (y + z).view(original_shape), lb_loss
428
+
429
+ else:
430
+ # Inference mode or all-experts training mode
431
+ y = torch.zeros_like(x)
432
+ for i in range(self.n_routed_experts):
433
+ mask = (indices == i)
434
+ idx = torch.where(mask.any(dim=1))[0]
435
 
436
+ if idx.numel() == 0:
437
+ continue
 
 
 
 
 
 
438
 
439
+ top_positions = torch.argmax(mask[idx].int(), dim=1)
440
+ expert_weights = weights[idx, top_positions].unsqueeze(-1)
441
+ expert_out = self.experts[i](x[idx])
442
+ y[idx] += expert_out * expert_weights
443
 
444
+ z = self.shared_experts(x)
445
+ output = (y + z).view(original_shape)
446
 
447
+ # Only compute load balance loss during training
448
+ if self.training:
449
+ lb_loss = self.compute_load_balance_loss(router_probs, indices)
450
+ return output, lb_loss
451
+ else:
452
+ return output, None
453
 
454
 
455
 
 
531
  h = self.tok_embeddings(tokens).to(Linear.dtype)
532
  freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
533
 
534
+ # CRITICAL: Always clear caches at start_pos=0, regardless of training mode
535
  if start_pos == 0:
536
  for layer in self.layers:
537
  if hasattr(layer.attn, 'kv_cache'):
 
541
 
542
  mask = None
543
  if seqlen > 1:
544
+ mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device, dtype=h.dtype)
545
  mask = torch.triu(mask, diagonal=1)
546
+ mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device, dtype=h.dtype), mask])
547
 
548
  total_lb_loss = 0.0
549
 
 
555
  h = self.norm(h)
556
  output = self.output(h)
557
 
558
+ # FIX: Only return load balance loss during training
559
  if self.training and total_lb_loss > 0:
560
  return output, total_lb_loss
561
+ else:
562
+ return output
Model_Architecture/train.py CHANGED
@@ -264,6 +264,17 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
264
  """
265
  model.eval()
266
 
 
 
 
 
 
 
 
 
 
 
 
267
  # Clear caches...
268
  for layer in model.layers:
269
  if hasattr(layer.attn, 'kv_cache'):
@@ -273,11 +284,18 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
273
 
274
  # Set expert mode for validation
275
  if hasattr(model, 'set_active_expert'):
276
- model.set_active_expert(active_expert)
 
277
  if active_expert is not None:
278
  print(f" Validating with ONLY expert {active_expert}")
 
 
279
  else:
280
  print(f" Validating with ALL experts")
 
 
 
 
281
 
282
  total_loss = 0.0
283
  total_tokens = 0
@@ -297,21 +315,9 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
297
  input_ids = input_ids.to(device, non_blocking=True)
298
  target_ids = target_ids.to(device, non_blocking=True)
299
 
300
- # 🔥 VISUAL TURKISH SAMPLE: Show human-readable text
301
- if i == 0: # First batch only
302
- sample_tokens = input_ids[0].cpu().tolist()
303
- # Decode first 30 tokens (skip padding zeros)
304
- non_zero_tokens = [t for t in sample_tokens[:30] if t > 0]
305
- try:
306
- sample_text = tokenizer.decode(non_zero_tokens)
307
- # Truncate if too long
308
- if len(sample_text) > 60:
309
- sample_text = sample_text[:57] + "..."
310
- print(f"\n📝 Örnek Turkce metin: '{sample_text}'")
311
- except Exception as e:
312
- print(f"\n⚠️ Decode failed: {e}\n Tokens: {non_zero_tokens[:10]}...")
313
-
314
- with torch.amp.autocast(device_type='cuda', enabled=(val_dtype == 'bf16')):
315
  output = model(input_ids, start_pos=0)
316
  logits = output[0] if isinstance(output, tuple) else output
317
 
@@ -328,6 +334,16 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
328
  pbar.set_postfix({'loss': f'{loss.item():.3f}'})
329
 
330
  pbar.close()
 
 
 
 
 
 
 
 
 
 
331
  model.train()
332
 
333
  final_loss = total_loss / total_tokens
@@ -339,7 +355,6 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
339
 
340
  return final_loss
341
 
342
-
343
  def save_checkpoint(model, optimizer, step, config, expert_idx=None):
344
  """Save model checkpoint"""
345
  save_dir = Path(config["training"]["save_dir"])
@@ -392,7 +407,11 @@ def train_step(model, input_mb, target_mb, device, config, scaler=None):
392
  input_mb = input_mb.to(device, non_blocking=True)
393
  target_mb = target_mb.to(device, non_blocking=True)
394
 
395
- with torch.amp.autocast(device_type='cuda', enabled=(config["training"]["dtype"] == "bf16")):
 
 
 
 
396
  output = model(input_mb, start_pos=0)
397
 
398
  if isinstance(output, tuple):
 
264
  """
265
  model.eval()
266
 
267
+ # CRITICAL FIX: Store original gradient requirements for experts
268
+ original_expert_grads = {}
269
+ for name, param in model.named_parameters():
270
+ if "experts" in name:
271
+ original_expert_grads[name] = param.requires_grad
272
+
273
+ # Enable gradients for all experts during evaluation
274
+ for name, param in model.named_parameters():
275
+ if "experts" in name:
276
+ param.requires_grad = True
277
+
278
  # Clear caches...
279
  for layer in model.layers:
280
  if hasattr(layer.attn, 'kv_cache'):
 
284
 
285
  # Set expert mode for validation
286
  if hasattr(model, 'set_active_expert'):
287
+ # CRITICAL: For validation, temporarily set to None (all experts)
288
+ # even if we're in sequential training mode
289
  if active_expert is not None:
290
  print(f" Validating with ONLY expert {active_expert}")
291
+ # Store the actual active expert but use all for forward pass
292
+ validation_expert = active_expert
293
  else:
294
  print(f" Validating with ALL experts")
295
+ validation_expert = None
296
+
297
+ # Always use all experts for validation forward pass
298
+ model.set_active_expert(None)
299
 
300
  total_loss = 0.0
301
  total_tokens = 0
 
315
  input_ids = input_ids.to(device, non_blocking=True)
316
  target_ids = target_ids.to(device, non_blocking=True)
317
 
318
+ # CRITICAL: Use proper autocast settings based on dtype
319
+ use_autocast = val_dtype in ['bf16', 'fp16']
320
+ with torch.amp.autocast(device_type='cuda', enabled=use_autocast, dtype=torch.bfloat16 if val_dtype == 'bf16' else torch.float16):
 
 
 
 
 
 
 
 
 
 
 
 
321
  output = model(input_ids, start_pos=0)
322
  logits = output[0] if isinstance(output, tuple) else output
323
 
 
334
  pbar.set_postfix({'loss': f'{loss.item():.3f}'})
335
 
336
  pbar.close()
337
+
338
+ # CRITICAL: Restore original gradient requirements
339
+ for name, param in model.named_parameters():
340
+ if name in original_expert_grads:
341
+ param.requires_grad = original_expert_grads[name]
342
+
343
+ # Restore the active expert if in sequential training mode
344
+ if hasattr(model, 'set_active_expert') and 'validation_expert' in locals():
345
+ model.set_active_expert(validation_expert)
346
+
347
  model.train()
348
 
349
  final_loss = total_loss / total_tokens
 
355
 
356
  return final_loss
357
 
 
358
  def save_checkpoint(model, optimizer, step, config, expert_idx=None):
359
  """Save model checkpoint"""
360
  save_dir = Path(config["training"]["save_dir"])
 
407
  input_mb = input_mb.to(device, non_blocking=True)
408
  target_mb = target_mb.to(device, non_blocking=True)
409
 
410
+ training_dtype = config["training"]["dtype"].lower()
411
+ use_autocast = training_dtype in ['bf16', 'fp16']
412
+ autocast_dtype = torch.bfloat16 if training_dtype == 'bf16' else torch.float16
413
+ with torch.amp.autocast(device_type='cuda', enabled=use_autocast, dtype=autocast_dtype if use_autocast else None):
414
+
415
  output = model(input_mb, start_pos=0)
416
 
417
  if isinstance(output, tuple):