Spaces:
Sleeping
Sleeping
File size: 3,401 Bytes
1df0e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
import torch.nn as nn
from typing import Dict, Any, List
from .config import AetherisConfig
from .modules import SSMBlock, SparseMoELayer
class HybridMambaMoE(nn.Module):
def __init__(self, config: AetherisConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList()
for i in range(config.n_layer):
if i % 2 == 0:
self.layers.append(SSMBlock(config))
else:
self.layers.append(SparseMoELayer(config))
self.final_norm = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # Weight tying
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
self.gradient_checkpointing = config.gradient_checkpointing
# Initialize embeddings with smaller scale
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
def _init_weights(self, module):
"""Apply proper weight initialization"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=0.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
x = self.embedding(input_ids)
total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing and self.training:
# Checkpoint ALL layers for maximum memory savings
if isinstance(layer, SparseMoELayer):
def moe_forward(module, inp):
return module(inp)
x, aux_loss = torch.utils.checkpoint.checkpoint(
moe_forward, layer, x, use_reentrant=False
)
total_aux_loss = total_aux_loss + aux_loss
else:
x = torch.utils.checkpoint.checkpoint(
layer, x, use_reentrant=False
)
else:
if isinstance(layer, SparseMoELayer):
x, aux_loss = layer(x)
total_aux_loss = total_aux_loss + aux_loss
else:
x = layer(x)
x = self.final_norm(x)
logits = self.lm_head(x)
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1))
# Scale down aux loss to prevent it from dominating
total_loss = ce_loss + 0.01 * total_aux_loss
return {
"loss": total_loss,
"ce_loss": ce_loss,
"aux_loss": total_aux_loss,
"logits": logits
}
return {"logits": logits}
|