xiangan commited on
Commit
5486249
·
verified ·
1 Parent(s): b7cefb0

Upload folder using huggingface_hub

Browse files
configuration_onevision_encoder.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
 
4
  logger = logging.get_logger(__name__)
5
 
6
 
@@ -26,7 +27,7 @@ class OneVisionEncoderConfig(PretrainedConfig):
26
  The number of input channels.
27
  image_size (`int`, *optional*, defaults to 224):
28
  The size (resolution) of each image.
29
- patch_size (`int`, *optional*, defaults to 16):
30
  The size (resolution) of each patch.
31
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
32
  The non-linear activation function (function or string) in the encoder and pooler.
@@ -70,7 +71,7 @@ class OneVisionEncoderConfig(PretrainedConfig):
70
  num_attention_heads=16,
71
  num_channels=3,
72
  image_size=448,
73
- patch_size=16,
74
  hidden_act="gelu",
75
  layer_norm_eps=1e-6,
76
  layer_norm_type="layer_norm",
 
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
4
+
5
  logger = logging.get_logger(__name__)
6
 
7
 
 
27
  The number of input channels.
28
  image_size (`int`, *optional*, defaults to 224):
29
  The size (resolution) of each image.
30
+ patch_size (`int`, *optional*, defaults to 14):
31
  The size (resolution) of each patch.
32
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
33
  The non-linear activation function (function or string) in the encoder and pooler.
 
71
  num_attention_heads=16,
72
  num_channels=3,
73
  image_size=448,
74
+ patch_size=14,
75
  hidden_act="gelu",
76
  layer_norm_eps=1e-6,
77
  layer_norm_type="layer_norm",
modeling_onevision_encoder.py CHANGED
@@ -2,18 +2,23 @@ from typing import Optional, Tuple, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
- from transformers.modeling_outputs import (BaseModelOutput,
6
- BaseModelOutputWithPooling)
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.models.siglip.modeling_siglip import SiglipMLP
9
- from transformers.utils import (add_start_docstrings,
10
- add_start_docstrings_to_model_forward, logging,
11
- replace_return_docstrings)
 
 
 
12
 
13
  from .configuration_onevision_encoder import OneVisionEncoderConfig
14
 
 
15
  try:
16
  from flash_attn import flash_attn_func
 
17
  _flash_attn_available = True
18
  except ImportError:
19
  _flash_attn_available = False
@@ -61,6 +66,7 @@ ONEVISION_ENCODER_INPUTS_DOCSTRING = r"""
61
  # Helper Functions & Layers
62
  # ---------------------------------------------------------------------------
63
 
 
64
  def get_norm_layer(config):
65
  if config.layer_norm_type == "rms_norm":
66
  return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -100,6 +106,7 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
100
  """
101
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
102
  """
 
103
  def __init__(self, config: OneVisionEncoderConfig):
104
  super().__init__()
105
  head_dim = config.hidden_size // config.num_attention_heads
@@ -118,12 +125,25 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
118
  self.h_size = 6 * unit
119
  self.w_size = 6 * unit
120
 
121
- self.register_buffer("inv_freq_t", 1.0 / (base ** (torch.arange(self.t_size, dtype=torch.float32) / self.t_size)), persistent=False)
122
- self.register_buffer("inv_freq_h", 1.0 / (base ** (torch.arange(self.h_size, dtype=torch.float32) / self.h_size)), persistent=False)
123
- self.register_buffer("inv_freq_w", 1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)), persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def forward(self, t: int, h: int, w: int, device=None):
126
- if device is None: device = self.inv_freq_t.device
 
127
 
128
  inv_t = self.inv_freq_t.to(device=device)
129
  inv_h = self.inv_freq_h.to(device=device)
@@ -140,11 +160,38 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
140
  freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
141
  return freqs
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  class Siglip2MultiheadAttentionPoolingHead(nn.Module):
145
  """
146
  Multi-Head Attention Pooling with a learned probe (PMA-style).
147
  """
 
148
  def __init__(self, config: OneVisionEncoderConfig):
149
  super().__init__()
150
  self.embed_dim = config.hidden_size
@@ -170,6 +217,7 @@ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
170
  # Modeling Components
171
  # ---------------------------------------------------------------------------
172
 
 
173
  class OneVisionEncoderEmbeddings(nn.Module):
174
  def __init__(self, config: OneVisionEncoderConfig):
175
  super().__init__()
@@ -189,7 +237,7 @@ class OneVisionEncoderEmbeddings(nn.Module):
189
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
190
  # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
191
  if pixel_values.dim() == 4:
192
- pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
193
 
194
  batch_size, channels, t_frames, height, width = pixel_values.shape
195
 
@@ -198,7 +246,7 @@ class OneVisionEncoderEmbeddings(nn.Module):
198
 
199
  # Patch Embed
200
  embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
201
- embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
202
 
203
  # Flatten all patches
204
  total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
@@ -209,6 +257,7 @@ class OneVisionEncoderEmbeddings(nn.Module):
209
 
210
  class OneVisionEncoderAttention(nn.Module):
211
  """Multi-headed attention with RoPE support"""
 
212
  def __init__(self, config: OneVisionEncoderConfig):
213
  super().__init__()
214
  self.config = config
@@ -216,7 +265,7 @@ class OneVisionEncoderAttention(nn.Module):
216
  self.num_heads = config.num_attention_heads
217
  self.head_dim = self.embed_dim // self.num_heads
218
  if self.head_dim * self.num_heads != self.embed_dim:
219
- raise ValueError(
220
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
221
  )
222
 
@@ -228,7 +277,6 @@ class OneVisionEncoderAttention(nn.Module):
228
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
229
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
230
 
231
-
232
  def forward(
233
  self,
234
  hidden_states: torch.Tensor,
@@ -236,7 +284,6 @@ class OneVisionEncoderAttention(nn.Module):
236
  rotary_pos_emb: Optional[torch.Tensor] = None,
237
  output_attentions: bool = False,
238
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
239
-
240
  batch_size, q_len, _ = hidden_states.size()
241
 
242
  query_states = self.q_proj(hidden_states)
@@ -257,7 +304,7 @@ class OneVisionEncoderAttention(nn.Module):
257
  if attention_mask is not None:
258
  if attention_mask.size() != (batch_size, 1, q_len, q_len):
259
  if attention_mask.dim() == 3:
260
- attention_mask = attention_mask.unsqueeze(1)
261
  attn_weights = attn_weights + attention_mask
262
 
263
  # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
@@ -280,6 +327,7 @@ class OneVisionEncoderFlashAttention2(nn.Module):
280
  This module implements the same attention mechanism as OneVisionEncoderAttention but uses
281
  Flash Attention for improved performance and memory efficiency.
282
  """
 
283
  def __init__(self, config: OneVisionEncoderConfig):
284
  super().__init__()
285
  self.config = config
@@ -386,7 +434,6 @@ class OneVisionEncoderEncoderLayer(nn.Module):
386
  rotary_pos_emb: Optional[torch.Tensor] = None,
387
  output_attentions: bool = False,
388
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
389
-
390
  residual = hidden_states
391
  hidden_states = self.layer_norm1(hidden_states)
392
 
@@ -421,8 +468,7 @@ class OneVisionEncoderEncoder(nn.Module):
421
  output_attentions: bool = False,
422
  output_hidden_states: bool = False,
423
  return_dict: bool = True,
424
- ) -> Union[Tuple, BaseModelOutput]:
425
-
426
  all_hidden_states = () if output_hidden_states else None
427
  all_self_attentions = () if output_attentions else None
428
 
@@ -459,6 +505,7 @@ class OneVisionEncoderEncoder(nn.Module):
459
  # Main Models
460
  # ---------------------------------------------------------------------------
461
 
 
462
  @add_start_docstrings(
463
  "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
464
  ONEVISION_ENCODER_START_DOCSTRING,
@@ -484,7 +531,7 @@ class OneVisionEncoderPreTrainedModel(PreTrainedModel):
484
  elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
485
  # Fix: RMSNorm doesn't have bias, must check hasattr first
486
  module.weight.data.fill_(1.0)
487
- if hasattr(module, 'bias') and module.bias is not None:
488
  module.bias.data.zero_()
489
 
490
 
@@ -503,25 +550,25 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
503
  self.video_rope = VideoRotaryEmbeddingSplit466(config)
504
 
505
  if config.use_head:
506
- self.layernorm_post = get_norm_layer(config)
507
- self.head = Siglip2MultiheadAttentionPoolingHead(config)
508
  else:
509
- self.layernorm_post = None
510
- self.head = None
511
 
512
  self.post_init()
513
 
514
-
515
  @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
516
  @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
517
  def forward(
518
  self,
519
  pixel_values: torch.Tensor,
520
  visible_indices: Optional[torch.Tensor] = None,
 
521
  output_attentions: Optional[bool] = None,
522
  output_hidden_states: Optional[bool] = None,
523
  return_dict: Optional[bool] = None,
524
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
525
  r"""
526
  Returns:
527
 
@@ -549,14 +596,16 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
549
  # Determine video dimensions for RoPE
550
  # Note: pixel_values passed to embeddings can be 4D or 5D
551
  if pixel_values.dim() == 5:
552
- # Use config.rope_temporal_size if set, otherwise use actual frame count
553
- t_frames = self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
554
- height = pixel_values.shape[3]
555
- width = pixel_values.shape[4]
 
 
556
  else:
557
- t_frames = 1
558
- height = pixel_values.shape[2]
559
- width = pixel_values.shape[3]
560
 
561
  # 1. Embeddings
562
  hidden_states = self.embeddings(pixel_values)
@@ -564,16 +613,21 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
564
 
565
  # 2. Visible Indices Handling
566
  if visible_indices is None:
567
- visible_indices = torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
 
 
568
 
569
  # 3. RoPE Construction
570
- freqs_full = self.video_rope(
571
- t=t_frames,
572
- h=height // self.config.patch_size,
573
- w=width // self.config.patch_size,
574
- device=pixel_values.device
575
- )
576
- freqs_visible = freqs_full[visible_indices]
 
 
 
577
 
578
  # Concatenate D/2 + D/2 -> D for applying rope
579
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
 
2
 
3
  import torch
4
  import torch.nn as nn
5
+
6
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.models.siglip.modeling_siglip import SiglipMLP
9
+ from transformers.utils import (
10
+ add_start_docstrings,
11
+ add_start_docstrings_to_model_forward,
12
+ logging,
13
+ replace_return_docstrings,
14
+ )
15
 
16
  from .configuration_onevision_encoder import OneVisionEncoderConfig
17
 
18
+
19
  try:
20
  from flash_attn import flash_attn_func
21
+
22
  _flash_attn_available = True
23
  except ImportError:
24
  _flash_attn_available = False
 
66
  # Helper Functions & Layers
67
  # ---------------------------------------------------------------------------
68
 
69
+
70
  def get_norm_layer(config):
71
  if config.layer_norm_type == "rms_norm":
72
  return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
 
106
  """
107
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
108
  """
109
+
110
  def __init__(self, config: OneVisionEncoderConfig):
111
  super().__init__()
112
  head_dim = config.hidden_size // config.num_attention_heads
 
125
  self.h_size = 6 * unit
126
  self.w_size = 6 * unit
127
 
128
+ self.register_buffer(
129
+ "inv_freq_t",
130
+ 1.0 / (base ** (torch.arange(self.t_size, dtype=torch.float32) / self.t_size)),
131
+ persistent=False,
132
+ )
133
+ self.register_buffer(
134
+ "inv_freq_h",
135
+ 1.0 / (base ** (torch.arange(self.h_size, dtype=torch.float32) / self.h_size)),
136
+ persistent=False,
137
+ )
138
+ self.register_buffer(
139
+ "inv_freq_w",
140
+ 1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)),
141
+ persistent=False,
142
+ )
143
 
144
  def forward(self, t: int, h: int, w: int, device=None):
145
+ if device is None:
146
+ device = self.inv_freq_t.device
147
 
148
  inv_t = self.inv_freq_t.to(device=device)
149
  inv_h = self.inv_freq_h.to(device=device)
 
160
  freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
161
  return freqs
162
 
163
+ def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
164
+ """
165
+ Compute rotary position embeddings from explicit patch positions.
166
+
167
+ Args:
168
+ patch_positions: [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch
169
+
170
+ Returns:
171
+ freqs: [batch_size, seq_len, half] tensor of position frequencies
172
+ """
173
+ device = patch_positions.device
174
+ inv_t = self.inv_freq_t.to(device=device)
175
+ inv_h = self.inv_freq_h.to(device=device)
176
+ inv_w = self.inv_freq_w.to(device=device)
177
+
178
+ t_pos = patch_positions[..., 0].float() # [batch_size, seq_len]
179
+ h_pos = patch_positions[..., 1].float() # [batch_size, seq_len]
180
+ w_pos = patch_positions[..., 2].float() # [batch_size, seq_len]
181
+
182
+ # Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim]
183
+ ft = torch.einsum("bs,d->bsd", t_pos, inv_t)
184
+ fh = torch.einsum("bs,d->bsd", h_pos, inv_h)
185
+ fw = torch.einsum("bs,d->bsd", w_pos, inv_w)
186
+
187
+ return torch.cat([ft, fh, fw], dim=-1)
188
+
189
 
190
  class Siglip2MultiheadAttentionPoolingHead(nn.Module):
191
  """
192
  Multi-Head Attention Pooling with a learned probe (PMA-style).
193
  """
194
+
195
  def __init__(self, config: OneVisionEncoderConfig):
196
  super().__init__()
197
  self.embed_dim = config.hidden_size
 
217
  # Modeling Components
218
  # ---------------------------------------------------------------------------
219
 
220
+
221
  class OneVisionEncoderEmbeddings(nn.Module):
222
  def __init__(self, config: OneVisionEncoderConfig):
223
  super().__init__()
 
237
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
238
  # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
239
  if pixel_values.dim() == 4:
240
+ pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
241
 
242
  batch_size, channels, t_frames, height, width = pixel_values.shape
243
 
 
246
 
247
  # Patch Embed
248
  embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
249
+ embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
250
 
251
  # Flatten all patches
252
  total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
 
257
 
258
  class OneVisionEncoderAttention(nn.Module):
259
  """Multi-headed attention with RoPE support"""
260
+
261
  def __init__(self, config: OneVisionEncoderConfig):
262
  super().__init__()
263
  self.config = config
 
265
  self.num_heads = config.num_attention_heads
266
  self.head_dim = self.embed_dim // self.num_heads
267
  if self.head_dim * self.num_heads != self.embed_dim:
268
+ raise ValueError(
269
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
270
  )
271
 
 
277
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
279
 
 
280
  def forward(
281
  self,
282
  hidden_states: torch.Tensor,
 
284
  rotary_pos_emb: Optional[torch.Tensor] = None,
285
  output_attentions: bool = False,
286
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
 
287
  batch_size, q_len, _ = hidden_states.size()
288
 
289
  query_states = self.q_proj(hidden_states)
 
304
  if attention_mask is not None:
305
  if attention_mask.size() != (batch_size, 1, q_len, q_len):
306
  if attention_mask.dim() == 3:
307
+ attention_mask = attention_mask.unsqueeze(1)
308
  attn_weights = attn_weights + attention_mask
309
 
310
  # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
 
327
  This module implements the same attention mechanism as OneVisionEncoderAttention but uses
328
  Flash Attention for improved performance and memory efficiency.
329
  """
330
+
331
  def __init__(self, config: OneVisionEncoderConfig):
332
  super().__init__()
333
  self.config = config
 
434
  rotary_pos_emb: Optional[torch.Tensor] = None,
435
  output_attentions: bool = False,
436
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
 
437
  residual = hidden_states
438
  hidden_states = self.layer_norm1(hidden_states)
439
 
 
468
  output_attentions: bool = False,
469
  output_hidden_states: bool = False,
470
  return_dict: bool = True,
471
+ ) -> Union[tuple, BaseModelOutput]:
 
472
  all_hidden_states = () if output_hidden_states else None
473
  all_self_attentions = () if output_attentions else None
474
 
 
505
  # Main Models
506
  # ---------------------------------------------------------------------------
507
 
508
+
509
  @add_start_docstrings(
510
  "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
511
  ONEVISION_ENCODER_START_DOCSTRING,
 
531
  elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
532
  # Fix: RMSNorm doesn't have bias, must check hasattr first
533
  module.weight.data.fill_(1.0)
534
+ if hasattr(module, "bias") and module.bias is not None:
535
  module.bias.data.zero_()
536
 
537
 
 
550
  self.video_rope = VideoRotaryEmbeddingSplit466(config)
551
 
552
  if config.use_head:
553
+ self.layernorm_post = get_norm_layer(config)
554
+ self.head = Siglip2MultiheadAttentionPoolingHead(config)
555
  else:
556
+ self.layernorm_post = None
557
+ self.head = None
558
 
559
  self.post_init()
560
 
 
561
  @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
562
  @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
563
  def forward(
564
  self,
565
  pixel_values: torch.Tensor,
566
  visible_indices: Optional[torch.Tensor] = None,
567
+ patch_positions: Optional[torch.Tensor] = None,
568
  output_attentions: Optional[bool] = None,
569
  output_hidden_states: Optional[bool] = None,
570
  return_dict: Optional[bool] = None,
571
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
572
  r"""
573
  Returns:
574
 
 
596
  # Determine video dimensions for RoPE
597
  # Note: pixel_values passed to embeddings can be 4D or 5D
598
  if pixel_values.dim() == 5:
599
+ # Use config.rope_temporal_size if set, otherwise use actual frame count
600
+ t_frames = (
601
+ self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
602
+ )
603
+ height = pixel_values.shape[3]
604
+ width = pixel_values.shape[4]
605
  else:
606
+ t_frames = 1
607
+ height = pixel_values.shape[2]
608
+ width = pixel_values.shape[3]
609
 
610
  # 1. Embeddings
611
  hidden_states = self.embeddings(pixel_values)
 
613
 
614
  # 2. Visible Indices Handling
615
  if visible_indices is None:
616
+ visible_indices = (
617
+ torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
618
+ )
619
 
620
  # 3. RoPE Construction
621
+ if patch_positions is not None:
622
+ freqs_visible = self.video_rope.forward_from_positions(patch_positions)
623
+ else:
624
+ freqs_full = self.video_rope(
625
+ t=t_frames,
626
+ h=height // self.config.patch_size,
627
+ w=width // self.config.patch_size,
628
+ device=pixel_values.device,
629
+ )
630
+ freqs_visible = freqs_full[visible_indices]
631
 
632
  # Concatenate D/2 + D/2 -> D for applying rope
633
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)