Update best.py
Browse files
best.py
CHANGED
|
@@ -261,9 +261,13 @@ class GroupedQueryAttention(nn.Module):
|
|
| 261 |
query_states = (query_states * cos_full) + (rotate_half(query_states) * sin_full)
|
| 262 |
key_states = (key_states * cos_full) + (rotate_half(key_states) * sin_full)
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
present_kv = (key_states, value_states) if use_cache else None
|
| 265 |
|
| 266 |
-
#
|
| 267 |
if self.num_key_value_groups > 1:
|
| 268 |
key_states = key_states .repeat_interleave(self.num_key_value_groups, dim=1)
|
| 269 |
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
|
@@ -353,10 +357,10 @@ class DecoderLayer(nn.Module):
|
|
| 353 |
|
| 354 |
class LabelSmoothingCrossEntropy(nn.Module):
|
| 355 |
"""
|
| 356 |
-
Cross-entropy with label smoothing
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
"""
|
| 361 |
|
| 362 |
def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100):
|
|
@@ -364,26 +368,17 @@ class LabelSmoothingCrossEntropy(nn.Module):
|
|
| 364 |
self.vocab_size = vocab_size
|
| 365 |
self.smoothing = smoothing
|
| 366 |
self.ignore_index = ignore_index
|
| 367 |
-
self.confidence = 1.0 - smoothing
|
| 368 |
|
| 369 |
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 370 |
# logits: [N, V] targets: [N]
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
# Smooth target distribution
|
| 381 |
-
smooth_val = self.smoothing / (self.vocab_size - 1)
|
| 382 |
-
smooth_dist = torch.full_like(log_probs, smooth_val)
|
| 383 |
-
smooth_dist.scatter_(1, targets.unsqueeze(1), self.confidence)
|
| 384 |
-
|
| 385 |
-
loss = -(smooth_dist * log_probs).sum(dim=-1).mean()
|
| 386 |
-
return loss
|
| 387 |
|
| 388 |
|
| 389 |
# ============================================================================
|
|
@@ -900,7 +895,6 @@ def train_model(
|
|
| 900 |
betas=(config.adam_beta1, config.adam_beta2),
|
| 901 |
eps=config.adam_epsilon,
|
| 902 |
weight_decay=config.weight_decay,
|
| 903 |
-
fused=True if (device.type == 'cuda' and hasattr(torch.optim.AdamW, '__init__')) else False,
|
| 904 |
)
|
| 905 |
|
| 906 |
total_steps = sum(
|
|
|
|
| 261 |
query_states = (query_states * cos_full) + (rotate_half(query_states) * sin_full)
|
| 262 |
key_states = (key_states * cos_full) + (rotate_half(key_states) * sin_full)
|
| 263 |
|
| 264 |
+
# Store pre-expand KV in cache (shape [B, num_kv_heads, T, D]).
|
| 265 |
+
# Must happen BEFORE repeat_interleave — otherwise cached keys have
|
| 266 |
+
# num_heads channels instead of num_kv_heads, and every decode step
|
| 267 |
+
# re-expands them again, corrupting attention.
|
| 268 |
present_kv = (key_states, value_states) if use_cache else None
|
| 269 |
|
| 270 |
+
# Expand KV heads for full attention computation
|
| 271 |
if self.num_key_value_groups > 1:
|
| 272 |
key_states = key_states .repeat_interleave(self.num_key_value_groups, dim=1)
|
| 273 |
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
|
|
|
| 357 |
|
| 358 |
class LabelSmoothingCrossEntropy(nn.Module):
|
| 359 |
"""
|
| 360 |
+
Cross-entropy with label smoothing.
|
| 361 |
+
Filters ignore_index=-100 first, then uses F.cross_entropy with smoothing.
|
| 362 |
+
This keeps the exact same loss scale as the original nn.CrossEntropyLoss
|
| 363 |
+
so the LR schedule pacing is unchanged.
|
| 364 |
"""
|
| 365 |
|
| 366 |
def __init__(self, vocab_size: int, smoothing: float = 0.1, ignore_index: int = -100):
|
|
|
|
| 368 |
self.vocab_size = vocab_size
|
| 369 |
self.smoothing = smoothing
|
| 370 |
self.ignore_index = ignore_index
|
|
|
|
| 371 |
|
| 372 |
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 373 |
# logits: [N, V] targets: [N]
|
| 374 |
+
# F.cross_entropy with label_smoothing and ignore_index is correct in
|
| 375 |
+
# PyTorch >= 1.10 — it does NOT distribute to ignored positions.
|
| 376 |
+
return F.cross_entropy(
|
| 377 |
+
logits,
|
| 378 |
+
targets,
|
| 379 |
+
ignore_index=self.ignore_index,
|
| 380 |
+
label_smoothing=self.smoothing,
|
| 381 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
|
| 384 |
# ============================================================================
|
|
|
|
| 895 |
betas=(config.adam_beta1, config.adam_beta2),
|
| 896 |
eps=config.adam_epsilon,
|
| 897 |
weight_decay=config.weight_decay,
|
|
|
|
| 898 |
)
|
| 899 |
|
| 900 |
total_steps = sum(
|