Pomilon
Deploy Aetheris to HF Space
1df0e33
raw
history blame
3.4 kB
import torch
import torch.nn as nn
from typing import Dict, Any, List
from .config import AetherisConfig
from .modules import SSMBlock, SparseMoELayer
class HybridMambaMoE(nn.Module):
def __init__(self, config: AetherisConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList()
for i in range(config.n_layer):
if i % 2 == 0:
self.layers.append(SSMBlock(config))
else:
self.layers.append(SparseMoELayer(config))
self.final_norm = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # Weight tying
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
self.gradient_checkpointing = config.gradient_checkpointing
# Initialize embeddings with smaller scale
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
def _init_weights(self, module):
"""Apply proper weight initialization"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=0.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
x = self.embedding(input_ids)
total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
for i, layer in enumerate(self.layers):
if self.gradient_checkpointing and self.training:
# Checkpoint ALL layers for maximum memory savings
if isinstance(layer, SparseMoELayer):
def moe_forward(module, inp):
return module(inp)
x, aux_loss = torch.utils.checkpoint.checkpoint(
moe_forward, layer, x, use_reentrant=False
)
total_aux_loss = total_aux_loss + aux_loss
else:
x = torch.utils.checkpoint.checkpoint(
layer, x, use_reentrant=False
)
else:
if isinstance(layer, SparseMoELayer):
x, aux_loss = layer(x)
total_aux_loss = total_aux_loss + aux_loss
else:
x = layer(x)
x = self.final_norm(x)
logits = self.lm_head(x)
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1))
# Scale down aux loss to prevent it from dominating
total_loss = ce_loss + 0.01 * total_aux_loss
return {
"loss": total_loss,
"ce_loss": ce_loss,
"aux_loss": total_aux_loss,
"logits": logits
}
return {"logits": logits}