sunjuice commited on
Commit
435cd2d
·
1 Parent(s): 034e1f0

changed nn.linear to use 4-bit quant

Browse files
Files changed (1) hide show
  1. modeling_molmo2.py +36 -20
modeling_molmo2.py CHANGED
@@ -7,6 +7,8 @@ import torch
7
  from torch import nn
8
  from torch.nn import functional as F
9
 
 
 
10
  from transformers.models.auto import AutoModelForImageTextToText
11
  from transformers.activations import ACT2FN
12
  from transformers.configuration_utils import PretrainedConfig
@@ -86,9 +88,10 @@ class Molmo2ModelOutputWithPast(BaseModelOutputWithPast):
86
  class ViTMLP(nn.Module):
87
  def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
88
  super().__init__()
89
- self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
 
90
  self.act = ACT2FN[hidden_act]
91
- self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
92
 
93
  def forward(self, x: torch.Tensor) -> torch.Tensor:
94
  return self.w2(self.act(self.w1(x)))
@@ -121,28 +124,36 @@ class ViTMultiHeadDotProductAttention(nn.Module):
121
 
122
  input_dim = input_dim or hidden_size
123
 
124
- self.wq = nn.Linear(
125
  input_dim,
126
  self.num_heads * self.head_dim,
127
  bias=use_bias,
 
128
  device=device,
129
  )
130
- self.wk = nn.Linear(
 
131
  input_dim,
132
  self.num_key_value_heads * self.head_dim,
133
  bias=use_bias,
 
134
  device=device,
135
  )
136
- self.wv = nn.Linear(
 
137
  input_dim,
138
  self.num_key_value_heads * self.head_dim,
139
  bias=use_bias,
 
140
  device=device,
141
  )
142
- self.wo = nn.Linear(
 
143
  self.num_heads * self.head_dim,
144
  self.hidden_size,
 
145
  )
 
146
  self.float32_attention = float32_attention
147
  self.attention_dropout = attention_dropout
148
  self.residual_dropout = nn.Dropout(residual_dropout)
@@ -247,7 +258,7 @@ class Molmo2VisionBlock(nn.Module):
247
  num_heads=config.num_attention_heads,
248
  num_key_value_heads=config.num_key_value_heads,
249
  head_dim=config.head_dim,
250
- float32_attention=config.float32_attention,
251
  attention_dropout=config.attention_dropout,
252
  residual_dropout=config.residual_dropout,
253
  device=device,
@@ -258,7 +269,6 @@ class Molmo2VisionBlock(nn.Module):
258
  self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
259
 
260
  def forward(self, x: torch.Tensor) -> torch.Tensor:
261
- print("dtype before norm:", x.dtype)
262
  x = x + self.attention(self.attention_norm(x))
263
  x = x + self.feed_forward(self.ffn_norm(x))
264
  return x
@@ -295,10 +305,12 @@ class Molmo2VisionTransformer(nn.Module):
295
  )
296
 
297
  image_patch_size = config.image_patch_size
298
- self.patch_embedding = nn.Linear(
 
299
  image_patch_size * image_patch_size * 3,
300
  config.hidden_size,
301
  bias=True,
 
302
  device=device,
303
  )
304
 
@@ -355,9 +367,10 @@ class ImageProjectorMLP(nn.Module):
355
  device: Union[str, torch.device] = None,
356
  ):
357
  super().__init__()
358
- self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
359
- self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
360
- self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
 
361
  self.act = ACT2FN[hidden_act]
362
 
363
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -636,11 +649,12 @@ class Molmo2Attention(nn.Module):
636
  config.head_dim * config.num_key_value_heads,
637
  config.head_dim * config.num_key_value_heads,
638
  )
639
- self.att_proj = nn.Linear(
640
  config.hidden_size,
641
  sum(self.fused_dims),
642
  bias=config.qkv_bias,
643
- )
 
644
 
645
  # Layer norms.
646
  self.k_norm: Optional[Molmo2RMSNorm] = None
@@ -662,11 +676,12 @@ class Molmo2Attention(nn.Module):
662
  self.qk_norm_type = config.qk_norm_type
663
 
664
  self.attention_dropout = config.attention_dropout
665
-
666
- self.attn_out = nn.Linear(
667
  config.head_dim * config.num_attention_heads,
668
  config.hidden_size,
669
  bias=False,
 
670
  )
671
 
672
  def forward(
@@ -737,8 +752,9 @@ class LanguageModelMLP(nn.Module):
737
  device: Union[str, torch.device] = None,
738
  ):
739
  super().__init__()
740
- self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device)
741
- self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device)
 
742
  self.act = ACT2FN[hidden_act]
743
 
744
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -904,7 +920,7 @@ class Molmo2PreTrainedModel(PreTrainedModel):
904
 
905
  def _init_weights(self, module):
906
  std = self.config.initializer_range
907
- if isinstance(module, (nn.Linear,)):
908
  module.weight.data.normal_(mean=0.0, std=std)
909
  if module.bias is not None:
910
  module.bias.data.zero_()
@@ -1576,7 +1592,7 @@ class Molmo2ForConditionalGeneration(Molmo2PreTrainedModel, GenerationMixin):
1576
  super().__init__(config)
1577
 
1578
  self.model = Molmo2Model(config)
1579
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1580
  self.vocab_size = config.vocab_size
1581
 
1582
  # Initialize weights and apply final processing
 
7
  from torch import nn
8
  from torch.nn import functional as F
9
 
10
+ import bitsandbytes as bnb
11
+
12
  from transformers.models.auto import AutoModelForImageTextToText
13
  from transformers.activations import ACT2FN
14
  from transformers.configuration_utils import PretrainedConfig
 
88
  class ViTMLP(nn.Module):
89
  def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
90
  super().__init__()
91
+ self.w1 = bnb.nn.Linear4bit(dim, hidden_dim, bias=True, quant_type="nf4", device=device)
92
+
93
  self.act = ACT2FN[hidden_act]
94
+ self.w2 = bnb.nn.Linear4bit(dim, hidden_dim, bias=True, quant_type="nf4", device=device)
95
 
96
  def forward(self, x: torch.Tensor) -> torch.Tensor:
97
  return self.w2(self.act(self.w1(x)))
 
124
 
125
  input_dim = input_dim or hidden_size
126
 
127
+ self.wq = bnb.nn.Linear4bit(
128
  input_dim,
129
  self.num_heads * self.head_dim,
130
  bias=use_bias,
131
+ quant_type="nf4",
132
  device=device,
133
  )
134
+
135
+ self.wk = bnb.nn.Linear4bit(
136
  input_dim,
137
  self.num_key_value_heads * self.head_dim,
138
  bias=use_bias,
139
+ quant_type="nf4",
140
  device=device,
141
  )
142
+
143
+ self.wv = bnb.nn.Linear4bit(
144
  input_dim,
145
  self.num_key_value_heads * self.head_dim,
146
  bias=use_bias,
147
+ quant_type="nf4",
148
  device=device,
149
  )
150
+
151
+ self.wo = bnb.nn.Linear4bit(
152
  self.num_heads * self.head_dim,
153
  self.hidden_size,
154
+ quant_type="nf4",
155
  )
156
+
157
  self.float32_attention = float32_attention
158
  self.attention_dropout = attention_dropout
159
  self.residual_dropout = nn.Dropout(residual_dropout)
 
258
  num_heads=config.num_attention_heads,
259
  num_key_value_heads=config.num_key_value_heads,
260
  head_dim=config.head_dim,
261
+ float32_attention=False,
262
  attention_dropout=config.attention_dropout,
263
  residual_dropout=config.residual_dropout,
264
  device=device,
 
269
  self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device)
270
 
271
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
272
  x = x + self.attention(self.attention_norm(x))
273
  x = x + self.feed_forward(self.ffn_norm(x))
274
  return x
 
305
  )
306
 
307
  image_patch_size = config.image_patch_size
308
+
309
+ self.patch_embedding = bnb.nn.Linear4bit(
310
  image_patch_size * image_patch_size * 3,
311
  config.hidden_size,
312
  bias=True,
313
+ quant_type="nf4",
314
  device=device,
315
  )
316
 
 
367
  device: Union[str, torch.device] = None,
368
  ):
369
  super().__init__()
370
+
371
+ self.w1 = bnb.nn.Linear4bit(input_dim, hidden_dim, bias=False, quant_type="nf4", device=device)
372
+ self.w2 = bnb.nn.Linear4bit(hidden_dim, output_dim, bias=False, quant_type="nf4", device=device)
373
+ self.w3 = bnb.nn.Linear4bit(input_dim, hidden_dim, bias=False, quant_type="nf4", device=device)
374
  self.act = ACT2FN[hidden_act]
375
 
376
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
649
  config.head_dim * config.num_key_value_heads,
650
  config.head_dim * config.num_key_value_heads,
651
  )
652
+ self.att_proj = bnb.nn.Linear4bit(
653
  config.hidden_size,
654
  sum(self.fused_dims),
655
  bias=config.qkv_bias,
656
+ quant_type="nf4",
657
+ )
658
 
659
  # Layer norms.
660
  self.k_norm: Optional[Molmo2RMSNorm] = None
 
676
  self.qk_norm_type = config.qk_norm_type
677
 
678
  self.attention_dropout = config.attention_dropout
679
+
680
+ self.attn_out = bnb.nn.Linear4bit(
681
  config.head_dim * config.num_attention_heads,
682
  config.hidden_size,
683
  bias=False,
684
+ quant_type="nf4",
685
  )
686
 
687
  def forward(
 
752
  device: Union[str, torch.device] = None,
753
  ):
754
  super().__init__()
755
+ self.ff_proj = bnb.nn.Linear4bit(input_dim, intermediate_size * 2, bias=False, quant_type="nf4", device=device)
756
+ self.ff_out = bnb.nn.Linear4bit(intermediate_size, input_dim, bias=False, quant_type="nf4", device=device)
757
+
758
  self.act = ACT2FN[hidden_act]
759
 
760
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
920
 
921
  def _init_weights(self, module):
922
  std = self.config.initializer_range
923
+ if isinstance(module, (bnb.nn.Linear4bit,)):
924
  module.weight.data.normal_(mean=0.0, std=std)
925
  if module.bias is not None:
926
  module.bias.data.zero_()
 
1592
  super().__init__(config)
1593
 
1594
  self.model = Molmo2Model(config)
1595
+ self.lm_head = bnb.nn.Linear4bit(config.hidden_size, config.vocab_size, bias=False, quant_type="nf4")
1596
  self.vocab_size = config.vocab_size
1597
 
1598
  # Initialize weights and apply final processing