sunjuice commited on
Commit
e1b0f53
·
1 Parent(s): 89352af

Duplicate from allenai/Molmo2-8B

Browse files

Co-authored-by: Sangho Lee <sanghol@users.noreply.huggingface.co>

(cherry picked from commit 3d666677251284f0f13befe3bbd34062af653bbc)

Files changed (1) hide show
  1. modeling_molmo2.py +18 -35
modeling_molmo2.py CHANGED
@@ -7,8 +7,6 @@ import torch
7
  from torch import nn
8
  from torch.nn import functional as F
9
 
10
- import bitsandbytes as bnb
11
-
12
  from transformers.models.auto import AutoModelForImageTextToText
13
  from transformers.activations import ACT2FN
14
  from transformers.configuration_utils import PretrainedConfig
@@ -88,10 +86,9 @@ class Molmo2ModelOutputWithPast(BaseModelOutputWithPast):
88
  class ViTMLP(nn.Module):
89
  def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
90
  super().__init__()
91
- self.w1 = bnb.nn.Linear4bit(dim, hidden_dim, bias=True, quant_type="nf4", device=device)
92
-
93
  self.act = ACT2FN[hidden_act]
94
- self.w2 = bnb.nn.Linear4bit(hidden_dim, dim, bias=True, quant_type="nf4", device=device)
95
 
96
  def forward(self, x: torch.Tensor) -> torch.Tensor:
97
  return self.w2(self.act(self.w1(x)))
@@ -124,36 +121,28 @@ class ViTMultiHeadDotProductAttention(nn.Module):
124
 
125
  input_dim = input_dim or hidden_size
126
 
127
- self.wq = bnb.nn.Linear4bit(
128
  input_dim,
129
  self.num_heads * self.head_dim,
130
  bias=use_bias,
131
- quant_type="nf4",
132
  device=device,
133
  )
134
-
135
- self.wk = bnb.nn.Linear4bit(
136
  input_dim,
137
  self.num_key_value_heads * self.head_dim,
138
  bias=use_bias,
139
- quant_type="nf4",
140
  device=device,
141
  )
142
-
143
- self.wv = bnb.nn.Linear4bit(
144
  input_dim,
145
  self.num_key_value_heads * self.head_dim,
146
  bias=use_bias,
147
- quant_type="nf4",
148
  device=device,
149
  )
150
-
151
- self.wo = bnb.nn.Linear4bit(
152
  self.num_heads * self.head_dim,
153
  self.hidden_size,
154
- quant_type="nf4",
155
  )
156
-
157
  self.float32_attention = float32_attention
158
  self.attention_dropout = attention_dropout
159
  self.residual_dropout = nn.Dropout(residual_dropout)
@@ -305,12 +294,10 @@ class Molmo2VisionTransformer(nn.Module):
305
  )
306
 
307
  image_patch_size = config.image_patch_size
308
-
309
- self.patch_embedding = bnb.nn.Linear4bit(
310
  image_patch_size * image_patch_size * 3,
311
  config.hidden_size,
312
  bias=True,
313
- quant_type="nf4",
314
  device=device,
315
  )
316
 
@@ -367,10 +354,9 @@ class ImageProjectorMLP(nn.Module):
367
  device: Union[str, torch.device] = None,
368
  ):
369
  super().__init__()
370
-
371
- self.w1 = bnb.nn.Linear4bit(input_dim, hidden_dim, bias=False, quant_type="nf4", device=device)
372
- self.w2 = bnb.nn.Linear4bit(hidden_dim, output_dim, bias=False, quant_type="nf4", device=device)
373
- self.w3 = bnb.nn.Linear4bit(input_dim, hidden_dim, bias=False, quant_type="nf4", device=device)
374
  self.act = ACT2FN[hidden_act]
375
 
376
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -649,12 +635,11 @@ class Molmo2Attention(nn.Module):
649
  config.head_dim * config.num_key_value_heads,
650
  config.head_dim * config.num_key_value_heads,
651
  )
652
- self.att_proj = bnb.nn.Linear4bit(
653
  config.hidden_size,
654
  sum(self.fused_dims),
655
  bias=config.qkv_bias,
656
- quant_type="nf4",
657
- )
658
 
659
  # Layer norms.
660
  self.k_norm: Optional[Molmo2RMSNorm] = None
@@ -676,12 +661,11 @@ class Molmo2Attention(nn.Module):
676
  self.qk_norm_type = config.qk_norm_type
677
 
678
  self.attention_dropout = config.attention_dropout
679
-
680
- self.attn_out = bnb.nn.Linear4bit(
681
  config.head_dim * config.num_attention_heads,
682
  config.hidden_size,
683
  bias=False,
684
- quant_type="nf4",
685
  )
686
 
687
  def forward(
@@ -752,9 +736,8 @@ class LanguageModelMLP(nn.Module):
752
  device: Union[str, torch.device] = None,
753
  ):
754
  super().__init__()
755
- self.ff_proj = bnb.nn.Linear4bit(input_dim, intermediate_size * 2, bias=False, quant_type="nf4", device=device)
756
- self.ff_out = bnb.nn.Linear4bit(intermediate_size, input_dim, bias=False, quant_type="nf4", device=device)
757
-
758
  self.act = ACT2FN[hidden_act]
759
 
760
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -920,7 +903,7 @@ class Molmo2PreTrainedModel(PreTrainedModel):
920
 
921
  def _init_weights(self, module):
922
  std = self.config.initializer_range
923
- if isinstance(module, (bnb.nn.Linear4bit,)):
924
  module.weight.data.normal_(mean=0.0, std=std)
925
  if module.bias is not None:
926
  module.bias.data.zero_()
@@ -1592,7 +1575,7 @@ class Molmo2ForConditionalGeneration(Molmo2PreTrainedModel, GenerationMixin):
1592
  super().__init__(config)
1593
 
1594
  self.model = Molmo2Model(config)
1595
- self.lm_head = bnb.nn.Linear4bit(config.hidden_size, config.vocab_size, bias=False, quant_type="nf4")
1596
  self.vocab_size = config.vocab_size
1597
 
1598
  # Initialize weights and apply final processing
 
7
  from torch import nn
8
  from torch.nn import functional as F
9
 
 
 
10
  from transformers.models.auto import AutoModelForImageTextToText
11
  from transformers.activations import ACT2FN
12
  from transformers.configuration_utils import PretrainedConfig
 
86
  class ViTMLP(nn.Module):
87
  def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device: Union[str, torch.device] = None):
88
  super().__init__()
89
+ self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
 
90
  self.act = ACT2FN[hidden_act]
91
+ self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
92
 
93
  def forward(self, x: torch.Tensor) -> torch.Tensor:
94
  return self.w2(self.act(self.w1(x)))
 
121
 
122
  input_dim = input_dim or hidden_size
123
 
124
+ self.wq = nn.Linear(
125
  input_dim,
126
  self.num_heads * self.head_dim,
127
  bias=use_bias,
 
128
  device=device,
129
  )
130
+ self.wk = nn.Linear(
 
131
  input_dim,
132
  self.num_key_value_heads * self.head_dim,
133
  bias=use_bias,
 
134
  device=device,
135
  )
136
+ self.wv = nn.Linear(
 
137
  input_dim,
138
  self.num_key_value_heads * self.head_dim,
139
  bias=use_bias,
 
140
  device=device,
141
  )
142
+ self.wo = nn.Linear(
 
143
  self.num_heads * self.head_dim,
144
  self.hidden_size,
 
145
  )
 
146
  self.float32_attention = float32_attention
147
  self.attention_dropout = attention_dropout
148
  self.residual_dropout = nn.Dropout(residual_dropout)
 
294
  )
295
 
296
  image_patch_size = config.image_patch_size
297
+ self.patch_embedding = nn.Linear(
 
298
  image_patch_size * image_patch_size * 3,
299
  config.hidden_size,
300
  bias=True,
 
301
  device=device,
302
  )
303
 
 
354
  device: Union[str, torch.device] = None,
355
  ):
356
  super().__init__()
357
+ self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
358
+ self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
359
+ self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
 
360
  self.act = ACT2FN[hidden_act]
361
 
362
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
635
  config.head_dim * config.num_key_value_heads,
636
  config.head_dim * config.num_key_value_heads,
637
  )
638
+ self.att_proj = nn.Linear(
639
  config.hidden_size,
640
  sum(self.fused_dims),
641
  bias=config.qkv_bias,
642
+ )
 
643
 
644
  # Layer norms.
645
  self.k_norm: Optional[Molmo2RMSNorm] = None
 
661
  self.qk_norm_type = config.qk_norm_type
662
 
663
  self.attention_dropout = config.attention_dropout
664
+
665
+ self.attn_out = nn.Linear(
666
  config.head_dim * config.num_attention_heads,
667
  config.hidden_size,
668
  bias=False,
 
669
  )
670
 
671
  def forward(
 
736
  device: Union[str, torch.device] = None,
737
  ):
738
  super().__init__()
739
+ self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False, device=device)
740
+ self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False, device=device)
 
741
  self.act = ACT2FN[hidden_act]
742
 
743
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
903
 
904
  def _init_weights(self, module):
905
  std = self.config.initializer_range
906
+ if isinstance(module, (nn.Linear,)):
907
  module.weight.data.normal_(mean=0.0, std=std)
908
  if module.bias is not None:
909
  module.bias.data.zero_()
 
1575
  super().__init__(config)
1576
 
1577
  self.model = Molmo2Model(config)
1578
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1579
  self.vocab_size = config.vocab_size
1580
 
1581
  # Initialize weights and apply final processing