UltraBase / modeling_ultrabase.py
56m's picture
Upload 6 files
2144393 verified
Raw
History Blame Contribute Delete
6.97 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from configuration_ultrabase import UltraBaseConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
return x * torch.rsqrt(variance + self.eps) * self.weight
class MLA(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.latent_dim = config.latent_dim
self.d_model = config.d_model
self.kv_down_proj = nn.Linear(config.d_model, config.latent_dim, bias=False)
self.kv_up_proj_k = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False)
self.kv_up_proj_v = nn.Linear(config.latent_dim, config.n_heads * config.head_dim, bias=False)
self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False)
self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
def forward(self, x):
B, S, C = x.shape
q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
latent_kv = self.kv_down_proj(x)
k = self.kv_up_proj_k(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
v = self.kv_up_proj_v(latent_kv).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
mask = torch.full((S, S), float("-inf"), device=x.device)
mask = torch.triu(mask, diagonal=1)
attn_scores = attn_scores + mask.unsqueeze(0).unsqueeze(1)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(B, S, -1)
return self.o_proj(context)
class Expert(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.act = nn.SiLU()
def forward(self, x):
return self.w2(self.act(self.w1(x)))
class SSPMoE(nn.Module):
def __init__(self, config):
super().__init__()
self.num_private = config.num_private_experts
self.shared_expert = Expert(config.d_model, config.d_ff)
self.private_experts = nn.ModuleList([
Expert(config.d_model, config.d_ff) for _ in range(self.num_private)
])
self.router = nn.Linear(config.d_model, self.num_private, bias=False)
def forward(self, x):
shared_out = self.shared_expert(x)
router_logits = self.router(x)
routing_weights = F.softmax(router_logits, dim=-1)
top1_weights, top1_indices = torch.topk(routing_weights, k=1, dim=-1)
B, S, C = x.shape
flat_x = x.view(-1, C)
flat_indices = top1_indices.view(-1)
flat_weights = top1_weights.view(-1, 1)
private_out = torch.zeros_like(flat_x)
for i in range(self.num_private):
mask = (flat_indices == i)
if mask.any():
expert_in = flat_x[mask]
expert_out = self.private_experts[i](expert_in)
private_out[mask] = expert_out * flat_weights[mask]
private_out = private_out.view(B, S, C)
return shared_out + private_out
class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.active_rate = 1.0 - config.bypass_rate
self.mod_router = nn.Linear(config.d_model, 1, bias=False)
self.pre_rmsnorm = RMSNorm(config.d_model)
self.mla_block = MLA(config)
self.ssp_moe_layer = SSPMoE(config)
self.post_rmsnorm = RMSNorm(config.d_model)
def forward(self, x):
B, S, C = x.shape
if S < 2:
h = self.pre_rmsnorm(x)
h = h + self.mla_block(h)
h = h + self.ssp_moe_layer(h)
return self.post_rmsnorm(h)
router_logits = self.mod_router(x).squeeze(-1)
k = int(S * self.active_rate)
k = max(1, min(k, S))
_, topk_indices = torch.topk(router_logits, k, dim=-1)
out = x.clone()
for b in range(B):
active_idx = topk_indices[b]
x_active = x[b, active_idx, :].unsqueeze(0)
h = self.pre_rmsnorm(x_active)
h = h + self.mla_block(h)
h = h + self.ssp_moe_layer(h)
h = self.post_rmsnorm(h)
out[b, active_idx, :] = h.squeeze(0)
return out
class UltraBasePreTrainedModel(PreTrainedModel):
config_class = UltraBaseConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
class UltraBaseForCausalLM(PreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.embed = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)])
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(self, input_ids, labels=None, **kwargs):
x = self.embed(input_ids)
for layer in self.layers:
x = layer(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}