| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from transformers import PreTrainedModel |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from configuration_ultrabase import UltraBaseConfig |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| variance = x.pow(2).mean(-1, keepdim=True) |
| return x * torch.rsqrt(variance + self.eps) * self.weight |
|
|
| class MLA(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
| self.latent_dim = config.latent_dim |
| self.d_model = config.d_model |
| |
| self.kv_down_proj = nn.Linear(config.d_model, config.latent_dim, bias=False) |
| self.kv_up_proj_k = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False) |
| self.kv_up_proj_v = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False) |
| |
| self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False) |
|
|
| def forward(self, x): |
| B, S, C = x.shape |
| q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) |
| |
| latent_kv = self.kv_down_proj(x) |
| k = self.kv_up_proj_k(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) |
| v = self.kv_up_proj_v(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) |
| |
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| |
| mask = torch.full((S, S), float("-inf"), device=x.device) |
| mask = torch.triu(mask, diagonal=1) |
| attn_scores = attn_scores + mask.unsqueeze(0).unsqueeze(1) |
| |
| attn_weights = F.softmax(attn_scores, dim=-1) |
| context = torch.matmul(attn_weights, v) |
| context = context.transpose(1, 2).contiguous().view(B, S, -1) |
| |
| return self.o_proj(context) |
|
|
| class Expert(nn.Module): |
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) |
| self.w2 = nn.Linear(d_ff, d_model, bias=False) |
| self.act = nn.SiLU() |
|
|
| def forward(self, x): |
| return self.w2(self.act(self.w1(x))) |
|
|
| class SSPMoE(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.num_private = config.num_private_experts |
| self.shared_expert = Expert(config.d_model, config.d_ff) |
| self.private_experts = nn.ModuleList([ |
| Expert(config.d_model, config.d_ff) for _ in range(self.num_private) |
| ]) |
| self.router = nn.Linear(config.d_model, self.num_private, bias=False) |
|
|
| def forward(self, x): |
| shared_out = self.shared_expert(x) |
| |
| router_logits = self.router(x) |
| routing_weights = F.softmax(router_logits, dim=-1) |
| top1_weights, top1_indices = torch.topk(routing_weights, k=1, dim=-1) |
| |
| B, S, C = x.shape |
| flat_x = x.view(-1, C) |
| flat_indices = top1_indices.view(-1) |
| flat_weights = top1_weights.view(-1, 1) |
| |
| private_out = torch.zeros_like(flat_x) |
| for i in range(self.num_private): |
| mask = (flat_indices == i) |
| if mask.any(): |
| expert_in = flat_x[mask] |
| expert_out = self.private_experts[i](expert_in) |
| private_out[mask] = expert_out * flat_weights[mask] |
| |
| private_out = private_out.view(B, S, C) |
| return shared_out + private_out |
|
|
| class DecoderLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.active_rate = 1.0 - config.bypass_rate |
| self.mod_router = nn.Linear(config.d_model, 1, bias=False) |
| |
| self.pre_rmsnorm = RMSNorm(config.d_model) |
| self.mla_block = MLA(config) |
| self.ssp_moe_layer = SSPMoE(config) |
| self.post_rmsnorm = RMSNorm(config.d_model) |
|
|
| def forward(self, x): |
| B, S, C = x.shape |
| if S < 2: |
| h = self.pre_rmsnorm(x) |
| h = h + self.mla_block(h) |
| h = h + self.ssp_moe_layer(h) |
| return self.post_rmsnorm(h) |
| |
| router_logits = self.mod_router(x).squeeze(-1) |
| k = int(S * self.active_rate) |
| k = max(1, min(k, S)) |
| |
| _, topk_indices = torch.topk(router_logits, k, dim=-1) |
| out = x.clone() |
| |
| for b in range(B): |
| active_idx = topk_indices[b] |
| x_active = x[b, active_idx, :].unsqueeze(0) |
| |
| h = self.pre_rmsnorm(x_active) |
| h = h + self.mla_block(h) |
| h = h + self.ssp_moe_layer(h) |
| h = self.post_rmsnorm(h) |
| |
| out[b, active_idx, :] = h.squeeze(0) |
| |
| return out |
|
|
| class UltraBasePreTrainedModel(PreTrainedModel): |
| config_class = UltraBaseConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| class UltraBaseForCausalLM(PreTrainedModel, GenerationMixin): |
| def __init__(self, config): |
| super().__init__(config) |
| self.embed = nn.Embedding(config.vocab_size, config.d_model) |
| self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)]) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed |
|
|
| def set_input_embeddings(self, value): |
| self.embed = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward(self, input_ids, labels=None, **kwargs): |
| x = self.embed(input_ids) |
| for layer in self.layers: |
| x = layer(x) |
| logits = self.lm_head(x) |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|