BitMamba-Zen-VLM-v3.4 / modeling_bitmamba.py
rileyseaburg's picture
Upload modeling_bitmamba.py with huggingface_hub
4252933 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import PreTrainedModel, PretrainedConfig
class BitMambaConfig(PretrainedConfig):
model_type = "bitmamba"
def __init__(self, vocab_size=151365, hidden_dim=768, vision_dim=1152, num_layers=12, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.vision_dim = vision_dim
self.num_layers = num_layers
class STESign(torch.autograd.Function):
@staticmethod
def forward(ctx, x): return torch.clamp(torch.round(x), -1, 1)
@staticmethod
def backward(ctx, grad_output): return grad_output
class STERound(torch.autograd.Function):
@staticmethod
def forward(ctx, x): return torch.round(x)
@staticmethod
def backward(ctx, grad_output): return grad_output
def ste_sign(x): return STESign.apply(x)
def ste_round(x): return STERound.apply(x)
class BitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = False):
super().__init__()
self.in_features = in_features; self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self.alpha = nn.Parameter(torch.ones(out_features))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, x):
input_dim = x.dim(); eps = 1e-5
w_scale = self.weight.abs().mean(dim=1, keepdim=True).clamp(min=eps)
w_quant = ste_sign(self.weight / w_scale)
a_scale = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps)
a_scaled = x / a_scale * 127.0
a_quant = ste_round((torch.clamp(a_scaled, -128, 127) - a_scaled).detach() + a_scaled)
y = F.linear(a_quant, w_quant, None)
w_scale_flat = w_scale.squeeze(-1)
rescale = (w_scale_flat.view(1, 1, -1) if input_dim == 3 else w_scale_flat.view(1, -1)) * a_scale / 127.0
y = y * rescale * (self.alpha.view(1, 1, -1) if input_dim == 3 else self.alpha.view(1, -1))
if self.bias is not None: y = y + self.bias
return y
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return (x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)) * self.weight
class BitMambaBlock(nn.Module):
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
super().__init__()
self.d_model = d_model; self.d_inner = int(expand * d_model); self.d_state = d_state
self.in_proj = BitLinear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1)
self.act = nn.SiLU()
self.x_proj = nn.Linear(self.d_inner, math.ceil(d_model/16) + d_state * 2, bias=False)
self.dt_proj = nn.Linear(math.ceil(d_model/16), self.d_inner, bias=True)
self.log_A = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = BitLinear(self.d_inner, d_model, bias=False)
self.norm = RMSNorm(d_model)
nn.init.uniform_(self.dt_proj.bias, math.log(0.001), math.log(0.1))
def forward(self, x):
residual = x; x = self.norm(x)
x, z = self.in_proj(x).chunk(2, dim=-1)
x_conv = self.act(self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]].transpose(1, 2))
dt, B, C = torch.split(self.x_proj(x_conv), [self.dt_proj.in_features, self.d_state, self.d_state], dim=-1)
dt = F.softplus(self.dt_proj(dt))
dA = torch.exp(torch.einsum('bsd,dn->bsdn', dt, -torch.exp(self.log_A)))
dB = torch.einsum('bsd,bsn->bsdn', dt, B)
u = torch.einsum('bsdn,bsd->bsdn', dB, x_conv.float())
h = torch.zeros(x_conv.shape[0], self.d_inner, self.d_state, device=x_conv.device)
ys = []
for t in range(x_conv.shape[1]):
h = dA[:, t] * h + u[:, t]
ys.append(h)
y = torch.einsum('bsdn,bsn->bsd', torch.stack(ys, dim=1), C).to(x_conv.dtype)
return residual + self.out_proj(y * self.act(z) * self.D)
class BitMambaVLMStudent(PreTrainedModel):
config_class = BitMambaConfig
def __init__(self, config):
super().__init__(config)
self.hidden_dim = config.hidden_dim
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
self.vision_proj = nn.Sequential(nn.Linear(config.vision_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim))
self.layers = nn.ModuleList([BitMambaBlock(d_model=config.hidden_dim) for _ in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_dim)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
def forward(self, input_ids, image_embeds=None, **kwargs):
x = self.embed_tokens(input_ids)
if image_embeds is not None: x = torch.cat([self.vision_proj(image_embeds), x], dim=1)
for layer in self.layers: x = layer(x)
x_text = x[:, -input_ids.shape[1]:, :] if image_embeds is not None else x
return {'logits': self.lm_head(self.norm(x_text))}