xiangan commited on
Commit
a2882ce
·
1 Parent(s): bb40be1

Upload folder using huggingface_hub

Browse files
configuration_onevision_encoder.py CHANGED
@@ -1,6 +1,7 @@
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
 
4
  logger = logging.get_logger(__name__)
5
 
6
 
@@ -26,7 +27,7 @@ class OneVisionEncoderConfig(PretrainedConfig):
26
  The number of input channels.
27
  image_size (`int`, *optional*, defaults to 224):
28
  The size (resolution) of each image.
29
- patch_size (`int`, *optional*, defaults to 16):
30
  The size (resolution) of each patch.
31
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
32
  The non-linear activation function (function or string) in the encoder and pooler.
@@ -70,7 +71,7 @@ class OneVisionEncoderConfig(PretrainedConfig):
70
  num_attention_heads=16,
71
  num_channels=3,
72
  image_size=448,
73
- patch_size=16,
74
  hidden_act="gelu",
75
  layer_norm_eps=1e-6,
76
  layer_norm_type="layer_norm",
 
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.utils import logging
3
 
4
+
5
  logger = logging.get_logger(__name__)
6
 
7
 
 
27
  The number of input channels.
28
  image_size (`int`, *optional*, defaults to 224):
29
  The size (resolution) of each image.
30
+ patch_size (`int`, *optional*, defaults to 14):
31
  The size (resolution) of each patch.
32
  hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
33
  The non-linear activation function (function or string) in the encoder and pooler.
 
71
  num_attention_heads=16,
72
  num_channels=3,
73
  image_size=448,
74
+ patch_size=14,
75
  hidden_act="gelu",
76
  layer_norm_eps=1e-6,
77
  layer_norm_type="layer_norm",
modeling_onevision_encoder.py CHANGED
@@ -2,18 +2,23 @@ from typing import Optional, Tuple, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
- from transformers.modeling_outputs import (BaseModelOutput,
6
- BaseModelOutputWithPooling)
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.models.siglip.modeling_siglip import SiglipMLP
9
- from transformers.utils import (add_start_docstrings,
10
- add_start_docstrings_to_model_forward, logging,
11
- replace_return_docstrings)
 
 
 
12
 
13
  from .configuration_onevision_encoder import OneVisionEncoderConfig
14
 
 
15
  try:
16
  from flash_attn import flash_attn_func
 
17
  _flash_attn_available = True
18
  except ImportError:
19
  _flash_attn_available = False
@@ -61,6 +66,7 @@ ONEVISION_ENCODER_INPUTS_DOCSTRING = r"""
61
  # Helper Functions & Layers
62
  # ---------------------------------------------------------------------------
63
 
 
64
  def get_norm_layer(config):
65
  if config.layer_norm_type == "rms_norm":
66
  return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -100,6 +106,7 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
100
  """
101
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
102
  """
 
103
  def __init__(self, config: OneVisionEncoderConfig):
104
  super().__init__()
105
  head_dim = config.hidden_size // config.num_attention_heads
@@ -118,12 +125,25 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
118
  self.h_size = 6 * unit
119
  self.w_size = 6 * unit
120
 
121
- self.register_buffer("inv_freq_t", 1.0 / (base ** (torch.arange(self.t_size, dtype=torch.float32) / self.t_size)), persistent=False)
122
- self.register_buffer("inv_freq_h", 1.0 / (base ** (torch.arange(self.h_size, dtype=torch.float32) / self.h_size)), persistent=False)
123
- self.register_buffer("inv_freq_w", 1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)), persistent=False)
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def forward(self, t: int, h: int, w: int, device=None):
126
- if device is None: device = self.inv_freq_t.device
 
127
 
128
  inv_t = self.inv_freq_t.to(device=device)
129
  inv_h = self.inv_freq_h.to(device=device)
@@ -140,11 +160,37 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
140
  freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
141
  return freqs
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  class Siglip2MultiheadAttentionPoolingHead(nn.Module):
145
  """
146
  Multi-Head Attention Pooling with a learned probe (PMA-style).
147
  """
 
148
  def __init__(self, config: OneVisionEncoderConfig):
149
  super().__init__()
150
  self.embed_dim = config.hidden_size
@@ -170,6 +216,7 @@ class Siglip2MultiheadAttentionPoolingHead(nn.Module):
170
  # Modeling Components
171
  # ---------------------------------------------------------------------------
172
 
 
173
  class OneVisionEncoderEmbeddings(nn.Module):
174
  def __init__(self, config: OneVisionEncoderConfig):
175
  super().__init__()
@@ -189,7 +236,7 @@ class OneVisionEncoderEmbeddings(nn.Module):
189
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
190
  # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
191
  if pixel_values.dim() == 4:
192
- pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
193
 
194
  batch_size, channels, t_frames, height, width = pixel_values.shape
195
 
@@ -198,7 +245,7 @@ class OneVisionEncoderEmbeddings(nn.Module):
198
 
199
  # Patch Embed
200
  embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
201
- embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
202
 
203
  # Flatten all patches
204
  total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
@@ -209,6 +256,7 @@ class OneVisionEncoderEmbeddings(nn.Module):
209
 
210
  class OneVisionEncoderAttention(nn.Module):
211
  """Multi-headed attention with RoPE support"""
 
212
  def __init__(self, config: OneVisionEncoderConfig):
213
  super().__init__()
214
  self.config = config
@@ -216,7 +264,7 @@ class OneVisionEncoderAttention(nn.Module):
216
  self.num_heads = config.num_attention_heads
217
  self.head_dim = self.embed_dim // self.num_heads
218
  if self.head_dim * self.num_heads != self.embed_dim:
219
- raise ValueError(
220
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
221
  )
222
 
@@ -228,7 +276,6 @@ class OneVisionEncoderAttention(nn.Module):
228
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
229
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
230
 
231
-
232
  def forward(
233
  self,
234
  hidden_states: torch.Tensor,
@@ -236,7 +283,6 @@ class OneVisionEncoderAttention(nn.Module):
236
  rotary_pos_emb: Optional[torch.Tensor] = None,
237
  output_attentions: bool = False,
238
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
239
-
240
  batch_size, q_len, _ = hidden_states.size()
241
 
242
  query_states = self.q_proj(hidden_states)
@@ -257,7 +303,7 @@ class OneVisionEncoderAttention(nn.Module):
257
  if attention_mask is not None:
258
  if attention_mask.size() != (batch_size, 1, q_len, q_len):
259
  if attention_mask.dim() == 3:
260
- attention_mask = attention_mask.unsqueeze(1)
261
  attn_weights = attn_weights + attention_mask
262
 
263
  # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
@@ -280,6 +326,7 @@ class OneVisionEncoderFlashAttention2(nn.Module):
280
  This module implements the same attention mechanism as OneVisionEncoderAttention but uses
281
  Flash Attention for improved performance and memory efficiency.
282
  """
 
283
  def __init__(self, config: OneVisionEncoderConfig):
284
  super().__init__()
285
  self.config = config
@@ -386,7 +433,6 @@ class OneVisionEncoderEncoderLayer(nn.Module):
386
  rotary_pos_emb: Optional[torch.Tensor] = None,
387
  output_attentions: bool = False,
388
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
389
-
390
  residual = hidden_states
391
  hidden_states = self.layer_norm1(hidden_states)
392
 
@@ -421,8 +467,7 @@ class OneVisionEncoderEncoder(nn.Module):
421
  output_attentions: bool = False,
422
  output_hidden_states: bool = False,
423
  return_dict: bool = True,
424
- ) -> Union[Tuple, BaseModelOutput]:
425
-
426
  all_hidden_states = () if output_hidden_states else None
427
  all_self_attentions = () if output_attentions else None
428
 
@@ -459,6 +504,7 @@ class OneVisionEncoderEncoder(nn.Module):
459
  # Main Models
460
  # ---------------------------------------------------------------------------
461
 
 
462
  @add_start_docstrings(
463
  "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
464
  ONEVISION_ENCODER_START_DOCSTRING,
@@ -484,7 +530,7 @@ class OneVisionEncoderPreTrainedModel(PreTrainedModel):
484
  elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
485
  # Fix: RMSNorm doesn't have bias, must check hasattr first
486
  module.weight.data.fill_(1.0)
487
- if hasattr(module, 'bias') and module.bias is not None:
488
  module.bias.data.zero_()
489
 
490
 
@@ -503,25 +549,25 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
503
  self.video_rope = VideoRotaryEmbeddingSplit466(config)
504
 
505
  if config.use_head:
506
- self.layernorm_post = get_norm_layer(config)
507
- self.head = Siglip2MultiheadAttentionPoolingHead(config)
508
  else:
509
- self.layernorm_post = None
510
- self.head = None
511
 
512
  self.post_init()
513
 
514
-
515
  @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
516
  @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
517
  def forward(
518
  self,
519
  pixel_values: torch.Tensor,
 
520
  visible_indices: Optional[torch.Tensor] = None,
521
  output_attentions: Optional[bool] = None,
522
  output_hidden_states: Optional[bool] = None,
523
  return_dict: Optional[bool] = None,
524
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
525
  r"""
526
  Returns:
527
 
@@ -549,14 +595,16 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
549
  # Determine video dimensions for RoPE
550
  # Note: pixel_values passed to embeddings can be 4D or 5D
551
  if pixel_values.dim() == 5:
552
- # Use config.rope_temporal_size if set, otherwise use actual frame count
553
- t_frames = self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
554
- height = pixel_values.shape[3]
555
- width = pixel_values.shape[4]
 
 
556
  else:
557
- t_frames = 1
558
- height = pixel_values.shape[2]
559
- width = pixel_values.shape[3]
560
 
561
  # 1. Embeddings
562
  hidden_states = self.embeddings(pixel_values)
@@ -564,16 +612,21 @@ class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
564
 
565
  # 2. Visible Indices Handling
566
  if visible_indices is None:
567
- visible_indices = torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
 
 
568
 
569
  # 3. RoPE Construction
570
- freqs_full = self.video_rope(
571
- t=t_frames,
572
- h=height // self.config.patch_size,
573
- w=width // self.config.patch_size,
574
- device=pixel_values.device
575
- )
576
- freqs_visible = freqs_full[visible_indices]
 
 
 
577
 
578
  # Concatenate D/2 + D/2 -> D for applying rope
579
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
 
2
 
3
  import torch
4
  import torch.nn as nn
5
+
6
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
7
  from transformers.modeling_utils import PreTrainedModel
8
  from transformers.models.siglip.modeling_siglip import SiglipMLP
9
+ from transformers.utils import (
10
+ add_start_docstrings,
11
+ add_start_docstrings_to_model_forward,
12
+ logging,
13
+ replace_return_docstrings,
14
+ )
15
 
16
  from .configuration_onevision_encoder import OneVisionEncoderConfig
17
 
18
+
19
  try:
20
  from flash_attn import flash_attn_func
21
+
22
  _flash_attn_available = True
23
  except ImportError:
24
  _flash_attn_available = False
 
66
  # Helper Functions & Layers
67
  # ---------------------------------------------------------------------------
68
 
69
+
70
  def get_norm_layer(config):
71
  if config.layer_norm_type == "rms_norm":
72
  return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
 
106
  """
107
  3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
108
  """
109
+
110
  def __init__(self, config: OneVisionEncoderConfig):
111
  super().__init__()
112
  head_dim = config.hidden_size // config.num_attention_heads
 
125
  self.h_size = 6 * unit
126
  self.w_size = 6 * unit
127
 
128
+ self.register_buffer(
129
+ "inv_freq_t",
130
+ 1.0 / (base ** (torch.arange(self.t_size, dtype=torch.float32) / self.t_size)),
131
+ persistent=False,
132
+ )
133
+ self.register_buffer(
134
+ "inv_freq_h",
135
+ 1.0 / (base ** (torch.arange(self.h_size, dtype=torch.float32) / self.h_size)),
136
+ persistent=False,
137
+ )
138
+ self.register_buffer(
139
+ "inv_freq_w",
140
+ 1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)),
141
+ persistent=False,
142
+ )
143
 
144
  def forward(self, t: int, h: int, w: int, device=None):
145
+ if device is None:
146
+ device = self.inv_freq_t.device
147
 
148
  inv_t = self.inv_freq_t.to(device=device)
149
  inv_h = self.inv_freq_h.to(device=device)
 
160
  freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
161
  return freqs
162
 
163
+ def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
164
+ """
165
+ Compute rotary position embeddings from explicit patch positions.
166
+
167
+ Args:
168
+ patch_positions: [seq_len, 3] tensor with [t, h, w] positions for each patch
169
+
170
+ Returns:
171
+ freqs: [seq_len, half] tensor of position frequencies
172
+ """
173
+ device = patch_positions.device
174
+ inv_t = self.inv_freq_t.to(device=device)
175
+ inv_h = self.inv_freq_h.to(device=device)
176
+ inv_w = self.inv_freq_w.to(device=device)
177
+
178
+ t_pos = patch_positions[:, 0].float()
179
+ h_pos = patch_positions[:, 1].float()
180
+ w_pos = patch_positions[:, 2].float()
181
+
182
+ ft = torch.outer(t_pos, inv_t)
183
+ fh = torch.outer(h_pos, inv_h)
184
+ fw = torch.outer(w_pos, inv_w)
185
+
186
+ return torch.cat([ft, fh, fw], dim=-1)
187
+
188
 
189
  class Siglip2MultiheadAttentionPoolingHead(nn.Module):
190
  """
191
  Multi-Head Attention Pooling with a learned probe (PMA-style).
192
  """
193
+
194
  def __init__(self, config: OneVisionEncoderConfig):
195
  super().__init__()
196
  self.embed_dim = config.hidden_size
 
216
  # Modeling Components
217
  # ---------------------------------------------------------------------------
218
 
219
+
220
  class OneVisionEncoderEmbeddings(nn.Module):
221
  def __init__(self, config: OneVisionEncoderConfig):
222
  super().__init__()
 
236
  def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
237
  # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
238
  if pixel_values.dim() == 4:
239
+ pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
240
 
241
  batch_size, channels, t_frames, height, width = pixel_values.shape
242
 
 
245
 
246
  # Patch Embed
247
  embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
248
+ embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
249
 
250
  # Flatten all patches
251
  total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
 
256
 
257
  class OneVisionEncoderAttention(nn.Module):
258
  """Multi-headed attention with RoPE support"""
259
+
260
  def __init__(self, config: OneVisionEncoderConfig):
261
  super().__init__()
262
  self.config = config
 
264
  self.num_heads = config.num_attention_heads
265
  self.head_dim = self.embed_dim // self.num_heads
266
  if self.head_dim * self.num_heads != self.embed_dim:
267
+ raise ValueError(
268
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
269
  )
270
 
 
276
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
 
 
279
  def forward(
280
  self,
281
  hidden_states: torch.Tensor,
 
283
  rotary_pos_emb: Optional[torch.Tensor] = None,
284
  output_attentions: bool = False,
285
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
 
286
  batch_size, q_len, _ = hidden_states.size()
287
 
288
  query_states = self.q_proj(hidden_states)
 
303
  if attention_mask is not None:
304
  if attention_mask.size() != (batch_size, 1, q_len, q_len):
305
  if attention_mask.dim() == 3:
306
+ attention_mask = attention_mask.unsqueeze(1)
307
  attn_weights = attn_weights + attention_mask
308
 
309
  # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
 
326
  This module implements the same attention mechanism as OneVisionEncoderAttention but uses
327
  Flash Attention for improved performance and memory efficiency.
328
  """
329
+
330
  def __init__(self, config: OneVisionEncoderConfig):
331
  super().__init__()
332
  self.config = config
 
433
  rotary_pos_emb: Optional[torch.Tensor] = None,
434
  output_attentions: bool = False,
435
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
 
436
  residual = hidden_states
437
  hidden_states = self.layer_norm1(hidden_states)
438
 
 
467
  output_attentions: bool = False,
468
  output_hidden_states: bool = False,
469
  return_dict: bool = True,
470
+ ) -> Union[tuple, BaseModelOutput]:
 
471
  all_hidden_states = () if output_hidden_states else None
472
  all_self_attentions = () if output_attentions else None
473
 
 
504
  # Main Models
505
  # ---------------------------------------------------------------------------
506
 
507
+
508
  @add_start_docstrings(
509
  "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
510
  ONEVISION_ENCODER_START_DOCSTRING,
 
530
  elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
531
  # Fix: RMSNorm doesn't have bias, must check hasattr first
532
  module.weight.data.fill_(1.0)
533
+ if hasattr(module, "bias") and module.bias is not None:
534
  module.bias.data.zero_()
535
 
536
 
 
549
  self.video_rope = VideoRotaryEmbeddingSplit466(config)
550
 
551
  if config.use_head:
552
+ self.layernorm_post = get_norm_layer(config)
553
+ self.head = Siglip2MultiheadAttentionPoolingHead(config)
554
  else:
555
+ self.layernorm_post = None
556
+ self.head = None
557
 
558
  self.post_init()
559
 
 
560
  @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
561
  @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
562
  def forward(
563
  self,
564
  pixel_values: torch.Tensor,
565
+ patch_postions: Optional[torch.Tensor] = None,
566
  visible_indices: Optional[torch.Tensor] = None,
567
  output_attentions: Optional[bool] = None,
568
  output_hidden_states: Optional[bool] = None,
569
  return_dict: Optional[bool] = None,
570
+ ) -> Union[tuple, BaseModelOutputWithPooling]:
571
  r"""
572
  Returns:
573
 
 
595
  # Determine video dimensions for RoPE
596
  # Note: pixel_values passed to embeddings can be 4D or 5D
597
  if pixel_values.dim() == 5:
598
+ # Use config.rope_temporal_size if set, otherwise use actual frame count
599
+ t_frames = (
600
+ self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
601
+ )
602
+ height = pixel_values.shape[3]
603
+ width = pixel_values.shape[4]
604
  else:
605
+ t_frames = 1
606
+ height = pixel_values.shape[2]
607
+ width = pixel_values.shape[3]
608
 
609
  # 1. Embeddings
610
  hidden_states = self.embeddings(pixel_values)
 
612
 
613
  # 2. Visible Indices Handling
614
  if visible_indices is None:
615
+ visible_indices = (
616
+ torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
617
+ )
618
 
619
  # 3. RoPE Construction
620
+ if patch_postions is not None:
621
+ freqs_visible = self.video_rope.forward_from_positions(patch_postions)
622
+ else:
623
+ freqs_full = self.video_rope(
624
+ t=t_frames,
625
+ h=height // self.config.patch_size,
626
+ w=width // self.config.patch_size,
627
+ device=pixel_values.device,
628
+ )
629
+ freqs_visible = freqs_full[visible_indices]
630
 
631
  # Concatenate D/2 + D/2 -> D for applying rope
632
  freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)