Commit ·
8f73121
1
Parent(s): 6b1c605
Some Fixes
Browse files- Model_Architecture/diognose_weights.py +83 -0
- Model_Architecture/model.py +35 -38
- Model_Architecture/train.py +37 -18
Model_Architecture/diognose_weights.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def diagnose_checkpoint(checkpoint_path, config, device):
|
| 2 |
+
"""Diagnose if the checkpoint has actually learned anything"""
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
print("🔍 Diagnosing checkpoint...")
|
| 7 |
+
|
| 8 |
+
# Load checkpoint
|
| 9 |
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
| 10 |
+
|
| 11 |
+
# Create model with fixes
|
| 12 |
+
from model import ismail, ModelArgs
|
| 13 |
+
args = ModelArgs(**config["model"])
|
| 14 |
+
model = ismail(args).to(device)
|
| 15 |
+
|
| 16 |
+
# Load weights
|
| 17 |
+
model.load_state_dict(ckpt["model_state_dict"], strict=False)
|
| 18 |
+
model.eval()
|
| 19 |
+
|
| 20 |
+
# Check expert weight statistics
|
| 21 |
+
print("\n📊 Expert Weight Analysis:")
|
| 22 |
+
for name, param in model.named_parameters():
|
| 23 |
+
if "experts" in name and "routed" in name:
|
| 24 |
+
expert_idx = int(name.split("experts.")[1].split(".")[0])
|
| 25 |
+
weight_std = param.std().item()
|
| 26 |
+
weight_mean = param.mean().item()
|
| 27 |
+
print(f" Expert {expert_idx}: mean={weight_mean:.6f}, std={weight_std:.6f}")
|
| 28 |
+
|
| 29 |
+
# Check router weights
|
| 30 |
+
print("\n🎯 Router Weight Analysis:")
|
| 31 |
+
for name, param in model.named_parameters():
|
| 32 |
+
if "gate.weight" in name:
|
| 33 |
+
weight_std = param.std().item()
|
| 34 |
+
weight_range = (param.max() - param.min()).item()
|
| 35 |
+
print(f" {name}: std={weight_std:.6f}, range={weight_range:.6f}")
|
| 36 |
+
|
| 37 |
+
# Check if router has learned to differentiate
|
| 38 |
+
router_weights = param.detach().cpu()
|
| 39 |
+
correlations = []
|
| 40 |
+
for i in range(min(5, router_weights.shape[0])):
|
| 41 |
+
for j in range(i+1, min(5, router_weights.shape[0])):
|
| 42 |
+
corr = torch.corrcoef(torch.stack([router_weights[i], router_weights[j]]))[0,1].item()
|
| 43 |
+
correlations.append(abs(corr))
|
| 44 |
+
|
| 45 |
+
if correlations:
|
| 46 |
+
avg_correlation = np.mean(correlations)
|
| 47 |
+
print(f" Average correlation between experts: {avg_correlation:.4f}")
|
| 48 |
+
if avg_correlation < 0.9:
|
| 49 |
+
print(" ✅ Experts show differentiation (good!)")
|
| 50 |
+
else:
|
| 51 |
+
print(" ⚠️ Experts are too similar (potential issue)")
|
| 52 |
+
|
| 53 |
+
# Test with random input
|
| 54 |
+
print("\n🎲 Testing with random input:")
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
test_input = torch.randint(0, config["model"]["vocab_size"], (2, 128)).to(device)
|
| 57 |
+
output = model(test_input)
|
| 58 |
+
if isinstance(output, tuple):
|
| 59 |
+
output = output[0]
|
| 60 |
+
|
| 61 |
+
# Check output statistics
|
| 62 |
+
output_std = output.std().item()
|
| 63 |
+
output_mean = output.mean().item()
|
| 64 |
+
print(f" Output mean: {output_mean:.6f}, std: {output_std:.6f}")
|
| 65 |
+
|
| 66 |
+
if output_std > 0.1:
|
| 67 |
+
print(" ✅ Model produces varied outputs")
|
| 68 |
+
else:
|
| 69 |
+
print(" ⚠️ Model outputs might be collapsed")
|
| 70 |
+
|
| 71 |
+
return ckpt["step"]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__== "__main__":
|
| 75 |
+
import json
|
| 76 |
+
|
| 77 |
+
# Load config
|
| 78 |
+
with open("./config.json", "r") as f:
|
| 79 |
+
config = json.load(f)
|
| 80 |
+
|
| 81 |
+
# Run diagnostic
|
| 82 |
+
current_step = diagnose_checkpoint("./checkpoints/your_latest_checkpoint.pt", config, "cuda")
|
| 83 |
+
print(f"\n📍 Current step: {current_step}")
|
Model_Architecture/model.py
CHANGED
|
@@ -395,33 +395,27 @@ class MoE(nn.Module):
|
|
| 395 |
original_shape = x.size()
|
| 396 |
x = x.view(-1, self.dim)
|
| 397 |
|
| 398 |
-
router_logits =
|
| 399 |
router_probs = router_logits.sigmoid()
|
| 400 |
weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1)
|
| 401 |
|
| 402 |
-
|
| 403 |
# Normalize weights
|
| 404 |
-
weights = weights / weights.sum(dim=-1, keepdim=True)
|
| 405 |
weights = weights * self.gate.route_scale
|
| 406 |
|
| 407 |
-
#
|
| 408 |
if self.training and self.active_expert_idx is not None:
|
|
|
|
| 409 |
y = torch.zeros_like(x)
|
| 410 |
i = self.active_expert_idx
|
| 411 |
|
| 412 |
# Find tokens where expert i is in the top-k
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
idx = torch.where(mask.any(dim=1))[0] # token indices
|
| 416 |
|
| 417 |
if idx.numel() > 0:
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
# Get weights for expert i
|
| 422 |
-
expert_weights = weights[idx, top_positions].unsqueeze(-1) # shape: [num_selected_tokens, 1]
|
| 423 |
-
|
| 424 |
-
# Forward pass ONLY for active expert
|
| 425 |
expert_out = self.experts[i](x[idx])
|
| 426 |
y[idx] = expert_out * expert_weights
|
| 427 |
|
|
@@ -430,31 +424,32 @@ class MoE(nn.Module):
|
|
| 430 |
|
| 431 |
# Shared experts
|
| 432 |
z = self.shared_experts(x)
|
| 433 |
-
|
| 434 |
return (y + z).view(original_shape), lb_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
for i in range(self.n_routed_experts):
|
| 439 |
-
mask = (indices == i)
|
| 440 |
-
idx = torch.where(mask.any(dim=1))[0]
|
| 441 |
-
|
| 442 |
-
if idx.numel() == 0:
|
| 443 |
-
continue
|
| 444 |
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
|
| 450 |
-
|
| 451 |
-
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
|
|
|
| 458 |
|
| 459 |
|
| 460 |
|
|
@@ -536,6 +531,7 @@ class ismail(nn.Module):
|
|
| 536 |
h = self.tok_embeddings(tokens).to(Linear.dtype)
|
| 537 |
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
|
| 538 |
|
|
|
|
| 539 |
if start_pos == 0:
|
| 540 |
for layer in self.layers:
|
| 541 |
if hasattr(layer.attn, 'kv_cache'):
|
|
@@ -545,9 +541,9 @@ class ismail(nn.Module):
|
|
| 545 |
|
| 546 |
mask = None
|
| 547 |
if seqlen > 1:
|
| 548 |
-
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
| 549 |
mask = torch.triu(mask, diagonal=1)
|
| 550 |
-
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask])
|
| 551 |
|
| 552 |
total_lb_loss = 0.0
|
| 553 |
|
|
@@ -559,7 +555,8 @@ class ismail(nn.Module):
|
|
| 559 |
h = self.norm(h)
|
| 560 |
output = self.output(h)
|
| 561 |
|
|
|
|
| 562 |
if self.training and total_lb_loss > 0:
|
| 563 |
return output, total_lb_loss
|
| 564 |
-
|
| 565 |
-
|
|
|
|
| 395 |
original_shape = x.size()
|
| 396 |
x = x.view(-1, self.dim)
|
| 397 |
|
| 398 |
+
router_logits = linear(x, self.gate.weight, self.gate.bias)
|
| 399 |
router_probs = router_logits.sigmoid()
|
| 400 |
weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1)
|
| 401 |
|
|
|
|
| 402 |
# Normalize weights
|
| 403 |
+
weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8) # Add epsilon for stability
|
| 404 |
weights = weights * self.gate.route_scale
|
| 405 |
|
| 406 |
+
# CRITICAL FIX: Check training mode AND active expert
|
| 407 |
if self.training and self.active_expert_idx is not None:
|
| 408 |
+
# Sequential training mode - only train one expert
|
| 409 |
y = torch.zeros_like(x)
|
| 410 |
i = self.active_expert_idx
|
| 411 |
|
| 412 |
# Find tokens where expert i is in the top-k
|
| 413 |
+
mask = (indices == i)
|
| 414 |
+
idx = torch.where(mask.any(dim=1))[0]
|
|
|
|
| 415 |
|
| 416 |
if idx.numel() > 0:
|
| 417 |
+
top_positions = torch.argmax(mask[idx].int(), dim=1)
|
| 418 |
+
expert_weights = weights[idx, top_positions].unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
expert_out = self.experts[i](x[idx])
|
| 420 |
y[idx] = expert_out * expert_weights
|
| 421 |
|
|
|
|
| 424 |
|
| 425 |
# Shared experts
|
| 426 |
z = self.shared_experts(x)
|
|
|
|
| 427 |
return (y + z).view(original_shape), lb_loss
|
| 428 |
+
|
| 429 |
+
else:
|
| 430 |
+
# Inference mode or all-experts training mode
|
| 431 |
+
y = torch.zeros_like(x)
|
| 432 |
+
for i in range(self.n_routed_experts):
|
| 433 |
+
mask = (indices == i)
|
| 434 |
+
idx = torch.where(mask.any(dim=1))[0]
|
| 435 |
|
| 436 |
+
if idx.numel() == 0:
|
| 437 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
+
top_positions = torch.argmax(mask[idx].int(), dim=1)
|
| 440 |
+
expert_weights = weights[idx, top_positions].unsqueeze(-1)
|
| 441 |
+
expert_out = self.experts[i](x[idx])
|
| 442 |
+
y[idx] += expert_out * expert_weights
|
| 443 |
|
| 444 |
+
z = self.shared_experts(x)
|
| 445 |
+
output = (y + z).view(original_shape)
|
| 446 |
|
| 447 |
+
# Only compute load balance loss during training
|
| 448 |
+
if self.training:
|
| 449 |
+
lb_loss = self.compute_load_balance_loss(router_probs, indices)
|
| 450 |
+
return output, lb_loss
|
| 451 |
+
else:
|
| 452 |
+
return output, None
|
| 453 |
|
| 454 |
|
| 455 |
|
|
|
|
| 531 |
h = self.tok_embeddings(tokens).to(Linear.dtype)
|
| 532 |
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
|
| 533 |
|
| 534 |
+
# CRITICAL: Always clear caches at start_pos=0, regardless of training mode
|
| 535 |
if start_pos == 0:
|
| 536 |
for layer in self.layers:
|
| 537 |
if hasattr(layer.attn, 'kv_cache'):
|
|
|
|
| 541 |
|
| 542 |
mask = None
|
| 543 |
if seqlen > 1:
|
| 544 |
+
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device, dtype=h.dtype)
|
| 545 |
mask = torch.triu(mask, diagonal=1)
|
| 546 |
+
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device, dtype=h.dtype), mask])
|
| 547 |
|
| 548 |
total_lb_loss = 0.0
|
| 549 |
|
|
|
|
| 555 |
h = self.norm(h)
|
| 556 |
output = self.output(h)
|
| 557 |
|
| 558 |
+
# FIX: Only return load balance loss during training
|
| 559 |
if self.training and total_lb_loss > 0:
|
| 560 |
return output, total_lb_loss
|
| 561 |
+
else:
|
| 562 |
+
return output
|
Model_Architecture/train.py
CHANGED
|
@@ -264,6 +264,17 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
|
|
| 264 |
"""
|
| 265 |
model.eval()
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# Clear caches...
|
| 268 |
for layer in model.layers:
|
| 269 |
if hasattr(layer.attn, 'kv_cache'):
|
|
@@ -273,11 +284,18 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
|
|
| 273 |
|
| 274 |
# Set expert mode for validation
|
| 275 |
if hasattr(model, 'set_active_expert'):
|
| 276 |
-
|
|
|
|
| 277 |
if active_expert is not None:
|
| 278 |
print(f" Validating with ONLY expert {active_expert}")
|
|
|
|
|
|
|
| 279 |
else:
|
| 280 |
print(f" Validating with ALL experts")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
total_loss = 0.0
|
| 283 |
total_tokens = 0
|
|
@@ -297,21 +315,9 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
|
|
| 297 |
input_ids = input_ids.to(device, non_blocking=True)
|
| 298 |
target_ids = target_ids.to(device, non_blocking=True)
|
| 299 |
|
| 300 |
-
#
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
# Decode first 30 tokens (skip padding zeros)
|
| 304 |
-
non_zero_tokens = [t for t in sample_tokens[:30] if t > 0]
|
| 305 |
-
try:
|
| 306 |
-
sample_text = tokenizer.decode(non_zero_tokens)
|
| 307 |
-
# Truncate if too long
|
| 308 |
-
if len(sample_text) > 60:
|
| 309 |
-
sample_text = sample_text[:57] + "..."
|
| 310 |
-
print(f"\n📝 Örnek Turkce metin: '{sample_text}'")
|
| 311 |
-
except Exception as e:
|
| 312 |
-
print(f"\n⚠️ Decode failed: {e}\n Tokens: {non_zero_tokens[:10]}...")
|
| 313 |
-
|
| 314 |
-
with torch.amp.autocast(device_type='cuda', enabled=(val_dtype == 'bf16')):
|
| 315 |
output = model(input_ids, start_pos=0)
|
| 316 |
logits = output[0] if isinstance(output, tuple) else output
|
| 317 |
|
|
@@ -328,6 +334,16 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
|
|
| 328 |
pbar.set_postfix({'loss': f'{loss.item():.3f}'})
|
| 329 |
|
| 330 |
pbar.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
model.train()
|
| 332 |
|
| 333 |
final_loss = total_loss / total_tokens
|
|
@@ -339,7 +355,6 @@ def evaluate(model, val_loader, device, config, tokenizer, active_expert=None):
|
|
| 339 |
|
| 340 |
return final_loss
|
| 341 |
|
| 342 |
-
|
| 343 |
def save_checkpoint(model, optimizer, step, config, expert_idx=None):
|
| 344 |
"""Save model checkpoint"""
|
| 345 |
save_dir = Path(config["training"]["save_dir"])
|
|
@@ -392,7 +407,11 @@ def train_step(model, input_mb, target_mb, device, config, scaler=None):
|
|
| 392 |
input_mb = input_mb.to(device, non_blocking=True)
|
| 393 |
target_mb = target_mb.to(device, non_blocking=True)
|
| 394 |
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
output = model(input_mb, start_pos=0)
|
| 397 |
|
| 398 |
if isinstance(output, tuple):
|
|
|
|
| 264 |
"""
|
| 265 |
model.eval()
|
| 266 |
|
| 267 |
+
# CRITICAL FIX: Store original gradient requirements for experts
|
| 268 |
+
original_expert_grads = {}
|
| 269 |
+
for name, param in model.named_parameters():
|
| 270 |
+
if "experts" in name:
|
| 271 |
+
original_expert_grads[name] = param.requires_grad
|
| 272 |
+
|
| 273 |
+
# Enable gradients for all experts during evaluation
|
| 274 |
+
for name, param in model.named_parameters():
|
| 275 |
+
if "experts" in name:
|
| 276 |
+
param.requires_grad = True
|
| 277 |
+
|
| 278 |
# Clear caches...
|
| 279 |
for layer in model.layers:
|
| 280 |
if hasattr(layer.attn, 'kv_cache'):
|
|
|
|
| 284 |
|
| 285 |
# Set expert mode for validation
|
| 286 |
if hasattr(model, 'set_active_expert'):
|
| 287 |
+
# CRITICAL: For validation, temporarily set to None (all experts)
|
| 288 |
+
# even if we're in sequential training mode
|
| 289 |
if active_expert is not None:
|
| 290 |
print(f" Validating with ONLY expert {active_expert}")
|
| 291 |
+
# Store the actual active expert but use all for forward pass
|
| 292 |
+
validation_expert = active_expert
|
| 293 |
else:
|
| 294 |
print(f" Validating with ALL experts")
|
| 295 |
+
validation_expert = None
|
| 296 |
+
|
| 297 |
+
# Always use all experts for validation forward pass
|
| 298 |
+
model.set_active_expert(None)
|
| 299 |
|
| 300 |
total_loss = 0.0
|
| 301 |
total_tokens = 0
|
|
|
|
| 315 |
input_ids = input_ids.to(device, non_blocking=True)
|
| 316 |
target_ids = target_ids.to(device, non_blocking=True)
|
| 317 |
|
| 318 |
+
# CRITICAL: Use proper autocast settings based on dtype
|
| 319 |
+
use_autocast = val_dtype in ['bf16', 'fp16']
|
| 320 |
+
with torch.amp.autocast(device_type='cuda', enabled=use_autocast, dtype=torch.bfloat16 if val_dtype == 'bf16' else torch.float16):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
output = model(input_ids, start_pos=0)
|
| 322 |
logits = output[0] if isinstance(output, tuple) else output
|
| 323 |
|
|
|
|
| 334 |
pbar.set_postfix({'loss': f'{loss.item():.3f}'})
|
| 335 |
|
| 336 |
pbar.close()
|
| 337 |
+
|
| 338 |
+
# CRITICAL: Restore original gradient requirements
|
| 339 |
+
for name, param in model.named_parameters():
|
| 340 |
+
if name in original_expert_grads:
|
| 341 |
+
param.requires_grad = original_expert_grads[name]
|
| 342 |
+
|
| 343 |
+
# Restore the active expert if in sequential training mode
|
| 344 |
+
if hasattr(model, 'set_active_expert') and 'validation_expert' in locals():
|
| 345 |
+
model.set_active_expert(validation_expert)
|
| 346 |
+
|
| 347 |
model.train()
|
| 348 |
|
| 349 |
final_loss = total_loss / total_tokens
|
|
|
|
| 355 |
|
| 356 |
return final_loss
|
| 357 |
|
|
|
|
| 358 |
def save_checkpoint(model, optimizer, step, config, expert_idx=None):
|
| 359 |
"""Save model checkpoint"""
|
| 360 |
save_dir = Path(config["training"]["save_dir"])
|
|
|
|
| 407 |
input_mb = input_mb.to(device, non_blocking=True)
|
| 408 |
target_mb = target_mb.to(device, non_blocking=True)
|
| 409 |
|
| 410 |
+
training_dtype = config["training"]["dtype"].lower()
|
| 411 |
+
use_autocast = training_dtype in ['bf16', 'fp16']
|
| 412 |
+
autocast_dtype = torch.bfloat16 if training_dtype == 'bf16' else torch.float16
|
| 413 |
+
with torch.amp.autocast(device_type='cuda', enabled=use_autocast, dtype=autocast_dtype if use_autocast else None):
|
| 414 |
+
|
| 415 |
output = model(input_mb, start_pos=0)
|
| 416 |
|
| 417 |
if isinstance(output, tuple):
|