|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| class BitMambaConfig(PretrainedConfig): |
| model_type = "bitmamba" |
| def __init__(self, vocab_size=151365, hidden_dim=768, vision_dim=1152, num_layers=12, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.hidden_dim = hidden_dim |
| self.vision_dim = vision_dim |
| self.num_layers = num_layers |
|
|
| class STESign(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): return torch.clamp(torch.round(x), -1, 1) |
| @staticmethod |
| def backward(ctx, grad_output): return grad_output |
|
|
| class STERound(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): return torch.round(x) |
| @staticmethod |
| def backward(ctx, grad_output): return grad_output |
|
|
| def ste_sign(x): return STESign.apply(x) |
| def ste_round(x): return STERound.apply(x) |
|
|
| class BitLinear(nn.Module): |
| def __init__(self, in_features: int, out_features: int, bias: bool = False): |
| super().__init__() |
| self.in_features = in_features; self.out_features = out_features |
| self.weight = nn.Parameter(torch.empty(out_features, in_features)) |
| self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None |
| self.alpha = nn.Parameter(torch.ones(out_features)) |
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| def forward(self, x): |
| input_dim = x.dim(); eps = 1e-5 |
| w_scale = self.weight.abs().mean(dim=1, keepdim=True).clamp(min=eps) |
| w_quant = ste_sign(self.weight / w_scale) |
| a_scale = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps) |
| a_scaled = x / a_scale * 127.0 |
| a_quant = ste_round((torch.clamp(a_scaled, -128, 127) - a_scaled).detach() + a_scaled) |
| y = F.linear(a_quant, w_quant, None) |
| w_scale_flat = w_scale.squeeze(-1) |
| rescale = (w_scale_flat.view(1, 1, -1) if input_dim == 3 else w_scale_flat.view(1, -1)) * a_scale / 127.0 |
| y = y * rescale * (self.alpha.view(1, 1, -1) if input_dim == 3 else self.alpha.view(1, -1)) |
| if self.bias is not None: y = y + self.bias |
| return y |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps; self.weight = nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| return (x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)) * self.weight |
|
|
| class BitMambaBlock(nn.Module): |
| def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2): |
| super().__init__() |
| self.d_model = d_model; self.d_inner = int(expand * d_model); self.d_state = d_state |
| self.in_proj = BitLinear(d_model, self.d_inner * 2, bias=False) |
| self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1) |
| self.act = nn.SiLU() |
| self.x_proj = nn.Linear(self.d_inner, math.ceil(d_model/16) + d_state * 2, bias=False) |
| self.dt_proj = nn.Linear(math.ceil(d_model/16), self.d_inner, bias=True) |
| self.log_A = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1))) |
| self.D = nn.Parameter(torch.ones(self.d_inner)) |
| self.out_proj = BitLinear(self.d_inner, d_model, bias=False) |
| self.norm = RMSNorm(d_model) |
| nn.init.uniform_(self.dt_proj.bias, math.log(0.001), math.log(0.1)) |
|
|
| def forward(self, x): |
| residual = x; x = self.norm(x) |
| x, z = self.in_proj(x).chunk(2, dim=-1) |
| x_conv = self.act(self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]].transpose(1, 2)) |
| dt, B, C = torch.split(self.x_proj(x_conv), [self.dt_proj.in_features, self.d_state, self.d_state], dim=-1) |
| dt = F.softplus(self.dt_proj(dt)) |
| dA = torch.exp(torch.einsum('bsd,dn->bsdn', dt, -torch.exp(self.log_A))) |
| dB = torch.einsum('bsd,bsn->bsdn', dt, B) |
| u = torch.einsum('bsdn,bsd->bsdn', dB, x_conv.float()) |
| h = torch.zeros(x_conv.shape[0], self.d_inner, self.d_state, device=x_conv.device) |
| ys = [] |
| for t in range(x_conv.shape[1]): |
| h = dA[:, t] * h + u[:, t] |
| ys.append(h) |
| y = torch.einsum('bsdn,bsn->bsd', torch.stack(ys, dim=1), C).to(x_conv.dtype) |
| return residual + self.out_proj(y * self.act(z) * self.D) |
|
|
| class BitMambaVLMStudent(PreTrainedModel): |
| config_class = BitMambaConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.hidden_dim = config.hidden_dim |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim) |
| self.vision_proj = nn.Sequential(nn.Linear(config.vision_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim)) |
| self.layers = nn.ModuleList([BitMambaBlock(d_model=config.hidden_dim) for _ in range(config.num_layers)]) |
| self.norm = RMSNorm(config.hidden_dim) |
| self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) |
| |
| def forward(self, input_ids, image_embeds=None, **kwargs): |
| x = self.embed_tokens(input_ids) |
| if image_embeds is not None: x = torch.cat([self.vision_proj(image_embeds), x], dim=1) |
| for layer in self.layers: x = layer(x) |
| x_text = x[:, -input_ids.shape[1]:, :] if image_embeds is not None else x |
| return {'logits': self.lm_head(self.norm(x_text))} |
|
|