Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Dict, Any, List | |
| from .config import AetherisConfig | |
| from .modules import SSMBlock, SparseMoELayer | |
| class HybridMambaMoE(nn.Module): | |
| def __init__(self, config: AetherisConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) | |
| self.layers = nn.ModuleList() | |
| for i in range(config.n_layer): | |
| if i % 2 == 0: | |
| self.layers.append(SSMBlock(config)) | |
| else: | |
| self.layers.append(SparseMoELayer(config)) | |
| self.final_norm = nn.LayerNorm(config.d_model) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| self.lm_head.weight = self.embedding.weight # Weight tying | |
| self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1) | |
| self.gradient_checkpointing = config.gradient_checkpointing | |
| # Initialize embeddings with smaller scale | |
| nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02) | |
| def _init_weights(self, module): | |
| """Apply proper weight initialization""" | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight, gain=0.5) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]: | |
| x = self.embedding(input_ids) | |
| total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype) | |
| for i, layer in enumerate(self.layers): | |
| if self.gradient_checkpointing and self.training: | |
| # Checkpoint ALL layers for maximum memory savings | |
| if isinstance(layer, SparseMoELayer): | |
| def moe_forward(module, inp): | |
| return module(inp) | |
| x, aux_loss = torch.utils.checkpoint.checkpoint( | |
| moe_forward, layer, x, use_reentrant=False | |
| ) | |
| total_aux_loss = total_aux_loss + aux_loss | |
| else: | |
| x = torch.utils.checkpoint.checkpoint( | |
| layer, x, use_reentrant=False | |
| ) | |
| else: | |
| if isinstance(layer, SparseMoELayer): | |
| x, aux_loss = layer(x) | |
| total_aux_loss = total_aux_loss + aux_loss | |
| else: | |
| x = layer(x) | |
| x = self.final_norm(x) | |
| logits = self.lm_head(x) | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size), | |
| shift_labels.view(-1)) | |
| # Scale down aux loss to prevent it from dominating | |
| total_loss = ce_loss + 0.01 * total_aux_loss | |
| return { | |
| "loss": total_loss, | |
| "ce_loss": ce_loss, | |
| "aux_loss": total_aux_loss, | |
| "logits": logits | |
| } | |
| return {"logits": logits} | |