import torch import torch.nn as nn import torch.nn.functional as F import math from transformers import PreTrainedModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from configuration_ultrabase import UltraBaseConfig class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): variance = x.pow(2).mean(-1, keepdim=True) return x * torch.rsqrt(variance + self.eps) * self.weight class MLA(nn.Module): def __init__(self, config): super().__init__() self.n_heads = config.n_heads self.head_dim = config.head_dim self.latent_dim = config.latent_dim self.d_model = config.d_model self.kv_down_proj = nn.Linear(config.d_model, config.latent_dim, bias=False) self.kv_up_proj_k = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False) self.kv_up_proj_v = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False) self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False) self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False) def forward(self, x): B, S, C = x.shape q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) latent_kv = self.kv_down_proj(x) k = self.kv_up_proj_k(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) v = self.kv_up_proj_v(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) mask = torch.full((S, S), float("-inf"), device=x.device) mask = torch.triu(mask, diagonal=1) attn_scores = attn_scores + mask.unsqueeze(0).unsqueeze(1) attn_weights = F.softmax(attn_scores, dim=-1) context = torch.matmul(attn_weights, v) context = context.transpose(1, 2).contiguous().view(B, S, -1) return self.o_proj(context) class Expert(nn.Module): def __init__(self, d_model, d_ff): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) self.act = nn.SiLU() def forward(self, x): return self.w2(self.act(self.w1(x))) class SSPMoE(nn.Module): def __init__(self, config): super().__init__() self.num_private = config.num_private_experts self.shared_expert = Expert(config.d_model, config.d_ff) self.private_experts = nn.ModuleList([ Expert(config.d_model, config.d_ff) for _ in range(self.num_private) ]) self.router = nn.Linear(config.d_model, self.num_private, bias=False) def forward(self, x): shared_out = self.shared_expert(x) router_logits = self.router(x) routing_weights = F.softmax(router_logits, dim=-1) top1_weights, top1_indices = torch.topk(routing_weights, k=1, dim=-1) B, S, C = x.shape flat_x = x.view(-1, C) flat_indices = top1_indices.view(-1) flat_weights = top1_weights.view(-1, 1) private_out = torch.zeros_like(flat_x) for i in range(self.num_private): mask = (flat_indices == i) if mask.any(): expert_in = flat_x[mask] expert_out = self.private_experts[i](expert_in) private_out[mask] = expert_out * flat_weights[mask] private_out = private_out.view(B, S, C) return shared_out + private_out class DecoderLayer(nn.Module): def __init__(self, config): super().__init__() self.active_rate = 1.0 - config.bypass_rate self.mod_router = nn.Linear(config.d_model, 1, bias=False) self.pre_rmsnorm = RMSNorm(config.d_model) self.mla_block = MLA(config) self.ssp_moe_layer = SSPMoE(config) self.post_rmsnorm = RMSNorm(config.d_model) def forward(self, x): B, S, C = x.shape if S < 2: h = self.pre_rmsnorm(x) h = h + self.mla_block(h) h = h + self.ssp_moe_layer(h) return self.post_rmsnorm(h) router_logits = self.mod_router(x).squeeze(-1) k = int(S * self.active_rate) k = max(1, min(k, S)) _, topk_indices = torch.topk(router_logits, k, dim=-1) out = x.clone() for b in range(B): active_idx = topk_indices[b] x_active = x[b, active_idx, :].unsqueeze(0) h = self.pre_rmsnorm(x_active) h = h + self.mla_block(h) h = h + self.ssp_moe_layer(h) h = self.post_rmsnorm(h) out[b, active_idx, :] = h.squeeze(0) return out class UltraBasePreTrainedModel(PreTrainedModel): config_class = UltraBaseConfig base_model_prefix = "model" supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) class UltraBaseForCausalLM(PreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.embed = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)]) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.embed def set_input_embeddings(self, value): self.embed = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward(self, input_ids, labels=None, **kwargs): x = self.embed(input_ids) for layer in self.layers: x = layer(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}