silveroxides commited on
Commit
c96e937
·
verified ·
1 Parent(s): feba260

Update modeling_florence2.py

Browse files
Files changed (1) hide show
  1. modeling_florence2.py +24 -26
modeling_florence2.py CHANGED
@@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
26
  from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange
29
- from timm.models.layers import DropPath, trunc_normal_
30
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.generation.utils import GenerationMixin
@@ -610,29 +610,10 @@ class DaViT(nn.Module):
610
  self.avgpool = nn.AdaptiveAvgPool1d(1)
611
  self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
612
 
613
- self.apply(self._init_weights)
614
-
615
  @property
616
  def dim_out(self):
617
  return self.embed_dims[-1]
618
 
619
- def _init_weights(self, m):
620
- if isinstance(m, nn.Linear):
621
- trunc_normal_(m.weight, std=0.02)
622
- if m.bias is not None:
623
- nn.init.constant_(m.bias, 0)
624
- elif isinstance(m, nn.Conv2d):
625
- nn.init.normal_(m.weight, std=0.02)
626
- for name, _ in m.named_parameters():
627
- if name in ['bias']:
628
- nn.init.constant_(m.bias, 0)
629
- elif isinstance(m, nn.LayerNorm):
630
- nn.init.constant_(m.weight, 1.0)
631
- nn.init.constant_(m.bias, 0)
632
- elif isinstance(m, nn.BatchNorm2d):
633
- nn.init.constant_(m.weight, 1.0)
634
- nn.init.constant_(m.bias, 0)
635
-
636
  def forward_features_unpool(self, x):
637
  """
638
  forward until avg pooling
@@ -1451,6 +1432,17 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
1451
  module.weight.data.normal_(mean=0.0, std=std)
1452
  if module.padding_idx is not None:
1453
  module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
 
 
 
 
 
1454
 
1455
  @property
1456
  def dummy_inputs(self):
@@ -2074,14 +2066,20 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2074
  # Initialize weights and apply final processing
2075
  self.post_init()
2076
 
 
 
 
 
 
 
2077
  def get_encoder(self):
2078
  return self.model.get_encoder()
2079
 
2080
  def get_decoder(self):
2081
  return self.model.get_decoder()
2082
 
2083
- def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
2084
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
2085
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2086
  return new_embeddings
2087
 
@@ -2531,6 +2529,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2531
  FLORENCE2_START_DOCSTRING,
2532
  )
2533
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
 
 
2534
  def __init__(self, config: Florence2Config):
2535
  super().__init__(config)
2536
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
@@ -2545,8 +2545,6 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2545
 
2546
  language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
2547
 
2548
- if language_model._tied_weights_keys is not None:
2549
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
2550
  self.language_model = language_model
2551
 
2552
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@@ -2589,8 +2587,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2589
  def get_input_embeddings(self):
2590
  return self.language_model.get_input_embeddings()
2591
 
2592
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
2593
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
2594
  # update vocab size
2595
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2596
  self.config.vocab_size = model_embeds.num_embeddings
 
26
  from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange
29
+ from timm.layers import DropPath, trunc_normal_
30
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.generation.utils import GenerationMixin
 
610
  self.avgpool = nn.AdaptiveAvgPool1d(1)
611
  self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
612
 
 
 
613
  @property
614
  def dim_out(self):
615
  return self.embed_dims[-1]
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  def forward_features_unpool(self, x):
618
  """
619
  forward until avg pooling
 
1432
  module.weight.data.normal_(mean=0.0, std=std)
1433
  if module.padding_idx is not None:
1434
  module.weight.data[module.padding_idx].zero_()
1435
+ elif isinstance(module, nn.Conv2d):
1436
+ nn.init.normal_(module.weight, std=0.02)
1437
+ for name, _ in module.named_parameters():
1438
+ if name == "bias":
1439
+ nn.init.constant_(module.bias, 0)
1440
+ elif isinstance(module, nn.LayerNorm):
1441
+ nn.init.constant_(module.weight, 1.0)
1442
+ nn.init.constant_(module.bias, 0)
1443
+ elif isinstance(module, nn.BatchNorm2d):
1444
+ nn.init.constant_(module.weight, 1.0)
1445
+ nn.init.constant_(module.bias, 0)
1446
 
1447
  @property
1448
  def dummy_inputs(self):
 
2066
  # Initialize weights and apply final processing
2067
  self.post_init()
2068
 
2069
+ def _tie_weights(self):
2070
+ if self.config.tie_word_embeddings:
2071
+ self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
2072
+ self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
2073
+ self._tie_or_clone_weights(self.lm_head, self.model.shared)
2074
+
2075
  def get_encoder(self):
2076
  return self.model.get_encoder()
2077
 
2078
  def get_decoder(self):
2079
  return self.model.get_decoder()
2080
 
2081
+ def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, **kwargs) -> nn.Embedding:
2082
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
2083
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2084
  return new_embeddings
2085
 
 
2529
  FLORENCE2_START_DOCSTRING,
2530
  )
2531
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2532
+ _tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
2533
+
2534
  def __init__(self, config: Florence2Config):
2535
  super().__init__(config)
2536
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
 
2545
 
2546
  language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
2547
 
 
 
2548
  self.language_model = language_model
2549
 
2550
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
 
2587
  def get_input_embeddings(self):
2588
  return self.language_model.get_input_embeddings()
2589
 
2590
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, **kwargs) -> nn.Embedding:
2591
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
2592
  # update vocab size
2593
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2594
  self.config.vocab_size = model_embeds.num_embeddings