Upload modeling_bitmamba.py with huggingface_hub
Browse files- modeling_bitmamba.py +114 -0
modeling_bitmamba.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 7 |
+
|
| 8 |
+
class BitMambaConfig(PretrainedConfig):
|
| 9 |
+
model_type = "bitmamba"
|
| 10 |
+
def __init__(self, vocab_size=151552, hidden_dim=768, vision_dim=1152, num_layers=12, **kwargs):
|
| 11 |
+
super().__init__(**kwargs)
|
| 12 |
+
self.vocab_size = vocab_size
|
| 13 |
+
self.hidden_dim = hidden_dim
|
| 14 |
+
self.vision_dim = vision_dim
|
| 15 |
+
self.num_layers = num_layers
|
| 16 |
+
|
| 17 |
+
class STESign(torch.autograd.Function):
|
| 18 |
+
@staticmethod
|
| 19 |
+
def forward(ctx, x): return torch.clamp(torch.round(x), -1, 1)
|
| 20 |
+
@staticmethod
|
| 21 |
+
def backward(ctx, grad_output): return grad_output
|
| 22 |
+
|
| 23 |
+
class STERound(torch.autograd.Function):
|
| 24 |
+
@staticmethod
|
| 25 |
+
def forward(ctx, x): return torch.round(x)
|
| 26 |
+
@staticmethod
|
| 27 |
+
def backward(ctx, grad_output): return grad_output
|
| 28 |
+
|
| 29 |
+
def ste_sign(x): return STESign.apply(x)
|
| 30 |
+
def ste_round(x): return STERound.apply(x)
|
| 31 |
+
|
| 32 |
+
class BitLinear(nn.Module):
|
| 33 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.in_features = in_features; self.out_features = out_features
|
| 36 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features))
|
| 37 |
+
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
| 38 |
+
self.alpha = nn.Parameter(torch.ones(out_features))
|
| 39 |
+
# No init needed here if loading from weights, but good practice
|
| 40 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
input_dim = x.dim(); eps = 1e-5
|
| 44 |
+
# Detach scales for inference stability (matches training)
|
| 45 |
+
w_scale = self.weight.abs().mean(dim=1, keepdim=True).clamp(min=eps).detach()
|
| 46 |
+
w_quant = ste_sign(self.weight / w_scale)
|
| 47 |
+
a_scale = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps).detach()
|
| 48 |
+
a_scaled = x / a_scale * 127.0
|
| 49 |
+
a_quant = ste_round((torch.clamp(a_scaled, -128, 127) - a_scaled).detach() + a_scaled)
|
| 50 |
+
y = F.linear(a_quant, w_quant, None)
|
| 51 |
+
w_scale_flat = w_scale.squeeze(-1)
|
| 52 |
+
rescale = (w_scale_flat.view(1, 1, -1) if input_dim == 3 else w_scale_flat.view(1, -1)) * a_scale / 127.0
|
| 53 |
+
y = y * rescale * (self.alpha.view(1, 1, -1) if input_dim == 3 else self.alpha.view(1, -1))
|
| 54 |
+
if self.bias is not None: y = y + self.bias
|
| 55 |
+
return y
|
| 56 |
+
|
| 57 |
+
class RMSNorm(nn.Module):
|
| 58 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
return (x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)) * self.weight
|
| 63 |
+
|
| 64 |
+
class BitMambaBlock(nn.Module):
|
| 65 |
+
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.d_model = d_model; self.d_inner = int(expand * d_model); self.d_state = d_state
|
| 68 |
+
self.in_proj = BitLinear(d_model, self.d_inner * 2, bias=False)
|
| 69 |
+
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1)
|
| 70 |
+
self.act = nn.SiLU()
|
| 71 |
+
self.x_proj = nn.Linear(self.d_inner, math.ceil(d_model/16) + d_state * 2, bias=False)
|
| 72 |
+
self.dt_proj = nn.Linear(math.ceil(d_model/16), self.d_inner, bias=True)
|
| 73 |
+
self.log_A = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)))
|
| 74 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 75 |
+
self.out_proj = BitLinear(self.d_inner, d_model, bias=False)
|
| 76 |
+
self.norm = RMSNorm(d_model)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
residual = x; x = self.norm(x)
|
| 80 |
+
x, z = self.in_proj(x).chunk(2, dim=-1)
|
| 81 |
+
x_conv = self.act(self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]].transpose(1, 2))
|
| 82 |
+
dt, B, C = torch.split(self.x_proj(x_conv), [self.dt_proj.in_features, self.d_state, self.d_state], dim=-1)
|
| 83 |
+
dt = F.softplus(self.dt_proj(dt))
|
| 84 |
+
|
| 85 |
+
# Inference Scan (Standard)
|
| 86 |
+
dA = torch.exp(torch.einsum('bsd,dn->bsdn', dt, -torch.exp(self.log_A)))
|
| 87 |
+
dB = torch.einsum('bsd,bsn->bsdn', dt, B)
|
| 88 |
+
u = torch.einsum('bsdn,bsd->bsdn', dB, x_conv.float())
|
| 89 |
+
h = torch.zeros(x_conv.shape[0], self.d_inner, self.d_state, device=x_conv.device)
|
| 90 |
+
ys = []
|
| 91 |
+
for t in range(x_conv.shape[1]):
|
| 92 |
+
h = dA[:, t] * h + u[:, t]
|
| 93 |
+
ys.append(h)
|
| 94 |
+
y = torch.einsum('bsdn,bsn->bsd', torch.stack(ys, dim=1), C).to(x_conv.dtype)
|
| 95 |
+
|
| 96 |
+
return residual + self.out_proj(y * self.act(z) * self.D)
|
| 97 |
+
|
| 98 |
+
class BitMambaVLMStudent(PreTrainedModel):
|
| 99 |
+
config_class = BitMambaConfig
|
| 100 |
+
def __init__(self, config):
|
| 101 |
+
super().__init__(config)
|
| 102 |
+
self.hidden_dim = config.hidden_dim
|
| 103 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
|
| 104 |
+
self.vision_proj = nn.Sequential(nn.Linear(config.vision_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim))
|
| 105 |
+
self.layers = nn.ModuleList([BitMambaBlock(d_model=config.hidden_dim) for _ in range(config.num_layers)])
|
| 106 |
+
self.norm = RMSNorm(config.hidden_dim)
|
| 107 |
+
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
| 108 |
+
|
| 109 |
+
def forward(self, input_ids, image_embeds=None, **kwargs):
|
| 110 |
+
x = self.embed_tokens(input_ids)
|
| 111 |
+
if image_embeds is not None: x = torch.cat([self.vision_proj(image_embeds), x], dim=1)
|
| 112 |
+
for layer in self.layers: x = layer(x)
|
| 113 |
+
x_text = x[:, -input_ids.shape[1]:, :] if image_embeds is not None else x
|
| 114 |
+
return {'logits': self.lm_head(self.norm(x_text))}
|