File size: 3,401 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
from typing import Dict, Any, List
from .config import AetherisConfig
from .modules import SSMBlock, SparseMoELayer

class HybridMambaMoE(nn.Module):
    def __init__(self, config: AetherisConfig):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)

        self.layers = nn.ModuleList()
        for i in range(config.n_layer):
            if i % 2 == 0:
                self.layers.append(SSMBlock(config))
            else:
                self.layers.append(SparseMoELayer(config))

        self.final_norm = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Weight tying

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
        self.gradient_checkpointing = config.gradient_checkpointing

        # Initialize embeddings with smaller scale
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)

    def _init_weights(self, module):
        """Apply proper weight initialization"""
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight, gain=0.5)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, Any]:
        x = self.embedding(input_ids)
        total_aux_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)

        for i, layer in enumerate(self.layers):
            if self.gradient_checkpointing and self.training:
                # Checkpoint ALL layers for maximum memory savings
                if isinstance(layer, SparseMoELayer):
                    def moe_forward(module, inp):
                        return module(inp)
                    x, aux_loss = torch.utils.checkpoint.checkpoint(
                        moe_forward, layer, x, use_reentrant=False
                    )
                    total_aux_loss = total_aux_loss + aux_loss
                else:
                    x = torch.utils.checkpoint.checkpoint(
                        layer, x, use_reentrant=False
                    )
            else:
                if isinstance(layer, SparseMoELayer):
                    x, aux_loss = layer(x)
                    total_aux_loss = total_aux_loss + aux_loss
                else:
                    x = layer(x)

        x = self.final_norm(x)
        logits = self.lm_head(x)

        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            ce_loss = self.loss_fn(shift_logits.view(-1, self.config.vocab_size),
                                  shift_labels.view(-1))

            # Scale down aux loss to prevent it from dominating
            total_loss = ce_loss + 0.01 * total_aux_loss

            return {
                "loss": total_loss,
                "ce_loss": ce_loss,
                "aux_loss": total_aux_loss,
                "logits": logits
            }

        return {"logits": logits}