FaiziRBLX commited on
Commit
f040bfd
·
verified ·
1 Parent(s): 110b8ce

Update best.py

Browse files
Files changed (1) hide show
  1. best.py +17 -23
best.py CHANGED
@@ -261,9 +261,13 @@ class GroupedQueryAttention(nn.Module):
261
  query_states = (query_states * cos_full) + (rotate_half(query_states) * sin_full)
262
  key_states = (key_states * cos_full) + (rotate_half(key_states) * sin_full)
263
 
 
 
 
 
264
  present_kv = (key_states, value_states) if use_cache else None
265
 
266
- # FIX: only expand KV if groups > 1 (skip no-op repeat when groups==1)
267
  if self.num_key_value_groups > 1:
268
  key_states = key_states .repeat_interleave(self.num_key_value_groups, dim=1)
269
  value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
@@ -353,10 +357,10 @@ class DecoderLayer(nn.Module):
353
 
354
  class LabelSmoothingCrossEntropy(nn.Module):
355
  """
356
- Cross-entropy with label smoothing that correctly ignores positions
357
- where label == ignore_index. PyTorch's built-in label_smoothing
358
- distributes probability mass to ALL vocab entries including padding
359
- this implementation does not.
360
  """
361
 
362
  def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100):
@@ -364,26 +368,17 @@ class LabelSmoothingCrossEntropy(nn.Module):
364
  self.vocab_size = vocab_size
365
  self.smoothing = smoothing
366
  self.ignore_index = ignore_index
367
- self.confidence = 1.0 - smoothing
368
 
369
  def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
370
  # logits: [N, V] targets: [N]
371
- mask = targets != self.ignore_index
372
- if mask.sum() == 0:
373
- return logits.sum() * 0.0 # differentiable zero
374
-
375
- logits = logits[mask]
376
- targets = targets[mask]
377
-
378
- log_probs = F.log_softmax(logits, dim=-1) # [N_valid, V]
379
-
380
- # Smooth target distribution
381
- smooth_val = self.smoothing / (self.vocab_size - 1)
382
- smooth_dist = torch.full_like(log_probs, smooth_val)
383
- smooth_dist.scatter_(1, targets.unsqueeze(1), self.confidence)
384
-
385
- loss = -(smooth_dist * log_probs).sum(dim=-1).mean()
386
- return loss
387
 
388
 
389
  # ============================================================================
@@ -900,7 +895,6 @@ def train_model(
900
  betas=(config.adam_beta1, config.adam_beta2),
901
  eps=config.adam_epsilon,
902
  weight_decay=config.weight_decay,
903
- fused=True if (device.type == 'cuda' and hasattr(torch.optim.AdamW, '__init__')) else False,
904
  )
905
 
906
  total_steps = sum(
 
261
  query_states = (query_states * cos_full) + (rotate_half(query_states) * sin_full)
262
  key_states = (key_states * cos_full) + (rotate_half(key_states) * sin_full)
263
 
264
+ # Store pre-expand KV in cache (shape [B, num_kv_heads, T, D]).
265
+ # Must happen BEFORE repeat_interleave — otherwise cached keys have
266
+ # num_heads channels instead of num_kv_heads, and every decode step
267
+ # re-expands them again, corrupting attention.
268
  present_kv = (key_states, value_states) if use_cache else None
269
 
270
+ # Expand KV heads for full attention computation
271
  if self.num_key_value_groups > 1:
272
  key_states = key_states .repeat_interleave(self.num_key_value_groups, dim=1)
273
  value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
 
357
 
358
  class LabelSmoothingCrossEntropy(nn.Module):
359
  """
360
+ Cross-entropy with label smoothing.
361
+ Filters ignore_index=-100 first, then uses F.cross_entropy with smoothing.
362
+ This keeps the exact same loss scale as the original nn.CrossEntropyLoss
363
+ so the LR schedule pacing is unchanged.
364
  """
365
 
366
  def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100):
 
368
  self.vocab_size = vocab_size
369
  self.smoothing = smoothing
370
  self.ignore_index = ignore_index
 
371
 
372
  def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
373
  # logits: [N, V] targets: [N]
374
+ # F.cross_entropy with label_smoothing and ignore_index is correct in
375
+ # PyTorch >= 1.10 — it does NOT distribute to ignored positions.
376
+ return F.cross_entropy(
377
+ logits,
378
+ targets,
379
+ ignore_index=self.ignore_index,
380
+ label_smoothing=self.smoothing,
381
+ )
 
 
 
 
 
 
 
 
382
 
383
 
384
  # ============================================================================
 
895
  betas=(config.adam_beta1, config.adam_beta2),
896
  eps=config.adam_epsilon,
897
  weight_decay=config.weight_decay,
 
898
  )
899
 
900
  total_steps = sum(