Boxue commited on
Commit
4afe42d
·
verified ·
1 Parent(s): 4408865

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_innovator_vl.py +74 -135
modeling_innovator_vl.py CHANGED
@@ -28,46 +28,29 @@ import torch.nn as nn
28
  import torch.nn.functional as F
29
  import torch.utils.checkpoint
30
  from torch.nn import LayerNorm
31
- from transformers import AutoConfig, AutoModelForCausalLM
32
  from transformers.activations import ACT2FN
33
- from transformers.cache_utils import (
34
- Cache,
35
- DynamicCache,
36
- SlidingWindowCache,
37
- StaticCache,
38
- )
39
  from transformers.generation import GenerationMixin
40
- from transformers.integrations import use_kernel_forward_from_hub
41
- from transformers.masking_utils import (
42
- create_causal_mask,
43
- create_sliding_window_causal_mask,
44
- )
45
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
46
- from transformers.modeling_flash_attention_utils import (
47
- FlashAttentionKwargs,
48
- flash_attn_supports_top_left_mask,
49
- is_flash_attn_available,
50
- )
51
  from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
52
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
53
  from transformers.modeling_utils import PreTrainedModel
 
 
54
  from transformers.processing_utils import Unpack
55
- from transformers.utils import (
56
- auto_docstring,
57
- can_return_tuple,
58
- is_torch_flex_attn_available,
59
- is_torchdynamo_compiling,
60
- logging,
61
- )
62
-
63
  from .configuration_innovator_vl import InnovatorVLConfig, InnovatorVLTextConfig, RiceConfig
64
 
 
65
  if is_flash_attn_available():
66
- from flash_attn import flash_attn_varlen_func
67
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
68
 
69
  if is_torch_flex_attn_available():
70
  from torch.nn.attention.flex_attention import BlockMask
 
71
  from transformers.integrations.flex_attention import make_flex_block_causal_mask
72
 
73
 
@@ -75,7 +58,7 @@ logger = logging.get_logger(__name__)
75
 
76
 
77
  @dataclass
78
- class LLaVAOneVision1_5_ModelOutputWithPast(ModelOutput):
79
  """
80
  Base class for Llava outputs, with hidden states and attentions.
81
 
@@ -111,7 +94,7 @@ class LLaVAOneVision1_5_ModelOutputWithPast(ModelOutput):
111
 
112
 
113
  @dataclass
114
- class LLaVAOneVision1_5_CausalLMOutputWithPast(ModelOutput):
115
  """
116
  Base class for LLaVAOneVision1.5 causal language model (or autoregressive) outputs.
117
 
@@ -149,8 +132,8 @@ class LLaVAOneVision1_5_CausalLMOutputWithPast(ModelOutput):
149
  rope_deltas: Optional[torch.LongTensor] = None
150
 
151
 
152
- class LLaVAOneVision1_5_RotaryEmbedding(nn.Module):
153
- def __init__(self, config: LLaVAOneVision1_5_TextConfig, device=None):
154
  super().__init__()
155
  # BC: "rope_type" was originally "type"
156
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@@ -465,10 +448,10 @@ class RiceBlock(nn.Module):
465
 
466
 
467
  @use_kernel_forward_from_hub("RMSNorm")
468
- class LLaVAOneVision1_5_RMSNorm(nn.Module):
469
  def __init__(self, hidden_size, eps=1e-6):
470
  """
471
- LLaVAOneVision1_5_RMSNorm is equivalent to T5LayerNorm
472
  """
473
  super().__init__()
474
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -486,7 +469,7 @@ class LLaVAOneVision1_5_RMSNorm(nn.Module):
486
 
487
 
488
 
489
- class LLaVAOneVision1_5_MLP(nn.Module):
490
  def __init__(self, config):
491
  super().__init__()
492
  self.config = config
@@ -515,13 +498,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
515
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
516
 
517
 
518
- class LLaVAOneVision1_5_Attention(nn.Module):
519
  """
520
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
521
  and "Generating Long Sequences with Sparse Transformers".
522
  """
523
 
524
- def __init__(self, config: LLaVAOneVision1_5_TextConfig, layer_idx: Optional[int] = None):
525
  super().__init__()
526
  self.config = config
527
  self.layer_idx = layer_idx
@@ -544,8 +527,8 @@ class LLaVAOneVision1_5_Attention(nn.Module):
544
  self.o_proj = nn.Linear(
545
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
546
  )
547
- self.q_norm = LLaVAOneVision1_5_RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
548
- self.k_norm = LLaVAOneVision1_5_RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
549
  self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
550
 
551
  def forward(
@@ -612,9 +595,9 @@ class LLaVAOneVision1_5_Attention(nn.Module):
612
  return attn_output, attn_weights, past_key_value
613
 
614
 
615
- class LLaVAOneVision1_5_FlashAttention2(LLaVAOneVision1_5_Attention):
616
  """
617
- LLaVAOneVision1_5 flash attention module, following Qwen2VL attention module. This module inherits from `LLaVAOneVision1_5_Attention`
618
  as the weights of the module stays untouched. The only required change would be on the forward pass
619
  where it needs to correctly call the public API of flash attention and deal with padding tokens
620
  in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
@@ -718,14 +701,14 @@ class LLaVAOneVision1_5_FlashAttention2(LLaVAOneVision1_5_Attention):
718
  return attn_output, attn_weights, past_key_value
719
 
720
 
721
- class LLaVAOneVision1_5_SdpaAttention(LLaVAOneVision1_5_Attention):
722
  """
723
  LLaVAOneVision1_51.5 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
724
- `LLaVAOneVision1_5_Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
725
  SDPA API.
726
  """
727
 
728
- # Adapted from LLaVAOneVision1_5_Attention.forward
729
  def forward(
730
  self,
731
  hidden_states: torch.Tensor,
@@ -805,28 +788,28 @@ class LLaVAOneVision1_5_SdpaAttention(LLaVAOneVision1_5_Attention):
805
  return attn_output, None, past_key_value
806
 
807
 
808
- LLaVAOneVision1_5_ATTENTION_CLASSES = {
809
- "eager": LLaVAOneVision1_5_Attention,
810
- "flash_attention_2": LLaVAOneVision1_5_FlashAttention2,
811
- "sdpa": LLaVAOneVision1_5_SdpaAttention,
812
  }
813
 
814
 
815
- class LLaVAOneVision1_5_DecoderLayer(nn.Module):
816
- def __init__(self, config: LLaVAOneVision1_5_TextConfig, layer_idx: int):
817
  super().__init__()
818
  self.hidden_size = config.hidden_size
 
819
  if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
820
  logger.warning_once(
821
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
822
  "unexpected results may be encountered."
823
  )
824
- self.self_attn = LLaVAOneVision1_5_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
825
 
826
- self.mlp = LLaVAOneVision1_5_MLP(config)
827
- self.input_layernorm = LLaVAOneVision1_5_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
828
- self.post_attention_layernorm = LLaVAOneVision1_5_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
829
- self.attention_type = config.layer_types[layer_idx]
830
 
831
  def forward(
832
  self,
@@ -895,10 +878,10 @@ class LLaVAOneVision1_5_DecoderLayer(nn.Module):
895
 
896
  @auto_docstring
897
  class Qwen2VLPreTrainedModel(PreTrainedModel):
898
- config_class = Llavaonevision1_5Config
899
  base_model_prefix = "model"
900
  supports_gradient_checkpointing = True
901
- _no_split_modules = ["LLaVAOneVision1_5_DecoderLayer", "RiceBlock"]
902
  _skip_keys_device_placement = "past_key_values"
903
  _supports_flash_attn_2 = True
904
  _supports_sdpa = True
@@ -918,7 +901,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
918
  elif isinstance(module, nn.LayerNorm):
919
  module.weight.data.fill_(1.0)
920
  module.bias.data.zero_()
921
- elif isinstance(module, LLaVAOneVision1_5_RMSNorm):
922
  module.weight.data.fill_(1.0)
923
 
924
 
@@ -1103,23 +1086,21 @@ class RiceTransformerPretrainedModel(Qwen2VLPreTrainedModel):
1103
 
1104
 
1105
  @auto_docstring
1106
- class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1107
- config_class = LLaVAOneVision1_5_TextConfig
1108
 
1109
- def __init__(self, config: LLaVAOneVision1_5_TextConfig):
1110
  super().__init__(config)
1111
  self.padding_idx = config.pad_token_id
1112
  self.vocab_size = config.vocab_size
1113
 
1114
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1115
  self.layers = nn.ModuleList(
1116
- [LLaVAOneVision1_5_DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1117
  )
1118
  self._attn_implementation = config._attn_implementation
1119
- self.norm = LLaVAOneVision1_5_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1120
- self.rotary_emb = LLaVAOneVision1_5_RotaryEmbedding(config=config)
1121
-
1122
- self.has_sliding_layers = "sliding_attention" in self.config.layer_types
1123
 
1124
  self.gradient_checkpointing = False
1125
  # Initialize weights and apply final processing
@@ -1182,24 +1163,9 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1182
  # elif position_ids.dim() == 2: # 这是为了3drope准备的
1183
  # position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
1184
 
1185
- # It may already have been prepared by e.g. `generate`
1186
- if not isinstance(causal_mask_mapping := attention_mask, dict):
1187
- # Prepare mask arguments
1188
- mask_kwargs = {
1189
- "config": self.config,
1190
- "input_embeds": inputs_embeds,
1191
- "attention_mask": attention_mask,
1192
- "cache_position": cache_position,
1193
- "past_key_values": past_key_values,
1194
- "position_ids": position_ids,
1195
- }
1196
- # Create the masks
1197
- causal_mask_mapping = {
1198
- "full_attention": create_causal_mask(**mask_kwargs),
1199
- }
1200
- # The sliding window alternating layers are not always activated depending on the config
1201
- if self.has_sliding_layers:
1202
- causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1203
 
1204
  hidden_states = inputs_embeds
1205
 
@@ -1211,7 +1177,6 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1211
  all_self_attns = () if output_attentions else None
1212
  next_decoder_cache = None
1213
 
1214
-
1215
  for decoder_layer in self.layers:
1216
  if output_hidden_states:
1217
  all_hidden_states += (hidden_states,)
@@ -1220,7 +1185,7 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1220
  layer_outputs = self._gradient_checkpointing_func(
1221
  decoder_layer.__call__,
1222
  hidden_states,
1223
- causal_mask_mapping[decoder_layer.attention_type],
1224
  position_ids,
1225
  past_key_values,
1226
  output_attentions,
@@ -1231,7 +1196,7 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1231
  else:
1232
  layer_outputs = decoder_layer(
1233
  hidden_states,
1234
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
1235
  position_ids=position_ids,
1236
  past_key_value=past_key_values,
1237
  output_attentions=output_attentions,
@@ -1359,7 +1324,7 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1359
  dtype: torch.dtype,
1360
  cache_position: torch.Tensor,
1361
  batch_size: int,
1362
- config: Llavaonevision1_5Config,
1363
  past_key_values: Cache,
1364
  ):
1365
  """
@@ -1379,7 +1344,7 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1379
  Indices depicting the position of the input sequence tokens in the sequence.
1380
  batch_size (`torch.Tensor`):
1381
  Batch size.
1382
- config (`Llavaonevision1_5Config`):
1383
  The model's configuration class
1384
  past_key_values (`Cache`):
1385
  The cache class that is being used currently to generate
@@ -1422,14 +1387,14 @@ class LLaVAOneVision1_5_TextModel(Qwen2VLPreTrainedModel):
1422
 
1423
 
1424
  @auto_docstring
1425
- class LLaVAOneVision1_5_Model(Qwen2VLPreTrainedModel):
1426
  base_model_prefix = ""
1427
  _checkpoint_conversion_mapping = {"^model": "language_model"}
1428
 
1429
- def __init__(self, config: Llavaonevision1_5Config):
1430
  super().__init__(config)
1431
  self.visual = RiceTransformerPretrainedModel._from_config(config.vision_config)
1432
- self.language_model = LLaVAOneVision1_5_TextModel._from_config(config.text_config)
1433
  self.rope_deltas = None # cache rope_deltas here
1434
 
1435
  # Initialize weights and apply final processing
@@ -1638,7 +1603,7 @@ class LLaVAOneVision1_5_Model(Qwen2VLPreTrainedModel):
1638
  video_grid_thw: Optional[torch.LongTensor] = None,
1639
  rope_deltas: Optional[torch.LongTensor] = None,
1640
  cache_position: Optional[torch.LongTensor] = None,
1641
- ) -> Union[Tuple, LLaVAOneVision1_5_ModelOutputWithPast]:
1642
  r"""
1643
  pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
1644
  The tensors corresponding to the input videos. Pixel values can be obtained using
@@ -1665,25 +1630,9 @@ class LLaVAOneVision1_5_Model(Qwen2VLPreTrainedModel):
1665
  n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
1666
  n_image_features = image_embeds.shape[0]
1667
  if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
1668
- if abs(n_image_tokens - n_image_features) <= 10:
1669
- logger.warning_once(
1670
- f"!!!!!!!! Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}. "
1671
- f"This may be caused by dynamic image sizes during training. Try to fix it. !!!!!!!!"
1672
- )
1673
- if n_image_tokens > n_image_features:
1674
- diff = n_image_tokens - n_image_features
1675
- pad_embeds = torch.zeros(
1676
- (diff, image_embeds.shape[1]),
1677
- dtype=image_embeds.dtype,
1678
- device=image_embeds.device,
1679
- )
1680
- image_embeds = torch.cat([image_embeds, pad_embeds], dim=0)
1681
- else:
1682
- image_embeds = image_embeds[:n_image_tokens, :]
1683
- else:
1684
- raise ValueError(
1685
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1686
- )
1687
  image_mask = (
1688
  (input_ids == self.config.image_token_id)
1689
  .unsqueeze(-1)
@@ -1710,7 +1659,7 @@ class LLaVAOneVision1_5_Model(Qwen2VLPreTrainedModel):
1710
  video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
1711
  inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1712
 
1713
- if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
1714
  attention_mask = attention_mask.to(inputs_embeds.device)
1715
 
1716
  if use_cache and past_key_values is None:
@@ -1739,7 +1688,7 @@ class LLaVAOneVision1_5_Model(Qwen2VLPreTrainedModel):
1739
  cache_position=cache_position,
1740
  )
1741
 
1742
- output = LLaVAOneVision1_5_ModelOutputWithPast(
1743
  last_hidden_state=outputs.last_hidden_state,
1744
  past_key_values=outputs.past_key_values,
1745
  hidden_states=outputs.hidden_states,
@@ -1805,7 +1754,7 @@ class LLaVAOneVision1_5_Model(Qwen2VLPreTrainedModel):
1805
  return causal_mask
1806
 
1807
 
1808
- class LLaVAOneVision1_5_ForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1809
  _checkpoint_conversion_mapping = {
1810
  "^visual": "model.visual",
1811
  r"^model(?!\.(language_model|visual))": "model.language_model",
@@ -1814,7 +1763,7 @@ class LLaVAOneVision1_5_ForConditionalGeneration(Qwen2VLPreTrainedModel, Generat
1814
 
1815
  def __init__(self, config):
1816
  super().__init__(config)
1817
- self.model = LLaVAOneVision1_5_Model(config)
1818
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1819
 
1820
  self.post_init()
@@ -1866,9 +1815,7 @@ class LLaVAOneVision1_5_ForConditionalGeneration(Qwen2VLPreTrainedModel, Generat
1866
  video_grid_thw: Optional[torch.LongTensor] = None,
1867
  rope_deltas: Optional[torch.LongTensor] = None,
1868
  cache_position: Optional[torch.LongTensor] = None,
1869
- *args,
1870
- **kwargs,
1871
- ) -> Union[Tuple, LLaVAOneVision1_5_CausalLMOutputWithPast]:
1872
  r"""
1873
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1874
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -1915,7 +1862,6 @@ class LLaVAOneVision1_5_ForConditionalGeneration(Qwen2VLPreTrainedModel, Generat
1915
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1916
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1917
  ```"""
1918
- position_ids = None
1919
 
1920
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1921
  output_hidden_states = (
@@ -1945,23 +1891,11 @@ class LLaVAOneVision1_5_ForConditionalGeneration(Qwen2VLPreTrainedModel, Generat
1945
  hidden_states = outputs[0]
1946
  logits = self.lm_head(hidden_states)
1947
 
1948
- # with torch.no_grad():
1949
- # log_probs = torch.nn.functional.log_softmax(logits.float() / 1, dim=-1)
1950
- # entropy = -torch.sum(log_probs.exp() * log_probs, dim=-1).squeeze()
1951
- # if entropy.ndim != 1:
1952
- # entropy = entropy.unsqueeze(0)
1953
- # if hasattr(self, "entropy"):
1954
- # self.entropy = torch.cat([self.entropy, entropy], dim=-1)
1955
- # else:
1956
- # self.entropy = entropy
1957
-
1958
- # print(self.entropy.mean())
1959
-
1960
  loss = None
1961
  if labels is not None:
1962
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
1963
 
1964
- return LLaVAOneVision1_5_CausalLMOutputWithPast(
1965
  loss=loss,
1966
  logits=logits,
1967
  past_key_values=outputs.past_key_values,
@@ -2133,8 +2067,13 @@ class LLaVAOneVision1_5_ForConditionalGeneration(Qwen2VLPreTrainedModel, Generat
2133
  return input_ids, model_kwargs
2134
 
2135
 
2136
- __all__ = ["LLaVAOneVision1_5_ForConditionalGeneration", "LLaVAOneVision1_5_Model", "Qwen2VLPreTrainedModel", "LLaVAOneVision1_5_TextModel"]
 
 
 
 
 
2137
 
2138
 
2139
- AutoConfig.register("llavaonevision1_5", Llavaonevision1_5Config)
2140
- AutoModelForCausalLM.register(Llavaonevision1_5Config, LLaVAOneVision1_5_ForConditionalGeneration)
 
28
  import torch.nn.functional as F
29
  import torch.utils.checkpoint
30
  from torch.nn import LayerNorm
31
+
32
  from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
 
 
 
 
 
34
  from transformers.generation import GenerationMixin
 
 
 
 
 
35
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+ from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
 
 
 
 
37
  from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
38
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
  from transformers.modeling_utils import PreTrainedModel
40
+ from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
41
+ from transformers.integrations import use_kernel_forward_from_hub
42
  from transformers.processing_utils import Unpack
43
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
44
+ from transformers import AutoModelForCausalLM, AutoConfig
 
 
 
 
 
 
45
  from .configuration_innovator_vl import InnovatorVLConfig, InnovatorVLTextConfig, RiceConfig
46
 
47
+
48
  if is_flash_attn_available():
49
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func
 
50
 
51
  if is_torch_flex_attn_available():
52
  from torch.nn.attention.flex_attention import BlockMask
53
+
54
  from transformers.integrations.flex_attention import make_flex_block_causal_mask
55
 
56
 
 
58
 
59
 
60
  @dataclass
61
+ class InnovatorVLModelOutputWithPast(ModelOutput):
62
  """
63
  Base class for Llava outputs, with hidden states and attentions.
64
 
 
94
 
95
 
96
  @dataclass
97
+ class InnovatorVLCausalLMOutputWithPast(ModelOutput):
98
  """
99
  Base class for LLaVAOneVision1.5 causal language model (or autoregressive) outputs.
100
 
 
132
  rope_deltas: Optional[torch.LongTensor] = None
133
 
134
 
135
+ class InnovatorVL_RotaryEmbedding(nn.Module):
136
+ def __init__(self, config: InnovatorVLTextConfig, device=None):
137
  super().__init__()
138
  # BC: "rope_type" was originally "type"
139
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 
448
 
449
 
450
  @use_kernel_forward_from_hub("RMSNorm")
451
+ class InnovatorVL_RMSNorm(nn.Module):
452
  def __init__(self, hidden_size, eps=1e-6):
453
  """
454
+ InnovatorVL_RMSNorm is equivalent to T5LayerNorm
455
  """
456
  super().__init__()
457
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
469
 
470
 
471
 
472
+ class InnovatorVL_MLP(nn.Module):
473
  def __init__(self, config):
474
  super().__init__()
475
  self.config = config
 
498
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
499
 
500
 
501
+ class InnovatorVL_Attention(nn.Module):
502
  """
503
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
504
  and "Generating Long Sequences with Sparse Transformers".
505
  """
506
 
507
+ def __init__(self, config: InnovatorVLTextConfig, layer_idx: Optional[int] = None):
508
  super().__init__()
509
  self.config = config
510
  self.layer_idx = layer_idx
 
527
  self.o_proj = nn.Linear(
528
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
529
  )
530
+ self.q_norm = InnovatorVL_RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
531
+ self.k_norm = InnovatorVL_RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
532
  self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
533
 
534
  def forward(
 
595
  return attn_output, attn_weights, past_key_value
596
 
597
 
598
+ class InnovatorVL_FlashAttention2(InnovatorVL_Attention):
599
  """
600
+ LLaVAOneVision1_5 flash attention module, following Qwen2VL attention module. This module inherits from `InnovatorVL_Attention`
601
  as the weights of the module stays untouched. The only required change would be on the forward pass
602
  where it needs to correctly call the public API of flash attention and deal with padding tokens
603
  in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
 
701
  return attn_output, attn_weights, past_key_value
702
 
703
 
704
+ class InnovatorVL_SdpaAttention(InnovatorVL_Attention):
705
  """
706
  LLaVAOneVision1_51.5 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
707
+ `InnovatorVL_Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
708
  SDPA API.
709
  """
710
 
711
+ # Adapted from InnovatorVL_Attention.forward
712
  def forward(
713
  self,
714
  hidden_states: torch.Tensor,
 
788
  return attn_output, None, past_key_value
789
 
790
 
791
+ InnovatorVL_ATTENTION_CLASSES = {
792
+ "eager": InnovatorVL_Attention,
793
+ "flash_attention_2": InnovatorVL_FlashAttention2,
794
+ "sdpa": InnovatorVL_SdpaAttention,
795
  }
796
 
797
 
798
+ class InnovatorVL_DecoderLayer(nn.Module):
799
+ def __init__(self, config: InnovatorVLTextConfig, layer_idx: int):
800
  super().__init__()
801
  self.hidden_size = config.hidden_size
802
+
803
  if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
804
  logger.warning_once(
805
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
806
  "unexpected results may be encountered."
807
  )
808
+ self.self_attn = InnovatorVL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
809
 
810
+ self.mlp = InnovatorVL_MLP(config)
811
+ self.input_layernorm = InnovatorVL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
812
+ self.post_attention_layernorm = InnovatorVL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
813
 
814
  def forward(
815
  self,
 
878
 
879
  @auto_docstring
880
  class Qwen2VLPreTrainedModel(PreTrainedModel):
881
+ config_class = InnovatorVLConfig
882
  base_model_prefix = "model"
883
  supports_gradient_checkpointing = True
884
+ _no_split_modules = ["InnovatorVL_DecoderLayer", "RiceBlock"]
885
  _skip_keys_device_placement = "past_key_values"
886
  _supports_flash_attn_2 = True
887
  _supports_sdpa = True
 
901
  elif isinstance(module, nn.LayerNorm):
902
  module.weight.data.fill_(1.0)
903
  module.bias.data.zero_()
904
+ elif isinstance(module, InnovatorVL_RMSNorm):
905
  module.weight.data.fill_(1.0)
906
 
907
 
 
1086
 
1087
 
1088
  @auto_docstring
1089
+ class InnovatorVLTextModel(Qwen2VLPreTrainedModel):
1090
+ config_class = InnovatorVLTextConfig
1091
 
1092
+ def __init__(self, config: InnovatorVLTextConfig):
1093
  super().__init__(config)
1094
  self.padding_idx = config.pad_token_id
1095
  self.vocab_size = config.vocab_size
1096
 
1097
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1098
  self.layers = nn.ModuleList(
1099
+ [InnovatorVL_DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1100
  )
1101
  self._attn_implementation = config._attn_implementation
1102
+ self.norm = InnovatorVL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1103
+ self.rotary_emb = InnovatorVL_RotaryEmbedding(config=config)
 
 
1104
 
1105
  self.gradient_checkpointing = False
1106
  # Initialize weights and apply final processing
 
1163
  # elif position_ids.dim() == 2: # 这是为了3drope准备的
1164
  # position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
1165
 
1166
+ causal_mask = self._update_causal_mask(
1167
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1168
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1169
 
1170
  hidden_states = inputs_embeds
1171
 
 
1177
  all_self_attns = () if output_attentions else None
1178
  next_decoder_cache = None
1179
 
 
1180
  for decoder_layer in self.layers:
1181
  if output_hidden_states:
1182
  all_hidden_states += (hidden_states,)
 
1185
  layer_outputs = self._gradient_checkpointing_func(
1186
  decoder_layer.__call__,
1187
  hidden_states,
1188
+ causal_mask,
1189
  position_ids,
1190
  past_key_values,
1191
  output_attentions,
 
1196
  else:
1197
  layer_outputs = decoder_layer(
1198
  hidden_states,
1199
+ attention_mask=causal_mask,
1200
  position_ids=position_ids,
1201
  past_key_value=past_key_values,
1202
  output_attentions=output_attentions,
 
1324
  dtype: torch.dtype,
1325
  cache_position: torch.Tensor,
1326
  batch_size: int,
1327
+ config: InnovatorVLConfig,
1328
  past_key_values: Cache,
1329
  ):
1330
  """
 
1344
  Indices depicting the position of the input sequence tokens in the sequence.
1345
  batch_size (`torch.Tensor`):
1346
  Batch size.
1347
+ config (`InnovatorVLConfig`):
1348
  The model's configuration class
1349
  past_key_values (`Cache`):
1350
  The cache class that is being used currently to generate
 
1387
 
1388
 
1389
  @auto_docstring
1390
+ class InnovatorVLModel(Qwen2VLPreTrainedModel):
1391
  base_model_prefix = ""
1392
  _checkpoint_conversion_mapping = {"^model": "language_model"}
1393
 
1394
+ def __init__(self, config: InnovatorVLConfig):
1395
  super().__init__(config)
1396
  self.visual = RiceTransformerPretrainedModel._from_config(config.vision_config)
1397
+ self.language_model = InnovatorVLTextModel._from_config(config.text_config)
1398
  self.rope_deltas = None # cache rope_deltas here
1399
 
1400
  # Initialize weights and apply final processing
 
1603
  video_grid_thw: Optional[torch.LongTensor] = None,
1604
  rope_deltas: Optional[torch.LongTensor] = None,
1605
  cache_position: Optional[torch.LongTensor] = None,
1606
+ ) -> Union[Tuple, InnovatorVLModelOutputWithPast]:
1607
  r"""
1608
  pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
1609
  The tensors corresponding to the input videos. Pixel values can be obtained using
 
1630
  n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
1631
  n_image_features = image_embeds.shape[0]
1632
  if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
1633
+ raise ValueError(
1634
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
1635
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1636
  image_mask = (
1637
  (input_ids == self.config.image_token_id)
1638
  .unsqueeze(-1)
 
1659
  video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
1660
  inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1661
 
1662
+ if attention_mask is not None:
1663
  attention_mask = attention_mask.to(inputs_embeds.device)
1664
 
1665
  if use_cache and past_key_values is None:
 
1688
  cache_position=cache_position,
1689
  )
1690
 
1691
+ output = InnovatorVLModelOutputWithPast(
1692
  last_hidden_state=outputs.last_hidden_state,
1693
  past_key_values=outputs.past_key_values,
1694
  hidden_states=outputs.hidden_states,
 
1754
  return causal_mask
1755
 
1756
 
1757
+ class InnovatorVLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1758
  _checkpoint_conversion_mapping = {
1759
  "^visual": "model.visual",
1760
  r"^model(?!\.(language_model|visual))": "model.language_model",
 
1763
 
1764
  def __init__(self, config):
1765
  super().__init__(config)
1766
+ self.model = InnovatorVLModel(config)
1767
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1768
 
1769
  self.post_init()
 
1815
  video_grid_thw: Optional[torch.LongTensor] = None,
1816
  rope_deltas: Optional[torch.LongTensor] = None,
1817
  cache_position: Optional[torch.LongTensor] = None,
1818
+ ) -> Union[Tuple, InnovatorVLCausalLMOutputWithPast]:
 
 
1819
  r"""
1820
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1821
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 
1862
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1863
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1864
  ```"""
 
1865
 
1866
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1867
  output_hidden_states = (
 
1891
  hidden_states = outputs[0]
1892
  logits = self.lm_head(hidden_states)
1893
 
 
 
 
 
 
 
 
 
 
 
 
 
1894
  loss = None
1895
  if labels is not None:
1896
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
1897
 
1898
+ return InnovatorVLCausalLMOutputWithPast(
1899
  loss=loss,
1900
  logits=logits,
1901
  past_key_values=outputs.past_key_values,
 
2067
  return input_ids, model_kwargs
2068
 
2069
 
2070
+ __all__ = [
2071
+ "InnovatorVLForConditionalGeneration",
2072
+ "InnovatorVLModel",
2073
+ "InnovatorVLTextModel",
2074
+ "Qwen2VLPreTrainedModel",
2075
+ ]
2076
 
2077
 
2078
+ AutoConfig.register("innovator_vl", InnovatorVLConfig)
2079
+ AutoModelForCausalLM.register(InnovatorVLConfig, InnovatorVLForConditionalGeneration)