rileyseaburg commited on
Commit
0b718e5
·
verified ·
1 Parent(s): 6a04338

Upload modeling_bitmamba.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_bitmamba.py +114 -0
modeling_bitmamba.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+
8
+ class BitMambaConfig(PretrainedConfig):
9
+ model_type = "bitmamba"
10
+ def __init__(self, vocab_size=151552, hidden_dim=768, vision_dim=1152, num_layers=12, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.vocab_size = vocab_size
13
+ self.hidden_dim = hidden_dim
14
+ self.vision_dim = vision_dim
15
+ self.num_layers = num_layers
16
+
17
+ class STESign(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x): return torch.clamp(torch.round(x), -1, 1)
20
+ @staticmethod
21
+ def backward(ctx, grad_output): return grad_output
22
+
23
+ class STERound(torch.autograd.Function):
24
+ @staticmethod
25
+ def forward(ctx, x): return torch.round(x)
26
+ @staticmethod
27
+ def backward(ctx, grad_output): return grad_output
28
+
29
+ def ste_sign(x): return STESign.apply(x)
30
+ def ste_round(x): return STERound.apply(x)
31
+
32
+ class BitLinear(nn.Module):
33
+ def __init__(self, in_features: int, out_features: int, bias: bool = False):
34
+ super().__init__()
35
+ self.in_features = in_features; self.out_features = out_features
36
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
37
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
38
+ self.alpha = nn.Parameter(torch.ones(out_features))
39
+ # No init needed here if loading from weights, but good practice
40
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
41
+
42
+ def forward(self, x):
43
+ input_dim = x.dim(); eps = 1e-5
44
+ # Detach scales for inference stability (matches training)
45
+ w_scale = self.weight.abs().mean(dim=1, keepdim=True).clamp(min=eps).detach()
46
+ w_quant = ste_sign(self.weight / w_scale)
47
+ a_scale = x.abs().amax(dim=-1, keepdim=True).clamp(min=eps).detach()
48
+ a_scaled = x / a_scale * 127.0
49
+ a_quant = ste_round((torch.clamp(a_scaled, -128, 127) - a_scaled).detach() + a_scaled)
50
+ y = F.linear(a_quant, w_quant, None)
51
+ w_scale_flat = w_scale.squeeze(-1)
52
+ rescale = (w_scale_flat.view(1, 1, -1) if input_dim == 3 else w_scale_flat.view(1, -1)) * a_scale / 127.0
53
+ y = y * rescale * (self.alpha.view(1, 1, -1) if input_dim == 3 else self.alpha.view(1, -1))
54
+ if self.bias is not None: y = y + self.bias
55
+ return y
56
+
57
+ class RMSNorm(nn.Module):
58
+ def __init__(self, dim: int, eps: float = 1e-5):
59
+ super().__init__()
60
+ self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
61
+ def forward(self, x):
62
+ return (x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)) * self.weight
63
+
64
+ class BitMambaBlock(nn.Module):
65
+ def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
66
+ super().__init__()
67
+ self.d_model = d_model; self.d_inner = int(expand * d_model); self.d_state = d_state
68
+ self.in_proj = BitLinear(d_model, self.d_inner * 2, bias=False)
69
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1)
70
+ self.act = nn.SiLU()
71
+ self.x_proj = nn.Linear(self.d_inner, math.ceil(d_model/16) + d_state * 2, bias=False)
72
+ self.dt_proj = nn.Linear(math.ceil(d_model/16), self.d_inner, bias=True)
73
+ self.log_A = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)))
74
+ self.D = nn.Parameter(torch.ones(self.d_inner))
75
+ self.out_proj = BitLinear(self.d_inner, d_model, bias=False)
76
+ self.norm = RMSNorm(d_model)
77
+
78
+ def forward(self, x):
79
+ residual = x; x = self.norm(x)
80
+ x, z = self.in_proj(x).chunk(2, dim=-1)
81
+ x_conv = self.act(self.conv1d(x.transpose(1, 2))[:, :, :x.shape[1]].transpose(1, 2))
82
+ dt, B, C = torch.split(self.x_proj(x_conv), [self.dt_proj.in_features, self.d_state, self.d_state], dim=-1)
83
+ dt = F.softplus(self.dt_proj(dt))
84
+
85
+ # Inference Scan (Standard)
86
+ dA = torch.exp(torch.einsum('bsd,dn->bsdn', dt, -torch.exp(self.log_A)))
87
+ dB = torch.einsum('bsd,bsn->bsdn', dt, B)
88
+ u = torch.einsum('bsdn,bsd->bsdn', dB, x_conv.float())
89
+ h = torch.zeros(x_conv.shape[0], self.d_inner, self.d_state, device=x_conv.device)
90
+ ys = []
91
+ for t in range(x_conv.shape[1]):
92
+ h = dA[:, t] * h + u[:, t]
93
+ ys.append(h)
94
+ y = torch.einsum('bsdn,bsn->bsd', torch.stack(ys, dim=1), C).to(x_conv.dtype)
95
+
96
+ return residual + self.out_proj(y * self.act(z) * self.D)
97
+
98
+ class BitMambaVLMStudent(PreTrainedModel):
99
+ config_class = BitMambaConfig
100
+ def __init__(self, config):
101
+ super().__init__(config)
102
+ self.hidden_dim = config.hidden_dim
103
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_dim)
104
+ self.vision_proj = nn.Sequential(nn.Linear(config.vision_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim))
105
+ self.layers = nn.ModuleList([BitMambaBlock(d_model=config.hidden_dim) for _ in range(config.num_layers)])
106
+ self.norm = RMSNorm(config.hidden_dim)
107
+ self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
108
+
109
+ def forward(self, input_ids, image_embeds=None, **kwargs):
110
+ x = self.embed_tokens(input_ids)
111
+ if image_embeds is not None: x = torch.cat([self.vision_proj(image_embeds), x], dim=1)
112
+ for layer in self.layers: x = layer(x)
113
+ x_text = x[:, -input_ids.shape[1]:, :] if image_embeds is not None else x
114
+ return {'logits': self.lm_head(self.norm(x_text))}