| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """HyperCLOVAX-SEED Vision Encoder model. |
| |
| A spatio-temporal vision transformer using the Qwen2.5-VL ViT architecture |
| (window attention, 3D patch embedding, 2D RoPE), trained with SigLIP-style |
| sigmoid contrastive loss. |
| |
| Code-level modifications over the base Qwen2.5-VL ViT: |
| - transformers 5.x compatibility: RotaryEmbedding recomputes inv_freq on-the-fly |
| to handle no_init_weights() zeroing (persistent=False register_buffer) |
| - Float16 numerical stability: autocast paths in PatchMerger and the last |
| transformer block's MLP |
| - disable_merger option: skips PatchMerger and returns raw patch features |
| with window index for external merging |
| |
| Acknowledgements: |
| - Architecture adapted from Qwen2.5-VL ViT |
| (https://github.com/QwenLM/Qwen2.5-VL), Apache-2.0 License. |
| - Training objective based on SigLIP |
| (https://github.com/google-research/big_vision), Apache-2.0 License. |
| """ |
|
|
| from collections.abc import Callable |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import AutoModel |
| from transformers.activations import ACT2FN |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| try: |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| except ImportError: |
| class GradientCheckpointingLayer(nn.Module): |
| pass |
|
|
| try: |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| except ImportError: |
| ALL_ATTENTION_FUNCTIONS = {} |
|
|
| from .configuration_hyperclovax_seed_vision_encoder import HyperCLOVAXSeedVisionEncoderConfig |
|
|
|
|
| class HyperCLOVAXSeedVisionRMSNorm(nn.Module): |
| """RMS normalisation layer.""" |
|
|
| def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self) -> str: |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb_vision( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Apply rotary position embeddings to query and key tensors.""" |
| orig_q_dtype = q.dtype |
| orig_k_dtype = k.dtype |
| q, k = q.float(), k.float() |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| q_embed = q_embed.to(orig_q_dtype) |
| k_embed = k_embed.to(orig_k_dtype) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). |
| hidden_states: (batch, num_key_value_heads, seqlen, head_dim) |
| -> (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| """Eager (non-fused) scaled dot-product attention, used as fallback.""" |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class HyperCLOVAXSeedVisionMLP(nn.Module): |
| """SwiGLU MLP used inside each vision transformer block.""" |
|
|
| def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig, bias: bool = False): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) |
|
|
|
|
| class HyperCLOVAXSeedVisionPatchEmbed(nn.Module): |
| """3D patch embedding for spatio-temporal inputs via Conv3d.""" |
|
|
| def __init__( |
| self, |
| patch_size: int = 14, |
| temporal_patch_size: int = 2, |
| in_channels: int = 3, |
| embed_dim: int = 1152, |
| ) -> None: |
| super().__init__() |
| self.patch_size = patch_size |
| self.temporal_patch_size = temporal_patch_size |
| self.in_channels = in_channels |
| self.embed_dim = embed_dim |
|
|
| kernel_size = [temporal_patch_size, patch_size, patch_size] |
| self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| target_dtype = self.proj.weight.dtype |
| hidden_states = hidden_states.view( |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size |
| ) |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) |
| return hidden_states |
|
|
|
|
| class HyperCLOVAXSeedVisionRotaryEmbedding(nn.Module): |
| """2D rotary position embedding for vision patches. |
| |
| Recomputes ``inv_freq`` in ``forward`` to be robust against |
| ``no_init_weights()`` zeroing in transformers 5.x (``persistent=False``). |
| """ |
|
|
| inv_freq: torch.Tensor |
|
|
| def __init__(self, dim: int, theta: float = 10000.0) -> None: |
| super().__init__() |
| self.dim = dim |
| self.theta = theta |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, seqlen: int) -> torch.Tensor: |
| |
| |
| inv_freq = 1.0 / (self.theta ** ( |
| torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim |
| )) |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=inv_freq.dtype) |
| freqs = torch.outer(seq, inv_freq) |
| return freqs |
|
|
|
|
| class HyperCLOVAXSeedVisionPatchMerger(nn.Module): |
| """MLP that merges spatially-grouped patches and projects to the output hidden size.""" |
|
|
| def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: |
| super().__init__() |
| self.hidden_size = context_dim * (spatial_merge_size**2) |
| self.ln_q = HyperCLOVAXSeedVisionRMSNorm(context_dim, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(self.hidden_size, self.hidden_size), |
| nn.GELU(), |
| nn.Linear(self.hidden_size, dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.mlp[0].weight.dtype == torch.float16: |
| with torch.amp.autocast(device_type=x.device.type, dtype=torch.float32): |
| x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) |
| x = x.to(self.mlp[0].weight.dtype) |
| else: |
| x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) |
| return x |
|
|
|
|
| class HyperCLOVAXSeedVisionAttention(nn.Module): |
| """Multi-head self-attention with 2D RoPE, supporting flash-attention and SDPA.""" |
|
|
| def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig) -> None: |
| super().__init__() |
| self.dim = config.hidden_size |
| self.num_heads = config.num_heads |
| self.head_dim = self.dim // self.num_heads |
| self.num_key_value_groups = 1 |
| self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) |
| self.proj = nn.Linear(self.dim, self.dim) |
| self.scaling = self.head_dim**-0.5 |
| self.config = config |
| self.attention_dropout = 0.0 |
| self.is_causal = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| rotary_pos_emb: Optional[torch.Tensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| query_states, key_states, value_states = ( |
| self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| ) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) |
|
|
| query_states = query_states.transpose(0, 1).unsqueeze(0) |
| key_states = key_states.transpose(0, 1).unsqueeze(0) |
| value_states = value_states.transpose(0, 1).unsqueeze(0) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| if self.config._attn_implementation == "flash_attention_2": |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
| attn_output, _ = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask=None, |
| scaling=self.scaling, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| cu_seq_lens_q=cu_seqlens, |
| cu_seq_lens_k=cu_seqlens, |
| max_length_q=max_seqlen, |
| max_length_k=max_seqlen, |
| is_causal=False, |
| **kwargs, |
| ) |
| else: |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| splits = [ |
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) |
| ] |
| attn_outputs = [ |
| attention_interface( |
| self, |
| q, |
| k, |
| v, |
| attention_mask=None, |
| scaling=self.scaling, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| is_causal=False, |
| **kwargs, |
| )[0] |
| for q, k, v in zip(*splits) |
| ] |
| attn_output = torch.cat(attn_outputs, dim=1) |
|
|
| attn_output = attn_output.reshape(seq_length, -1).contiguous() |
| attn_output = self.proj(attn_output) |
| return attn_output |
|
|
|
|
| class HyperCLOVAXSeedVisionBlock(GradientCheckpointingLayer): |
| """Transformer block with window or full attention and fp16-safe MLP.""" |
|
|
| def __init__( |
| self, |
| config: HyperCLOVAXSeedVisionEncoderConfig, |
| is_fullatt: bool = False, |
| is_last: bool = False, |
| ) -> None: |
| super().__init__() |
| self.norm1 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6) |
| self.norm2 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6) |
| self.attn = HyperCLOVAXSeedVisionAttention(config=config) |
| self.mlp = HyperCLOVAXSeedVisionMLP(config, bias=True) |
| self.is_fullatt = is_fullatt |
| self.is_last = is_last |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| rotary_pos_emb: Optional[torch.Tensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| hidden_states = hidden_states + self.attn( |
| self.norm1(hidden_states), |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| |
| |
| if ( |
| (not self.is_fullatt and not self.is_last) |
| or hidden_states.dtype != torch.float16 |
| ): |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| else: |
| org_type = hidden_states.dtype |
| with torch.amp.autocast(device_type=hidden_states.device.type, dtype=torch.float32): |
| mlp_out = self.mlp(self.norm2(hidden_states)) |
| if self.is_last: |
| hidden_states = (hidden_states + mlp_out).to(torch.float32) |
| else: |
| hidden_states = (hidden_states + mlp_out).to(org_type) |
| return hidden_states |
|
|
|
|
| class HyperCLOVAXSeedVisionEncoder(PreTrainedModel): |
| """HyperCLOVAX-SEED Vision Encoder. |
| |
| A spatio-temporal vision transformer that encodes images and videos into |
| token sequences suitable for a causal language model, |
| using window-based and full attention blocks. |
| |
| The encoder outputs merged patch embeddings of shape |
| ``(total_merged_patches, out_hidden_size)`` where |
| ``total_merged_patches = sum(t * h * w / spatial_merge_size^2 for each input)``. |
| """ |
|
|
| config_class = HyperCLOVAXSeedVisionEncoderConfig |
| _no_split_modules = ["HyperCLOVAXSeedVisionBlock"] |
| supports_gradient_checkpointing = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
|
|
| def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig, *inputs, **kwargs) -> None: |
| super().__init__(config, *inputs, **kwargs) |
| self.spatial_merge_size = config.spatial_merge_size |
| self.patch_size = config.patch_size |
| self.fullatt_block_indexes = config.fullatt_block_indexes |
| self.window_size = config.window_size |
| self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size |
| self.disable_merger = config.disable_merger |
|
|
| self.patch_embed = HyperCLOVAXSeedVisionPatchEmbed( |
| patch_size=config.patch_size, |
| temporal_patch_size=config.temporal_patch_size, |
| in_channels=config.in_channels, |
| embed_dim=config.hidden_size, |
| ) |
|
|
| head_dim = config.hidden_size // config.num_heads |
| self.rotary_pos_emb = HyperCLOVAXSeedVisionRotaryEmbedding(head_dim // 2) |
|
|
| self.blocks = nn.ModuleList([ |
| HyperCLOVAXSeedVisionBlock( |
| config, |
| is_fullatt=(_block_idx in config.fullatt_block_indexes), |
| is_last=(_block_idx == config.depth - 1), |
| ) |
| for _block_idx in range(config.depth) |
| ]) |
| self.merger = None |
| if not self.disable_merger: |
| self.merger = HyperCLOVAXSeedVisionPatchMerger( |
| dim=config.out_hidden_size, |
| context_dim=config.hidden_size, |
| spatial_merge_size=config.spatial_merge_size, |
| ) |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: |
| """Compute 2D rotary position embeddings for all patches in the batch.""" |
| pos_ids = [] |
| for t, h, w in grid_thw: |
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
| hpos_ids = hpos_ids.reshape( |
| h // self.spatial_merge_size, |
| self.spatial_merge_size, |
| w // self.spatial_merge_size, |
| self.spatial_merge_size, |
| ) |
| hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
| hpos_ids = hpos_ids.flatten() |
|
|
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
| wpos_ids = wpos_ids.reshape( |
| h // self.spatial_merge_size, |
| self.spatial_merge_size, |
| w // self.spatial_merge_size, |
| self.spatial_merge_size, |
| ) |
| wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
| wpos_ids = wpos_ids.flatten() |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| pos_ids = torch.cat(pos_ids, dim=0) |
| max_grid_size = grid_thw[:, 1:].max() |
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
| return rotary_pos_emb |
|
|
| def get_window_index(self, grid_thw: torch.Tensor) -> tuple[torch.Tensor, list]: |
| """Build a flat index that reorders tokens into non-overlapping windows. |
| |
| Returns: |
| window_index: permutation indices to gather tokens in window order |
| cu_window_seqlens: cumulative window sequence lengths for varlen attention |
| """ |
| window_index: list = [] |
| cu_window_seqlens: list = [0] |
| window_index_id = 0 |
| vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size |
|
|
| for grid_t, grid_h, grid_w in grid_thw: |
| llm_grid_h, llm_grid_w = ( |
| grid_h // self.spatial_merge_size, |
| grid_w // self.spatial_merge_size, |
| ) |
| index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
| pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
| pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
| num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
| num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
| index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
| index_padded = index_padded.reshape( |
| grid_t, |
| num_windows_h, |
| vit_merger_window_size, |
| num_windows_w, |
| vit_merger_window_size, |
| ) |
| index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
| grid_t, |
| num_windows_h * num_windows_w, |
| vit_merger_window_size, |
| vit_merger_window_size, |
| ) |
| seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
| index_padded = index_padded.reshape(-1) |
| index_new = index_padded[index_padded != -100] |
| window_index.append(index_new + window_index_id) |
| cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] |
| cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
| window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
| window_index = torch.cat(window_index, dim=0) |
|
|
| return window_index, cu_window_seqlens |
|
|
| def forward( |
| self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Args: |
| hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): |
| Flattened patch pixels (output of patch embedding pipeline before this call). |
| In practice this is the raw pixel tensor passed to `patch_embed` internally. |
| grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): |
| Temporal, height and width grid dimensions for each input item. |
| |
| Returns: |
| `torch.Tensor` of shape `(total_merged_patches, out_hidden_size)`. |
| """ |
| hidden_states = self.patch_embed(hidden_states) |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
| cu_window_seqlens = torch.tensor( |
| cu_window_seqlens, |
| device=hidden_states.device, |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
| seq_len, _ = hidden_states.size() |
| hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| hidden_states = hidden_states[window_index, :, :] |
| hidden_states = hidden_states.reshape(seq_len, -1) |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| position_embeddings = (emb.cos(), emb.sin()) |
|
|
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
| for layer_num, blk in enumerate(self.blocks): |
| cu_seqlens_now = cu_seqlens if layer_num in self.fullatt_block_indexes else cu_window_seqlens |
|
|
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func( |
| blk, |
| hidden_states, |
| cu_seqlens_now, |
| None, |
| position_embeddings, |
| ) |
| else: |
| hidden_states = blk( |
| hidden_states, |
| cu_seqlens=cu_seqlens_now, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
|
|
| if self.merger is not None: |
| hidden_states = self.merger(hidden_states) |
| reverse_indices = torch.argsort(window_index) |
| hidden_states = hidden_states[reverse_indices, :] |
| return hidden_states |
| else: |
| |
| return hidden_states, window_index |
|
|
|
|
| AutoModel.register(HyperCLOVAXSeedVisionEncoderConfig, HyperCLOVAXSeedVisionEncoder) |
|
|
| __all__ = ["HyperCLOVAXSeedVisionEncoder"] |
|
|