Instructions to use internlm/Intern-S2-Preview-FP8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use internlm/Intern-S2-Preview-FP8 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="internlm/Intern-S2-Preview-FP8", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoModelForImageTextToText model = AutoModelForImageTextToText.from_pretrained("internlm/Intern-S2-Preview-FP8", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use internlm/Intern-S2-Preview-FP8 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "internlm/Intern-S2-Preview-FP8" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "internlm/Intern-S2-Preview-FP8", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/internlm/Intern-S2-Preview-FP8
- SGLang
How to use internlm/Intern-S2-Preview-FP8 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "internlm/Intern-S2-Preview-FP8" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "internlm/Intern-S2-Preview-FP8", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "internlm/Intern-S2-Preview-FP8" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "internlm/Intern-S2-Preview-FP8", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use internlm/Intern-S2-Preview-FP8 with Docker Model Runner:
docker model run hf.co/internlm/Intern-S2-Preview-FP8
| # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | |
| # This file was automatically generated from src/transformers/models/interns2_preview/modular_interns2_preview.py. | |
| # Do NOT edit this file manually as any edits will be overwritten by the generation of | |
| # the file from the modular. If any change should be done, please apply the change to the | |
| # modular_interns2_preview.py file directly. One of our CI enforces this. | |
| # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | |
| # Copyright 2026 HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from typing import Any, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import initialization as init | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, EncoderDecoderCache | |
| from transformers.generation import GenerationMixin | |
| from transformers.integrations import use_experts_implementation, use_kernelized_func | |
| from transformers.masking_utils import create_causal_mask | |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs | |
| from transformers.modeling_layers import GradientCheckpointingLayer | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput | |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update | |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel | |
| from transformers.processing_utils import Unpack | |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check | |
| from transformers.utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults | |
| from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available | |
| from transformers.utils.output_capturing import OutputRecorder, capture_outputs | |
| from .configuration_interns2_preview import ( | |
| InternS2PreviewConfig, | |
| InternS2PreviewTextConfig, | |
| InternS2PreviewTimeSeriesConfig, | |
| InternS2PreviewVisionConfig, | |
| ) | |
| if is_causal_conv1d_available(): | |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update | |
| else: | |
| causal_conv1d_update, causal_conv1d_fn = None, None | |
| if is_flash_linear_attention_available(): | |
| from fla.modules import FusedRMSNormGated | |
| from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule | |
| else: | |
| chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None | |
| FusedRMSNormGated = None | |
| logger = logging.get_logger(__name__) | |
| class InternS2PreviewVisionRotaryEmbedding(nn.Module): | |
| inv_freq: torch.Tensor # fix linting for `register_buffer` | |
| def __init__(self, dim: int, theta: float = 10000.0) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.theta = theta | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, seqlen: int) -> torch.Tensor: | |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.outer(seq, self.inv_freq) | |
| return freqs | |
| class InternS2PreviewDynamicCache: | |
| """ | |
| A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention | |
| cache (which has a constant shape regardless of seq_len). | |
| This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` | |
| and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor | |
| For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, | |
| while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). | |
| For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), | |
| while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, | |
| and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`. | |
| """ | |
| is_compileable = False | |
| def __init__(self, config: InternS2PreviewConfig): | |
| super().__init__() | |
| self.layer_types = config.layer_types | |
| self.transformer_layers = [ | |
| i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention" | |
| ] | |
| self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention") | |
| # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference | |
| self.conv_states = [None for _ in range(config.num_hidden_layers)] | |
| self.recurrent_states = [None for _ in range(config.num_hidden_layers)] | |
| self.key_cache = [None for _ in range(config.num_hidden_layers)] | |
| self.value_cache = [None for _ in range(config.num_hidden_layers)] | |
| def __len__(self): | |
| return len(self.layer_types) | |
| def update( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: dict[str, Any] | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if self.key_cache[layer_idx] is None: | |
| self.key_cache[layer_idx] = key_states | |
| self.value_cache[layer_idx] = value_states | |
| else: | |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) | |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) | |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | |
| def reorder_cache(self, beam_idx: torch.LongTensor): | |
| """Reorders the cache for beam search, given the selected beam indices.""" | |
| for layer_idx in range(len(self.key_cache)): | |
| if self.key_cache[layer_idx] is not None: | |
| device = self.key_cache[layer_idx].device | |
| beam_idx = beam_idx.to(device) | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) | |
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) | |
| if self.conv_states[layer_idx] is not None: | |
| device = self.conv_states[layer_idx].device | |
| beam_idx = beam_idx.to(device) | |
| self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) | |
| self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx) | |
| def get_seq_length(self, layer_idx: int | None = 0) -> int: | |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | |
| # take any layer that contains cache and not empty tensor | |
| layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx | |
| if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: | |
| return 0 | |
| return self.key_cache[layer_idx].shape[-2] | |
| def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: | |
| """ | |
| Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for | |
| the given layer at `layer_idx`. | |
| The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer. | |
| """ | |
| kv_offset = 0 | |
| query_length = cache_position.shape[0] | |
| past_seen_tokens = self.get_seq_length(layer_idx) | |
| kv_length = query_length + past_seen_tokens | |
| return kv_length, kv_offset | |
| def has_previous_state(self): | |
| """We have a previous state if the last linear (conv) layer was already updated.""" | |
| return self.conv_states[self.last_linear_layer] is not None | |
| class InternS2PreviewRMSNormGated(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6, **kwargs): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states, gate=None): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| # Norm before gate | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| hidden_states = self.weight * hidden_states.to(input_dtype) | |
| hidden_states = hidden_states * F.silu(gate.to(torch.float32)) | |
| return hidden_states.to(input_dtype) | |
| def apply_mask_to_padding_states(hidden_states, attention_mask): | |
| """ | |
| Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 | |
| """ | |
| # NOTE: attention mask is a 2D boolean tensor | |
| if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: | |
| dtype = hidden_states.dtype | |
| hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) | |
| return hidden_states | |
| is_fast_path_available = all( | |
| (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) | |
| ) | |
| def torch_causal_conv1d_update( | |
| hidden_states, | |
| conv_state, | |
| weight, | |
| bias=None, | |
| activation=None, | |
| ): | |
| _, hidden_size, seq_len = hidden_states.shape | |
| state_len = conv_state.shape[-1] | |
| hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) | |
| conv_state.copy_(hidden_states_new[:, :, -state_len:]) | |
| out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size) | |
| out = F.silu(out[:, :, -seq_len:]) | |
| out = out.to(hidden_states.dtype) | |
| return out | |
| def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): | |
| """This function is intended to align with the l2norm implementation in the FLA library.""" | |
| inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) | |
| return x * inv_norm | |
| def torch_chunk_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g, | |
| beta, | |
| chunk_size=64, | |
| initial_state=None, | |
| output_final_state=False, | |
| use_qk_l2norm_in_kernel=False, | |
| ): | |
| initial_dtype = query.dtype | |
| if use_qk_l2norm_in_kernel: | |
| query = l2norm(query, dim=-1, eps=1e-6) | |
| key = l2norm(key, dim=-1, eps=1e-6) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) | |
| ] | |
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | |
| v_head_dim = value.shape[-1] | |
| pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size | |
| query = F.pad(query, (0, 0, 0, pad_size)) | |
| key = F.pad(key, (0, 0, 0, pad_size)) | |
| value = F.pad(value, (0, 0, 0, pad_size)) | |
| beta = F.pad(beta, (0, pad_size)) | |
| g = F.pad(g, (0, pad_size)) | |
| total_sequence_length = sequence_length + pad_size | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| query = query * scale | |
| v_beta = value * beta.unsqueeze(-1) | |
| k_beta = key * beta.unsqueeze(-1) | |
| # reshape to chunks | |
| query, key, value, k_beta, v_beta = [ | |
| x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) | |
| ] | |
| g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) | |
| mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) | |
| # chunk decay | |
| g = g.cumsum(dim=-1) | |
| decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() | |
| attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) | |
| for i in range(1, chunk_size): | |
| row = attn[..., i, :i].clone() | |
| sub = attn[..., :i, :i].clone() | |
| attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) | |
| attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) | |
| value = attn @ v_beta | |
| k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) | |
| last_recurrent_state = ( | |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| core_attn_out = torch.zeros_like(value) | |
| mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) | |
| # for each chunk | |
| for i in range(0, total_sequence_length // chunk_size): | |
| q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] | |
| attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) | |
| v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state | |
| v_new = v_i - v_prime | |
| attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state | |
| core_attn_out[:, :, i] = attn_inter + attn @ v_new | |
| last_recurrent_state = ( | |
| last_recurrent_state * g[:, :, i, -1, None, None].exp() | |
| + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new | |
| ) | |
| if not output_final_state: | |
| last_recurrent_state = None | |
| core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) | |
| core_attn_out = core_attn_out[:, :, :sequence_length] | |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | |
| return core_attn_out, last_recurrent_state | |
| def torch_recurrent_gated_delta_rule( | |
| query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False | |
| ): | |
| initial_dtype = query.dtype | |
| if use_qk_l2norm_in_kernel: | |
| query = l2norm(query, dim=-1, eps=1e-6) | |
| key = l2norm(key, dim=-1, eps=1e-6) | |
| query, key, value, beta, g = [ | |
| x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) | |
| ] | |
| batch_size, num_heads, sequence_length, k_head_dim = key.shape | |
| v_head_dim = value.shape[-1] | |
| scale = 1 / (query.shape[-1] ** 0.5) | |
| query = query * scale | |
| core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value) | |
| last_recurrent_state = ( | |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) | |
| if initial_state is None | |
| else initial_state.to(value) | |
| ) | |
| for i in range(sequence_length): | |
| q_t = query[:, :, i] | |
| k_t = key[:, :, i] | |
| v_t = value[:, :, i] | |
| g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) | |
| beta_t = beta[:, :, i].unsqueeze(-1) | |
| last_recurrent_state = last_recurrent_state * g_t | |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) | |
| delta = (v_t - kv_mem) * beta_t | |
| last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) | |
| core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) | |
| if not output_final_state: | |
| last_recurrent_state = None | |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | |
| return core_attn_out, last_recurrent_state | |
| class InternS2PreviewGatedDeltaNet(nn.Module): | |
| def __init__(self, config: InternS2PreviewConfig, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_v_heads = config.linear_num_value_heads | |
| self.num_k_heads = config.linear_num_key_heads | |
| self.head_k_dim = config.linear_key_head_dim | |
| self.head_v_dim = config.linear_value_head_dim | |
| self.key_dim = self.head_k_dim * self.num_k_heads | |
| self.value_dim = self.head_v_dim * self.num_v_heads | |
| self.conv_kernel_size = config.linear_conv_kernel_dim | |
| self.layer_idx = layer_idx | |
| self.activation = config.hidden_act | |
| self.act = ACT2FN[config.hidden_act] | |
| self.layer_norm_epsilon = config.rms_norm_eps | |
| # QKV | |
| self.conv_dim = self.key_dim * 2 + self.value_dim | |
| self.conv1d = nn.Conv1d( | |
| in_channels=self.conv_dim, | |
| out_channels=self.conv_dim, | |
| bias=False, | |
| kernel_size=self.conv_kernel_size, | |
| groups=self.conv_dim, | |
| padding=self.conv_kernel_size - 1, | |
| ) | |
| # time step projection (discretization) | |
| # instantiate once and copy inv_dt in init_weights of PretrainedModel | |
| self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) | |
| A = torch.empty(self.num_v_heads).uniform_(0, 16) | |
| self.A_log = nn.Parameter(torch.log(A)) | |
| self.norm = ( | |
| InternS2PreviewRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon) | |
| if FusedRMSNormGated is None | |
| else FusedRMSNormGated( | |
| self.head_v_dim, | |
| eps=self.layer_norm_epsilon, | |
| activation=self.activation, | |
| device=torch.cuda.current_device(), | |
| dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), | |
| ) | |
| ) | |
| self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) | |
| self.causal_conv1d_fn = causal_conv1d_fn | |
| self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update | |
| self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule | |
| self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule | |
| if not is_fast_path_available: | |
| logger.warning_once( | |
| "The fast path is not available because one of the required library is not installed. Falling back to " | |
| "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and" | |
| " https://github.com/Dao-AILab/causal-conv1d" | |
| ) | |
| self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False) | |
| self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) | |
| self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cache_params: InternS2PreviewDynamicCache | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) | |
| # Set up dimensions for reshapes later | |
| batch_size, seq_len, _ = hidden_states.shape | |
| use_precomputed_states = ( | |
| cache_params is not None | |
| and cache_params.has_previous_state | |
| and seq_len == 1 | |
| and cache_position is not None | |
| ) | |
| # getting projected states from cache if it exists | |
| if cache_params is not None: | |
| conv_state = cache_params.conv_states[self.layer_idx] | |
| recurrent_state = cache_params.recurrent_states[self.layer_idx] | |
| mixed_qkv = self.in_proj_qkv(hidden_states) | |
| mixed_qkv = mixed_qkv.transpose(1, 2) | |
| z = self.in_proj_z(hidden_states) | |
| z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) | |
| b = self.in_proj_b(hidden_states) | |
| a = self.in_proj_a(hidden_states) | |
| if use_precomputed_states: | |
| # 2. Convolution sequence transformation | |
| # NOTE: the conv state is updated in `causal_conv1d_update` | |
| mixed_qkv = self.causal_conv1d_update( | |
| mixed_qkv, | |
| conv_state, | |
| self.conv1d.weight.squeeze(1), | |
| self.conv1d.bias, | |
| self.activation, | |
| ) | |
| else: | |
| if cache_params is not None: | |
| conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) | |
| cache_params.conv_states[self.layer_idx] = conv_state | |
| if self.causal_conv1d_fn is not None: | |
| mixed_qkv = self.causal_conv1d_fn( | |
| x=mixed_qkv, | |
| weight=self.conv1d.weight.squeeze(1), | |
| bias=self.conv1d.bias, | |
| activation=self.activation, | |
| seq_idx=None, | |
| ) | |
| else: | |
| mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) | |
| mixed_qkv = mixed_qkv.transpose(1, 2) | |
| query, key, value = torch.split( | |
| mixed_qkv, | |
| [ | |
| self.key_dim, | |
| self.key_dim, | |
| self.value_dim, | |
| ], | |
| dim=-1, | |
| ) | |
| query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) | |
| key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) | |
| value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) | |
| beta = b.sigmoid() | |
| # If the model is loaded in fp16, without the .float() here, A might be -inf | |
| g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) | |
| if self.num_v_heads // self.num_k_heads > 1: | |
| query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | |
| key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | |
| if not use_precomputed_states: | |
| core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g=g, | |
| beta=beta, | |
| initial_state=None, | |
| output_final_state=cache_params is not None, | |
| use_qk_l2norm_in_kernel=True, | |
| ) | |
| else: | |
| core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g=g, | |
| beta=beta, | |
| initial_state=recurrent_state, | |
| output_final_state=cache_params is not None, | |
| use_qk_l2norm_in_kernel=True, | |
| ) | |
| # Update cache | |
| if cache_params is not None: | |
| cache_params.recurrent_states[self.layer_idx] = last_recurrent_state | |
| # reshape input data into 2D tensor | |
| core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) | |
| z = z.reshape(-1, self.head_v_dim) | |
| core_attn_out = self.norm(core_attn_out, z) | |
| core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) | |
| output = self.out_proj(core_attn_out) | |
| return output | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb | |
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): | |
| """Applies Rotary Position Embedding to the query and key tensors. | |
| Removes the interleaving of cos and sin from GLM | |
| Args: | |
| q (`torch.Tensor`): The query tensor. | |
| k (`torch.Tensor`): The key tensor. | |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
| sin (`torch.Tensor`): The sine part of the rotary embedding. | |
| unsqueeze_dim (`int`, *optional*, defaults to 1): | |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |
| Returns: | |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
| """ | |
| cos = cos.unsqueeze(unsqueeze_dim) | |
| sin = sin.unsqueeze(unsqueeze_dim) | |
| # Keep half or full tensor for later concatenation | |
| rotary_dim = cos.shape[-1] | |
| q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] | |
| k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] | |
| # Apply rotary embeddings on the first half or full tensor | |
| q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) | |
| k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) | |
| # Concatenate back to full shape | |
| q_embed = torch.cat([q_embed, q_pass], dim=-1) | |
| k_embed = torch.cat([k_embed, k_pass], dim=-1) | |
| return q_embed, k_embed | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """ | |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
| """ | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| def eager_attention_forward( | |
| module: nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: torch.Tensor | None, | |
| scaling: float, | |
| dropout: float = 0.0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ): | |
| key_states = repeat_kv(key, module.num_key_value_groups) | |
| value_states = repeat_kv(value, module.num_key_value_groups) | |
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling | |
| if attention_mask is not None: | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) | |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, attn_weights | |
| class InternS2PreviewRMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.zeros(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()) | |
| # Llama does x.to(float16) * w whilst InternS2Preview is (x * w).to(float16) | |
| # See https://github.com/huggingface/transformers/pull/29402 | |
| output = output * (1.0 + self.weight.float()) | |
| return output.type_as(x) | |
| def extra_repr(self): | |
| return f"{tuple(self.weight.shape)}, eps={self.eps}" | |
| class InternS2PreviewAttention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__(self, config: InternS2PreviewConfig, layer_idx: int): | |
| super().__init__() | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads | |
| self.scaling = self.head_dim**-0.5 | |
| self.attention_dropout = config.attention_dropout | |
| self.is_causal = True | |
| self.q_proj = nn.Linear( | |
| config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias | |
| ) | |
| self.k_proj = nn.Linear( | |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.v_proj = nn.Linear( | |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | |
| ) | |
| self.o_proj = nn.Linear( | |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias | |
| ) | |
| self.q_norm = InternS2PreviewRMSNorm( | |
| self.head_dim, eps=config.rms_norm_eps | |
| ) # unlike olmo, only on the head dim! | |
| self.k_norm = InternS2PreviewRMSNorm( | |
| self.head_dim, eps=config.rms_norm_eps | |
| ) # thus post q_norm does not need reshape | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: torch.Tensor | None, | |
| past_key_values: Cache | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| **kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> tuple[torch.Tensor, torch.Tensor | None]: | |
| input_shape = hidden_states.shape[:-1] | |
| hidden_shape = (*input_shape, -1, self.head_dim) | |
| query_states, gate = torch.chunk( | |
| self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 | |
| ) | |
| gate = gate.reshape(*input_shape, -1) | |
| query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) | |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) | |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| if past_key_values is not None: | |
| # sin and cos are specific to RoPE models; cache_position needed for the static cache | |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( | |
| self.config._attn_implementation, eager_attention_forward | |
| ) | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| scaling=self.scaling, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | |
| attn_output = attn_output * torch.sigmoid(gate) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, attn_weights | |
| class InternS2PreviewMLP(nn.Module): | |
| def __init__(self, config: InternS2PreviewConfig, intermediate_size: int): | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = intermediate_size | |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, x): | |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| return down_proj | |
| class InternS2PreviewExperts(nn.Module): | |
| """Collection of expert weights stored as 3D tensors.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.num_experts = config.num_experts | |
| self.hidden_dim = config.hidden_size | |
| self.intermediate_dim = config.moe_intermediate_size | |
| self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | |
| self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| top_k_index: torch.Tensor, | |
| top_k_weights: torch.Tensor, | |
| ) -> torch.Tensor: | |
| final_hidden_states = torch.zeros_like(hidden_states) | |
| with torch.no_grad(): | |
| expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) | |
| expert_mask = expert_mask.permute(2, 1, 0) | |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | |
| for expert_idx in expert_hit: | |
| expert_idx = expert_idx[0] | |
| if expert_idx == self.num_experts: | |
| continue | |
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | |
| current_state = hidden_states[token_idx] | |
| gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) | |
| current_hidden_states = self.act_fn(gate) * up | |
| current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) | |
| current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] | |
| final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) | |
| return final_hidden_states | |
| class InternS2PreviewTopKRouter(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.top_k = config.num_experts_per_tok | |
| self.num_experts = config.num_experts | |
| self.hidden_dim = config.hidden_size | |
| self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) | |
| def forward(self, hidden_states): | |
| hidden_states = hidden_states.reshape(-1, self.hidden_dim) | |
| router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) | |
| router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) | |
| router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) | |
| router_top_value /= router_top_value.sum(dim=-1, keepdim=True) | |
| router_top_value = router_top_value.to(router_logits.dtype) | |
| router_scores = router_top_value | |
| return router_logits, router_scores, router_indices | |
| class InternS2PreviewSparseMoeBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.gate = InternS2PreviewTopKRouter(config) | |
| self.experts = InternS2PreviewExperts(config) | |
| self.shared_expert = InternS2PreviewMLP(config, intermediate_size=config.shared_expert_intermediate_size) | |
| self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) | |
| def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| batch_size, sequence_length, hidden_dim = hidden_states.shape | |
| hidden_states_reshaped = hidden_states.view(-1, hidden_dim) | |
| shared_expert_output = self.shared_expert(hidden_states_reshaped) | |
| _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) | |
| expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) | |
| shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output | |
| expert_output += shared_expert_output | |
| expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) | |
| return expert_output | |
| class InternS2PreviewDecoderLayer(GradientCheckpointingLayer): | |
| def __init__(self, config: InternS2PreviewTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.layer_type = config.layer_types[layer_idx] | |
| if self.layer_type == "linear_attention": | |
| self.linear_attn = InternS2PreviewGatedDeltaNet(config, layer_idx) | |
| elif self.layer_type == "full_attention": | |
| self.self_attn = InternS2PreviewAttention(config, layer_idx) | |
| self.mlp = InternS2PreviewSparseMoeBlock(config) | |
| self.input_layernorm = InternS2PreviewRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = InternS2PreviewRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| **kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> torch.FloatTensor: | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| # Token Mixer | |
| if self.layer_type == "linear_attention": | |
| hidden_states = self.linear_attn( | |
| hidden_states=hidden_states, | |
| cache_params=past_key_values, | |
| cache_position=cache_position, | |
| attention_mask=attention_mask, | |
| ) | |
| elif self.layer_type == "full_attention": | |
| # Self Attention | |
| hidden_states, _ = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| cache_position=cache_position, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| hidden_states = residual + hidden_states | |
| # Fully Connected | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| # For the MoE layers, we need to unpack | |
| if isinstance(hidden_states, tuple): | |
| hidden_states, _ = hidden_states | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| class InternS2PreviewPreTrainedModel(PreTrainedModel): | |
| config: InternS2PreviewConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["InternS2PreviewDecoderLayer", "InternS2PreviewVisionBlock"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn = True | |
| _supports_sdpa = True | |
| _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] | |
| _can_record_outputs = { | |
| "router_logits": OutputRecorder(InternS2PreviewTopKRouter, index=0), | |
| "hidden_states": InternS2PreviewDecoderLayer, | |
| "attentions": InternS2PreviewAttention, | |
| } | |
| _is_stateful = True | |
| def _init_weights(self, module): | |
| super()._init_weights(module) | |
| if isinstance(module, InternS2PreviewGatedDeltaNet): | |
| init.ones_(module.dt_bias) | |
| init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_()) | |
| # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) | |
| elif isinstance(module, InternS2PreviewRMSNorm): | |
| init.zeros_(module.weight) | |
| elif isinstance(module, InternS2PreviewExperts): | |
| init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) | |
| init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) | |
| elif isinstance(module, InternS2PreviewSparseMoeBlock): | |
| init.normal_(module.gate.weight, mean=0.0, std=self.config.initializer_range) | |
| elif isinstance(module, InternS2PreviewVisionRotaryEmbedding): | |
| inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim)) | |
| init.copy_(module.inv_freq, inv_freq) | |
| class InternS2PreviewVisionMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) | |
| self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, hidden_state): | |
| return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) | |
| class InternS2PreviewVisionPatchEmbed(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| self.patch_size = config.patch_size | |
| self.temporal_patch_size = config.temporal_patch_size | |
| self.in_channels = config.in_channels | |
| self.embed_dim = config.hidden_size | |
| kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] | |
| self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| target_dtype = self.proj.weight.dtype | |
| hidden_states = hidden_states.view( | |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size | |
| ) | |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) | |
| return hidden_states | |
| class InternS2PreviewVisionPatchMerger(nn.Module): | |
| def __init__(self, config: InternS2PreviewVisionConfig, use_postshuffle_norm=False) -> None: | |
| super().__init__() | |
| self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) | |
| self.use_postshuffle_norm = use_postshuffle_norm | |
| self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.act_fn = nn.GELU() | |
| self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) | |
| x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) | |
| return x | |
| def apply_rotary_pos_emb_vision( | |
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| orig_q_dtype = q.dtype | |
| orig_k_dtype = k.dtype | |
| q, k = q.float(), k.float() | |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| q_embed = q_embed.to(orig_q_dtype) | |
| k_embed = k_embed.to(orig_k_dtype) | |
| return q_embed, k_embed | |
| class InternS2PreviewVisionAttention(nn.Module): | |
| def __init__(self, config: InternS2PreviewVisionConfig) -> None: | |
| super().__init__() | |
| self.dim = config.hidden_size | |
| self.num_heads = config.num_heads | |
| self.head_dim = self.dim // self.num_heads | |
| self.num_key_value_groups = 1 # needed for eager attention | |
| self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) | |
| self.proj = nn.Linear(self.dim, self.dim) | |
| self.scaling = self.head_dim**-0.5 | |
| self.config = config | |
| self.attention_dropout = 0.0 | |
| self.is_causal = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| rotary_pos_emb: torch.Tensor | None = None, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| seq_length = hidden_states.shape[0] | |
| query_states, key_states, value_states = ( | |
| self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| ) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) | |
| query_states = query_states.transpose(0, 1).unsqueeze(0) | |
| key_states = key_states.transpose(0, 1).unsqueeze(0) | |
| value_states = value_states.transpose(0, 1).unsqueeze(0) | |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( | |
| self.config._attn_implementation, eager_attention_forward | |
| ) | |
| if is_flash_attention_requested(self.config): | |
| # Flash Attention: Use cu_seqlens for variable length attention | |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() | |
| attn_output, _ = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask=None, | |
| scaling=self.scaling, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| cu_seq_lens_q=cu_seqlens, | |
| cu_seq_lens_k=cu_seqlens, | |
| max_length_q=max_seqlen, | |
| max_length_k=max_seqlen, | |
| is_causal=False, | |
| **kwargs, | |
| ) | |
| else: | |
| # Other implementations: Process each chunk separately | |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] | |
| splits = [ | |
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) | |
| ] | |
| attn_outputs = [ | |
| attention_interface( | |
| self, | |
| q, | |
| k, | |
| v, | |
| attention_mask=None, | |
| scaling=self.scaling, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| is_causal=False, | |
| **kwargs, | |
| )[0] | |
| for q, k, v in zip(*splits) | |
| ] | |
| attn_output = torch.cat(attn_outputs, dim=1) | |
| attn_output = attn_output.reshape(seq_length, -1).contiguous() | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| class InternS2PreviewVisionBlock(GradientCheckpointingLayer): | |
| def __init__(self, config, attn_implementation: str = "sdpa") -> None: | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) | |
| self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) | |
| self.attn = InternS2PreviewVisionAttention(config=config) | |
| self.mlp = InternS2PreviewVisionMLP(config=config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| rotary_pos_emb: torch.Tensor | None = None, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| hidden_states = hidden_states + self.attn( | |
| self.norm1(hidden_states), | |
| cu_seqlens=cu_seqlens, | |
| rotary_pos_emb=rotary_pos_emb, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) | |
| return hidden_states | |
| class InternS2PreviewVisionModel(InternS2PreviewPreTrainedModel): | |
| config: InternS2PreviewVisionConfig | |
| _no_split_modules = ["InternS2PreviewVisionBlock"] | |
| _can_record_outputs = { | |
| "hidden_states": InternS2PreviewVisionBlock, | |
| "attentions": InternS2PreviewVisionAttention, | |
| } | |
| def __init__(self, config, *inputs, **kwargs) -> None: | |
| super().__init__(config, *inputs, **kwargs) | |
| self.spatial_merge_size = config.spatial_merge_size | |
| self.patch_size = config.patch_size | |
| self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size | |
| self.patch_embed = InternS2PreviewVisionPatchEmbed( | |
| config=config, | |
| ) | |
| self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) | |
| self.num_grid_per_side = int(config.num_position_embeddings**0.5) | |
| head_dim = config.hidden_size // config.num_heads | |
| self.rotary_pos_emb = InternS2PreviewVisionRotaryEmbedding(head_dim // 2) | |
| self.blocks = nn.ModuleList([InternS2PreviewVisionBlock(config) for _ in range(config.depth)]) | |
| self.merger = InternS2PreviewVisionPatchMerger( | |
| config=config, | |
| use_postshuffle_norm=False, | |
| ) | |
| self.gradient_checkpointing = False | |
| self.post_init() | |
| def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: | |
| merge_size = self.spatial_merge_size | |
| grid_thw_list = grid_thw.tolist() | |
| max_hw = max(max(h, w) for _, h, w in grid_thw_list) | |
| freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) | |
| device = freq_table.device | |
| total_tokens = sum(t * h * w for t, h, w in grid_thw_list) | |
| pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) | |
| offset = 0 | |
| for num_frames, height, width in grid_thw_list: | |
| merged_h, merged_w = height // merge_size, width // merge_size | |
| block_rows = torch.arange(merged_h, device=device) # block row indices | |
| block_cols = torch.arange(merged_w, device=device) # block col indices | |
| intra_row = torch.arange(merge_size, device=device) # intra-block row offsets | |
| intra_col = torch.arange(merge_size, device=device) # intra-block col offsets | |
| # Compute full-resolution positions | |
| row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] | |
| col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] | |
| row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) | |
| col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) | |
| coords = torch.stack((row_idx, col_idx), dim=-1) | |
| if num_frames > 1: | |
| coords = coords.repeat(num_frames, 1) | |
| num_tokens = coords.shape[0] | |
| pos_ids[offset : offset + num_tokens] = coords | |
| offset += num_tokens | |
| embeddings = freq_table[pos_ids] # lookup rotary embeddings | |
| embeddings = embeddings.flatten(1) | |
| return embeddings | |
| def fast_pos_embed_interpolate(self, grid_thw): | |
| grid_thw_list = grid_thw.tolist() | |
| grid_ts = [row[0] for row in grid_thw_list] | |
| grid_hs = [row[1] for row in grid_thw_list] | |
| grid_ws = [row[2] for row in grid_thw_list] | |
| device = self.pos_embed.weight.device | |
| idx_list = [[] for _ in range(4)] | |
| weight_list = [[] for _ in range(4)] | |
| for t, h, w in grid_thw_list: | |
| h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) | |
| w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) | |
| h_idxs_floor = h_idxs.int() | |
| w_idxs_floor = w_idxs.int() | |
| h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) | |
| w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) | |
| dh = h_idxs - h_idxs_floor | |
| dw = w_idxs - w_idxs_floor | |
| base_h = h_idxs_floor * self.num_grid_per_side | |
| base_h_ceil = h_idxs_ceil * self.num_grid_per_side | |
| indices = [ | |
| (base_h[None].T + w_idxs_floor[None]).flatten(), | |
| (base_h[None].T + w_idxs_ceil[None]).flatten(), | |
| (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), | |
| (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), | |
| ] | |
| weights = [ | |
| ((1 - dh)[None].T * (1 - dw)[None]).flatten(), | |
| ((1 - dh)[None].T * dw[None]).flatten(), | |
| (dh[None].T * (1 - dw)[None]).flatten(), | |
| (dh[None].T * dw[None]).flatten(), | |
| ] | |
| for i in range(4): | |
| idx_list[i].extend(indices[i].tolist()) | |
| weight_list[i].extend(weights[i].tolist()) | |
| idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) | |
| weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) | |
| pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] | |
| patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] | |
| patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) | |
| patch_pos_embeds_permute = [] | |
| merge_size = self.config.spatial_merge_size | |
| for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): | |
| pos_embed = pos_embed.repeat(t, 1) | |
| pos_embed = ( | |
| pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) | |
| .permute(0, 1, 3, 2, 4, 5) | |
| .flatten(0, 4) | |
| ) | |
| patch_pos_embeds_permute.append(pos_embed) | |
| patch_pos_embeds = torch.cat(patch_pos_embeds_permute) | |
| return patch_pos_embeds | |
| def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): | |
| The final hidden states of the model. | |
| grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): | |
| The temporal, height and width of feature shape of each image in LLM. | |
| Returns: | |
| `torch.Tensor`: hidden_states. | |
| """ | |
| hidden_states = self.patch_embed(hidden_states) | |
| pos_embeds = self.fast_pos_embed_interpolate(grid_thw) | |
| hidden_states = hidden_states + pos_embeds | |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) | |
| seq_len, _ = hidden_states.size() | |
| hidden_states = hidden_states.reshape(seq_len, -1) | |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) | |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) | |
| position_embeddings = (emb.cos(), emb.sin()) | |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( | |
| dim=0, | |
| # Select dtype based on the following factors: | |
| # - FA2 requires that cu_seqlens_q must have dtype int32 | |
| # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw | |
| # See https://github.com/huggingface/transformers/pull/34852 for more information | |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, | |
| ) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| for blk in self.blocks: | |
| hidden_states = blk( | |
| hidden_states, | |
| cu_seqlens=cu_seqlens, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| merged_hidden_states = self.merger(hidden_states) | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=hidden_states, | |
| pooler_output=merged_hidden_states, | |
| ) | |
| class InternS2PreviewTextRotaryEmbedding(nn.Module): | |
| inv_freq: torch.Tensor # fix linting for `register_buffer` | |
| def __init__(self, config: InternS2PreviewTextConfig, device=None): | |
| super().__init__() | |
| self.max_seq_len_cached = config.max_position_embeddings | |
| self.original_max_seq_len = config.max_position_embeddings | |
| self.config = config | |
| self.rope_type = self.config.rope_parameters["rope_type"] | |
| rope_init_fn: Callable = self.compute_default_rope_parameters | |
| if self.rope_type != "default": | |
| rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
| inv_freq, self.attention_scaling = rope_init_fn(self.config, device) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) | |
| self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10]) | |
| def compute_default_rope_parameters( | |
| config: InternS2PreviewTextConfig | None = None, | |
| device: Optional["torch.device"] = None, | |
| seq_len: int | None = None, | |
| ) -> tuple["torch.Tensor", float]: | |
| """ | |
| Computes the inverse frequencies according to the original RoPE implementation | |
| Args: | |
| config ([`~transformers.PreTrainedConfig`]): | |
| The model configuration. | |
| device (`torch.device`): | |
| The device to use for initialization of the inverse frequencies. | |
| seq_len (`int`, *optional*): | |
| The current sequence length. Unused for this type of RoPE. | |
| Returns: | |
| Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the | |
| post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). | |
| """ | |
| base = config.rope_parameters["rope_theta"] | |
| partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) | |
| head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads | |
| dim = int(head_dim * partial_rotary_factor) | |
| attention_factor = 1.0 # Unused in this type of RoPE | |
| # Compute the inverse frequencies | |
| inv_freq = 1.0 / ( | |
| base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) | |
| ) | |
| return inv_freq, attention_factor | |
| # power user: used with advanced RoPE types (e.g. dynamic rope) | |
| def forward(self, x, position_ids): | |
| # In contrast to other models, InternS2Preview has different position ids for the grids | |
| # So we expand the inv_freq to shape (3, ...) | |
| if position_ids.ndim == 2: | |
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) | |
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) | |
| position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) | |
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" | |
| with maybe_autocast(device_type=device_type, enabled=False): # Force float32 | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) | |
| freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() * self.attention_scaling | |
| sin = emb.sin() * self.attention_scaling | |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
| def apply_interleaved_mrope(self, freqs, mrope_section): | |
| """Apply interleaved MRoPE to 3D rotary embeddings. | |
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to | |
| interleaved [THWTHWTHW...TT], preserving frequency continuity. | |
| args: | |
| x: (3, bs, seq_len, head_dim // 2) | |
| mrope_section: (3,) | |
| returns: | |
| x_t: (bs, seq_len, head_dim // 2) | |
| """ | |
| freqs_t = freqs[0] # just overwrite the first dimension T | |
| for dim, offset in enumerate((1, 2), start=1): # H, W | |
| length = mrope_section[dim] * 3 | |
| idx = slice(offset, length, 3) | |
| freqs_t[..., idx] = freqs[dim, ..., idx] | |
| return freqs_t | |
| class InternS2PreviewTextModel(InternS2PreviewPreTrainedModel): | |
| def __init__(self, config: InternS2PreviewTextConfig): | |
| super().__init__(config) | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) | |
| self.layers = nn.ModuleList( | |
| [InternS2PreviewDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
| ) | |
| self.norm = InternS2PreviewRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = InternS2PreviewTextRotaryEmbedding(config=config) | |
| self.gradient_checkpointing = False | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| use_cache: bool | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> BaseModelOutputWithPast: | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| if use_cache and past_key_values is None: | |
| past_key_values = InternS2PreviewDynamicCache(config=self.config) | |
| if cache_position is None: | |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| cache_position = torch.arange( | |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
| ) | |
| # mrope: the hard coded `3` is for temporal, height and width. | |
| if position_ids is None: | |
| position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) | |
| elif position_ids.ndim == 2: | |
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) | |
| if position_ids.ndim == 3 and position_ids.shape[0] == 4: | |
| text_position_ids = position_ids[0] | |
| position_ids = position_ids[1:] | |
| else: | |
| text_position_ids = position_ids[0] | |
| causal_mask = create_causal_mask( | |
| config=self.config, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| position_ids=text_position_ids, | |
| ) | |
| linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) | |
| hidden_states = inputs_embeds | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
| for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): | |
| layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask | |
| hidden_states = decoder_layer( | |
| hidden_states, | |
| position_embeddings=position_embeddings, | |
| attention_mask=layer_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| **kwargs, | |
| ) | |
| hidden_states = self.norm(hidden_states) | |
| return InternS2PreviewModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=past_key_values, | |
| ) | |
| def _update_linear_attn_mask(self, attention_mask, cache_position): | |
| """ | |
| NOTE: Left-padding is used for linear attention mask. | |
| No need for zeroing states when | |
| 1. Cached forward | |
| 2. Attending to all inputs | |
| """ | |
| linear_attn_mask = attention_mask | |
| if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): | |
| linear_attn_mask = None | |
| return linear_attn_mask | |
| class InternS2PreviewTimeSeriesAttention(nn.Module): | |
| """Multi-headed attention from 'Attention Is All You Need' paper""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| dropout: float = 0.0, | |
| is_decoder: bool = False, | |
| bias: bool = True, | |
| is_causal: bool = False, | |
| layer_idx: int | None = None, | |
| config: InternS2PreviewTimeSeriesConfig | None = None, | |
| num_key_value_groups: int = 1, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| self.config = config | |
| if (self.head_dim * num_heads) != self.embed_dim: | |
| raise ValueError( | |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" | |
| f" and `num_heads`: {num_heads})." | |
| ) | |
| self.scaling = self.head_dim**-0.5 | |
| self.is_decoder = is_decoder | |
| self.is_causal = is_causal | |
| if layer_idx is None and is_decoder: | |
| logger.warning_once( | |
| f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " | |
| "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " | |
| "when creating this class." | |
| ) | |
| self.layer_idx = layer_idx | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.num_key_value_groups = num_key_value_groups | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| key_value_states: torch.Tensor | None = None, | |
| past_key_values: Cache | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| output_attentions: bool = False, | |
| cache_position: torch.Tensor | None = None, | |
| # TODO: we need a refactor so that the different attention modules can get their specific kwargs | |
| # ATM, we have mixed things encoder, decoder, and encoder-decoder attn | |
| **kwargs: Unpack[FlashAttentionKwargs], | |
| ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: | |
| """Input shape: Batch x Time x Channel""" | |
| # if key_value_states are provided this layer is used as a cross-attention layer | |
| # for the decoder | |
| is_cross_attention = key_value_states is not None | |
| # determine input shapes | |
| bsz, tgt_len = hidden_states.shape[:-1] | |
| q_input_shape = (bsz, tgt_len, -1, self.head_dim) | |
| # Scaling is susceptible to floating point arithmetics' inprecisions | |
| # which can lead to different results (this is dependent from model | |
| # to model, e.g. intern_s2_preview_time_series is one such case). We therefore keep the | |
| # original order of scaling to follow the original implementation | |
| # and enforce no scaling (1.0) in the attention call below. | |
| query_states = self.q_proj(hidden_states) * self.scaling | |
| query_states = query_states.view(*q_input_shape) | |
| query_states = query_states.transpose(1, 2).contiguous() | |
| # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` | |
| if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): | |
| is_updated = past_key_values.is_updated.get(self.layer_idx) | |
| if is_cross_attention: | |
| # after the first generated id, we can subsequently re-use all key/value_states from cache | |
| past_key_values.is_updated[self.layer_idx] = True | |
| past_key_values = past_key_values.cross_attention_cache | |
| else: | |
| past_key_values = past_key_values.self_attention_cache | |
| # use key_value_states if cross attention | |
| current_states = key_value_states if key_value_states is not None else hidden_states | |
| if is_cross_attention and past_key_values and is_updated: | |
| # reuse k,v, cross_attentions | |
| key_states = past_key_values.layers[self.layer_idx].keys | |
| value_states = past_key_values.layers[self.layer_idx].values | |
| else: | |
| key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) | |
| value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) | |
| key_states = key_states.transpose(1, 2).contiguous() | |
| value_states = value_states.transpose(1, 2).contiguous() | |
| if past_key_values is not None: | |
| # save all key/value_states to cache to be re-used for fast auto-regressive generation | |
| cache_position = cache_position if not is_cross_attention else None | |
| key_states, value_states = past_key_values.update( | |
| key_states, value_states, self.layer_idx, {"cache_position": cache_position} | |
| ) | |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( | |
| self.config._attn_implementation, eager_attention_forward | |
| ) | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=0.0 if not self.training else self.dropout, | |
| scaling=1.0, | |
| output_attentions=output_attentions, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, attn_weights | |
| class InternS2PreviewTimeSeriesEncoderLayer(GradientCheckpointingLayer): | |
| def __init__(self, config: InternS2PreviewTimeSeriesConfig): | |
| super().__init__() | |
| self.embed_dim = config.d_model | |
| self.self_attn = InternS2PreviewTimeSeriesAttention( | |
| embed_dim=self.embed_dim, | |
| num_heads=config.encoder_attention_heads, | |
| dropout=config.attention_dropout, | |
| config=config, | |
| ) | |
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) | |
| self.dropout = config.dropout | |
| self.activation_fn = ACT2FN[config.activation_function] | |
| self.activation_dropout = config.activation_dropout | |
| self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) | |
| self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) | |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| output_attentions: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`): attention mask of size | |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| """ | |
| residual = hidden_states | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states = self.activation_fn(self.fc1(hidden_states)) | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) | |
| hidden_states = self.fc2(hidden_states) | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | |
| hidden_states = residual + hidden_states | |
| if hidden_states.dtype == torch.float16: | |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 | |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
| return hidden_states, attn_weights | |
| class InternS2PreviewTimeSeriesEncoder(PreTrainedModel): | |
| config: InternS2PreviewTimeSeriesConfig | |
| base_model_prefix = "model" | |
| main_input_name = "input_features" | |
| input_modalities = ("audio", "text") | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["InternS2PreviewTimeSeriesEncoderLayer", "WhisperDecoderLayer"] | |
| _supports_flash_attn = True | |
| _supports_sdpa = True | |
| _supports_flex_attn = True | |
| _can_compile_fullgraph = True | |
| def __init__(self, config: InternS2PreviewTimeSeriesConfig): | |
| super().__init__(config) | |
| self.dropout = config.dropout | |
| self.layerdrop = config.encoder_layerdrop | |
| self.embed_dim = config.d_model | |
| self.num_mel_bins = config.num_mel_bins | |
| # self.padding_idx = config.pad_token_id | |
| self.max_source_positions = config.max_source_positions | |
| self.embed_scale = math.sqrt(self.embed_dim) if config.scale_embedding else 1.0 | |
| self.conv1 = nn.Conv1d(self.num_mel_bins, self.embed_dim, kernel_size=3, padding=1) | |
| self.conv2 = nn.Conv1d(self.embed_dim, self.embed_dim, kernel_size=3, stride=2, padding=1) | |
| self.embed_positions = nn.Embedding(self.max_source_positions, self.embed_dim) | |
| self.layers = nn.ModuleList( | |
| [InternS2PreviewTimeSeriesEncoderLayer(config) for _ in range(config.encoder_layers)] | |
| ) | |
| self.layer_norm = nn.LayerNorm(config.d_model) | |
| self.gradient_checkpointing = False | |
| self.post_init() | |
| self.mask_type = None | |
| self.chunk_length = None | |
| self.adapt_in = nn.Linear(config.ts_adapt_in_dim, 80) | |
| self.adapt_out = nn.Linear(self.embed_dim, config.ts_adapt_out_dim) | |
| def _freeze_parameters(self): | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| self._requires_grad = False | |
| def get_input_embeddings(self) -> nn.Module: | |
| return self.conv1 | |
| def set_input_embeddings(self, value: nn.Module): | |
| self.conv1 = value | |
| def define_masktype(self, masktype, chunk_length=None): | |
| self.mask_type = masktype | |
| self.chunk_length = chunk_length | |
| def _make_causal_mask( | |
| self, input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 | |
| ): | |
| """ | |
| Make causal mask used for bi-directional self-attention. | |
| """ | |
| bsz, tgt_len = input_ids_shape | |
| mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) | |
| mask_cond = torch.arange(mask.size(-1), device=device) | |
| mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | |
| mask = mask.to(dtype) | |
| if past_key_values_length > 0: | |
| mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) | |
| return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) | |
| # Copied from transformers.models.bart.modeling_bart._expand_mask | |
| def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): | |
| """ | |
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
| """ | |
| # print(mask.size()) | |
| bsz, src_len = mask.size() | |
| tgt_len = tgt_len if tgt_len is not None else src_len | |
| expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
| inverted_mask = 1.0 - expanded_mask | |
| return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) | |
| def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): | |
| # create causal mask | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| combined_attention_mask = None | |
| if input_shape[-1] > 1: | |
| combined_attention_mask = self._make_causal_mask( | |
| input_shape, | |
| inputs_embeds.dtype, | |
| device=inputs_embeds.device, | |
| past_key_values_length=past_key_values_length, | |
| ) | |
| if attention_mask is not None: | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| expanded_attn_mask = self._expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | |
| combined_attention_mask = ( | |
| expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask | |
| ) | |
| return combined_attention_mask | |
| def prepare_chunk_attention_mask(self, attention_mask, input_shape, inputs_embeds): | |
| block_size = round(self.chunk_length / 4 * 2) | |
| matrix_size = input_shape[1] | |
| matrix = torch.ones(matrix_size, matrix_size) | |
| num_full_blocks = round(matrix_size // block_size) | |
| remainder = matrix_size % block_size | |
| for i in range(num_full_blocks): | |
| row_start = i * block_size | |
| col_start = i * block_size | |
| matrix[row_start : row_start + block_size, col_start : col_start + block_size] = torch.zeros( | |
| block_size, block_size | |
| ) | |
| if remainder > 0: | |
| last_row_start = num_full_blocks * block_size | |
| last_col_start = num_full_blocks * block_size | |
| matrix[last_row_start : last_row_start + remainder, last_col_start : last_col_start + remainder] = ( | |
| torch.zeros(remainder, remainder) | |
| ) | |
| matrix = matrix * -65504 | |
| matrix = matrix.unsqueeze(0).unsqueeze(0).repeat(input_shape[0], 1, 1, 1) | |
| attention_mask = matrix.to(inputs_embeds.device) | |
| return attention_mask | |
| def forward( | |
| self, | |
| input_features, | |
| attention_mask=None, | |
| head_mask=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| # (N, T, C) -> (T, N, C) -> (N, C, T) | |
| input_features = input_features.permute(1, 0, 2) | |
| input_features = self.adapt_in(input_features) | |
| input_features = input_features.permute(1, 2, 0) | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # (N, C, T) -> (N, C, T//2) | |
| inputs_embeds = nn.functional.gelu(self.conv1(input_features)) | |
| inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) | |
| # (N, C, T) -> (N, T, C) | |
| inputs_embeds = inputs_embeds.permute(0, 2, 1) # torch.Size([1, 100, 768]) | |
| embed_pos = self.embed_positions.weight # torch.Size([1500, 768]) | |
| if inputs_embeds.shape[1] > embed_pos.shape[0]: | |
| target_len = inputs_embeds.shape[1] | |
| padding = [0, 0, 0, target_len - embed_pos.shape[0]] | |
| embed_pos = nn.functional.pad(embed_pos, pad=padding, mode="constant", value=0) | |
| hidden_states = inputs_embeds[:, : embed_pos.shape[0], :] + embed_pos | |
| else: | |
| hidden_states = inputs_embeds + embed_pos[: inputs_embeds.shape[1], :] | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| input_shape = inputs_embeds.size()[:-1] | |
| past_key_values_length = 0 | |
| attention_mask = None | |
| if self.mask_type == "chunk": | |
| attention_mask = self.prepare_chunk_attention_mask(attention_mask, input_shape, inputs_embeds) | |
| else: | |
| attention_mask = self._prepare_decoder_attention_mask( | |
| attention_mask, input_shape, inputs_embeds, past_key_values_length | |
| ) | |
| if head_mask is not None: | |
| assert head_mask.size()[0] == (len(self.layers)), ( | |
| f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." | |
| ) | |
| for idx, encoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (self.layer_norm(hidden_states),) | |
| # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | |
| to_drop = False | |
| if self.training: | |
| dropout_probability = torch.rand([]) | |
| if dropout_probability < self.layerdrop: # skip the layer | |
| to_drop = True | |
| if to_drop: | |
| layer_outputs = (None, None) | |
| else: | |
| if self.gradient_checkpointing and self.training: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs, output_attentions) | |
| return custom_forward | |
| layer_outputs = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(encoder_layer), | |
| hidden_states, | |
| attention_mask, | |
| (head_mask[idx] if head_mask is not None else None), | |
| ) | |
| else: | |
| layer_outputs = encoder_layer( | |
| hidden_states, | |
| attention_mask, | |
| # TODO test whether OK | |
| # layer_head_mask=(head_mask[idx] if head_mask is not None else None), | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| # (N, T, C) -> (T, N, C) | |
| hidden_states = hidden_states.permute(1, 0, 2) | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = self.adapt_out(hidden_states) | |
| # (T, N, C) -> (N, T, C) | |
| hidden_states = hidden_states.permute(1, 0, 2) | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) | |
| return ModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) | |
| class InternS2PreviewTimeSeriesConcatSubsampling(nn.Module): | |
| def __init__(self, in_channels: int, concat_size: int): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels * concat_size | |
| def forward(self, ts_signals: torch.Tensor, ts_lens: torch.Tensor): | |
| if ts_signals.shape[1] % 2 != 0: | |
| ts_signals = ts_signals[:, :-1, :] | |
| even_frames = ts_signals[:, ::2, :] | |
| odd_frames = ts_signals[:, 1::2, :] | |
| ts_signals = torch.cat((even_frames, odd_frames), dim=2) | |
| ts_lens = ts_lens // 2 | |
| return ts_signals, ts_lens | |
| class InternS2PreviewTimeSeriesFixPositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 20000): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model, dtype=torch.float) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1).to(torch.float32) # (max_len, 1, d_model) | |
| self.register_buffer("pe", pe, persistent=True) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (seq_len, batch_size, d_model) | |
| x = x + self.pe[: x.size(0), :] | |
| return x.clone() | |
| class InternS2PreviewTimeSeriesMultiChannelAdaptiveSubsampling(nn.Module): | |
| def __init__(self, hidden_dim: int = 128, nhead: int = 8, num_encoder_layers: int = 1): | |
| super().__init__() | |
| self.conv = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=5, stride=1, padding=2) | |
| encoder_layers = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers) | |
| self.pos_encoder = InternS2PreviewTimeSeriesFixPositionalEncoding(d_model=hidden_dim) | |
| self.subsampling = InternS2PreviewTimeSeriesConcatSubsampling(128, 2) | |
| def forward( | |
| self, inputs: torch.Tensor, input_lens: torch.Tensor, sr: torch.Tensor, channels: torch.LongTensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| features, feature_lens = self.forward_patch(inputs, input_lens, sr, channels) | |
| outputs = features | |
| output_lens = feature_lens | |
| return outputs, output_lens | |
| def forward_patch( | |
| self, inputs: torch.Tensor, input_lens: torch.Tensor, sr: torch.Tensor, channels: torch.LongTensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| sr = sr.float() | |
| strides = torch.floor(160 / ((1 + torch.exp(-sr / 100)) ** 6)) | |
| patch_sizes = strides * 2 | |
| patched_outputs = [] | |
| output_lens = [] | |
| for i in range(len(inputs)): | |
| le = input_lens[i] | |
| channel = channels[i] | |
| seq = inputs[i, :le, :channel] # [seq_len, channel] | |
| ps = patch_sizes[i].item() | |
| st = strides[i].item() | |
| output_len = torch.ceil((le - ps) / st) + 1 | |
| pad_len = ((output_len - 1) * st + ps - le).long().item() | |
| if seq.ndim == 1: | |
| seq = seq.unsqueeze(-1) | |
| seq = nn.functional.pad(seq, (0, 0, 0, pad_len), "constant", 0) | |
| assert output_len > 0, (seq.shape, ps, st, le, output_len) | |
| output_lens.append(output_len) | |
| indices = (torch.arange(0, output_len * st, st).unsqueeze(1) + torch.arange(ps)).long() | |
| patched = seq[indices] | |
| output = self.forward_encoder(patched) # [num_patch, D] | |
| patched_outputs.append(output) | |
| outputs = nn.utils.rnn.pad_sequence(patched_outputs, batch_first=True) | |
| output_lens = torch.tensor(output_lens).squeeze().to(outputs.device).long() | |
| if output_lens.ndim == 0: | |
| output_lens = output_lens.unsqueeze(0) | |
| outputs, output_lens = self.subsampling(outputs.clone(), output_lens.clone()) | |
| return outputs, output_lens | |
| def forward_encoder(self, x: torch.Tensor) -> torch.Tensor: | |
| num_patch, patch_len, C = x.shape | |
| # conv1 | |
| x = x.reshape(num_patch * C, 1, patch_len) # each channel is treated as an independent sample into conv1 | |
| x = nn.functional.relu(self.conv(x)) # [B*C, D1, L] | |
| x = x.permute(2, 0, 1) # [L, B*C, D1] | |
| x = self.pos_encoder(x) # [L, B*C, D1] | |
| x = self.transformer_encoder(x.to(torch.bfloat16)) | |
| x = x.mean(0) | |
| x = x.reshape(num_patch, C, -1) | |
| return x.mean(1) | |
| class InternS2PreviewTimeSeriesProjector(nn.Module): | |
| def __init__(self, config: InternS2PreviewTimeSeriesConfig): | |
| super().__init__() | |
| self.layer_norm = nn.LayerNorm(config.ts_hidden_dim) | |
| self.linear_1 = nn.Linear(config.ts_hidden_dim, config.out_hidden_size) | |
| self.act = ACT2FN[config.activation_function] | |
| self.linear_2 = nn.Linear(config.out_hidden_size, config.out_hidden_size) | |
| def forward(self, ts_features): | |
| hidden_states = self.layer_norm(ts_features) | |
| hidden_states = self.linear_1(hidden_states) | |
| hidden_states = self.act(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| return hidden_states | |
| class InternS2PreviewTimeSeriesModel(PreTrainedModel): | |
| main_input_name = "time_series_signals" | |
| _supports_flash_attn_2 = False | |
| config_class = InternS2PreviewTimeSeriesConfig | |
| _no_split_modules = ["InternS2PreviewTimeSeriesEncoderLayer"] | |
| def __init__(self, config: InternS2PreviewTimeSeriesConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.encoder_embed = InternS2PreviewTimeSeriesMultiChannelAdaptiveSubsampling() | |
| self.encoder = InternS2PreviewTimeSeriesEncoder(config) | |
| self.projector = InternS2PreviewTimeSeriesProjector(config) | |
| def get_input_embeddings(self): | |
| return self.encoder_embed | |
| def make_pad_mask(self, lengths: torch.Tensor) -> torch.Tensor: | |
| r""" | |
| lengths (`torch.Tensor` of shape `(batch_size,)`): | |
| A 1-D tensor containing sentence lengths. | |
| Returns: | |
| A 2-D bool tensor, where masked positions are filled with `True` and non-masked positions are filled with `False`. | |
| Example: | |
| ```python | |
| >>> lengths = torch.tensor([1, 3, 2, 5]) | |
| >>> self.make_pad_mask(lengths) | |
| tensor([[False, True, True, True, True], | |
| [False, False, False, True, True], | |
| [False, False, True, True, True], | |
| [False, False, False, False, False]]) | |
| ``` | |
| """ | |
| assert lengths.ndim == 1, lengths.ndim | |
| max_len = lengths.max() | |
| n = lengths.size(0) | |
| seq_range = torch.arange(0, max_len, device=lengths.device) | |
| expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) | |
| return expaned_lengths >= lengths.unsqueeze(-1) | |
| def forward( | |
| self, | |
| time_series_signals: torch.FloatTensor = None, | |
| ts_lens: torch.Tensor = None, | |
| sr: torch.Tensor = None, | |
| channels: torch.LongTensor | None = None, | |
| output_hidden_states: bool = None, | |
| return_dict: bool = None, | |
| time_series_embeds: torch.FloatTensor = None, | |
| ): | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if time_series_signals is None and time_series_embeds is None: | |
| raise ValueError("You have to specify time_series_signals or time_series_embeds") | |
| if ( | |
| time_series_embeds is not None | |
| and len(time_series_embeds.shape) == 3 | |
| and time_series_embeds.shape[-1] == self.config.ts_adapt_in_dim | |
| ): | |
| time_series_embeds = time_series_embeds | |
| else: | |
| if time_series_signals.ndim == 3: | |
| time_series_embeds, ts_lens = self.encoder_embed(time_series_signals, ts_lens, sr, channels) | |
| else: | |
| raise ValueError(f"wrong time_series_signals size: {time_series_signals[0].shape}") | |
| # [B, 64000, 1] -> [B, 200, 256] -> [B, 100, 1024] | |
| encoder_outputs = self.encoder( | |
| input_features=time_series_embeds, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # ts_lens after encoder | |
| ts_lens = (ts_lens + 1) // 2 | |
| assert torch.all(ts_lens > 0), f"The length of time_series_embeds is so small. ts_lens: {ts_lens}" | |
| src_key_padding_mask = self.make_pad_mask(ts_lens) | |
| last_hidden_state = encoder_outputs.last_hidden_state | |
| ts_pad_mask = src_key_padding_mask | |
| ts_embeds = self.projector(last_hidden_state) | |
| return ts_embeds, ts_pad_mask | |
| class InternS2PreviewModelOutputWithPast(ModelOutput): | |
| r""" | |
| past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
| `past_key_values` input) to speed up sequential decoding. | |
| rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): | |
| The rope index difference between sequence length and multimodal rope. | |
| """ | |
| last_hidden_state: torch.FloatTensor | None = None | |
| past_key_values: Cache | None = None | |
| hidden_states: tuple[torch.FloatTensor] | None = None | |
| attentions: tuple[torch.FloatTensor] | None = None | |
| rope_deltas: torch.LongTensor | None = None | |
| router_logits: tuple[torch.FloatTensor] | None = None | |
| class InternS2PreviewCausalLMOutputWithPast(ModelOutput): | |
| r""" | |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
| Language modeling loss (for next-token prediction). | |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
| past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
| `past_key_values` input) to speed up sequential decoding. | |
| rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): | |
| The rope index difference between sequence length and multimodal rope. | |
| """ | |
| loss: torch.FloatTensor | None = None | |
| logits: torch.FloatTensor | None = None | |
| past_key_values: Cache | None = None | |
| hidden_states: tuple[torch.FloatTensor] | None = None | |
| attentions: tuple[torch.FloatTensor] | None = None | |
| rope_deltas: torch.LongTensor | None = None | |
| router_logits: tuple[torch.FloatTensor] | None = None | |
| aux_loss: torch.FloatTensor | None = None | |
| class InternS2PreviewModel(InternS2PreviewPreTrainedModel): | |
| base_model_prefix = "model" | |
| _checkpoint_conversion_mapping = {} | |
| # Reference: fix gemma3 grad acc #37208 | |
| accepts_loss_kwargs = False | |
| config: InternS2PreviewConfig | |
| _no_split_modules = [ | |
| "Qwen3_5MoeTextDecoderLayer", | |
| "Qwen3_5MoeVisionBlock", | |
| "InternS2PreviewTimeSeriesEncoderLayer", | |
| ] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.visual = InternS2PreviewVisionModel._from_config(config.vision_config) | |
| self.language_model = InternS2PreviewTextModel._from_config(config.text_config) | |
| self.rope_deltas = None # cache rope_deltas here | |
| self.time_series = InternS2PreviewTimeSeriesModel._from_config(config.ts_config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| def get_rope_index( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Different from the original implementation, InternS2Preview use timestamps rather than absolute time position ids.""" | |
| # Since we use timestamps to separate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split | |
| if video_grid_thw is not None: | |
| video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) | |
| video_grid_thw[:, 0] = 1 | |
| image_grid_thw_list = image_grid_thw.tolist() if image_grid_thw is not None else None | |
| video_grid_thw_list = video_grid_thw.tolist() if video_grid_thw is not None else None | |
| spatial_merge_size = self.config.vision_config.spatial_merge_size | |
| image_token_id = self.config.image_token_id | |
| video_token_id = self.config.video_token_id | |
| vision_start_token_id = self.config.vision_start_token_id | |
| mrope_position_deltas = [] | |
| total_input_ids = input_ids | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(total_input_ids) | |
| position_ids = torch.zeros( | |
| 3, | |
| input_ids.shape[0], | |
| input_ids.shape[1], | |
| dtype=input_ids.dtype, | |
| device=input_ids.device, | |
| ) | |
| image_index, video_index = 0, 0 | |
| attention_mask = attention_mask.to(total_input_ids.device) | |
| for i, input_ids in enumerate(total_input_ids): | |
| input_ids = input_ids[attention_mask[i] == 1] | |
| image_nums, video_nums = 0, 0 | |
| vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) | |
| vision_tokens = input_ids[vision_start_indices + 1] | |
| image_nums = (vision_tokens == image_token_id).sum() | |
| video_nums = (vision_tokens == video_token_id).sum() | |
| input_tokens = input_ids.tolist() | |
| llm_pos_ids_list: list = [] | |
| st = 0 | |
| remain_images, remain_videos = image_nums, video_nums | |
| for _ in range(image_nums + video_nums): | |
| if image_token_id in input_tokens and remain_images > 0: | |
| ed_image = input_tokens.index(image_token_id, st) | |
| else: | |
| ed_image = len(input_tokens) + 1 | |
| if video_token_id in input_tokens and remain_videos > 0: | |
| ed_video = input_tokens.index(video_token_id, st) | |
| else: | |
| ed_video = len(input_tokens) + 1 | |
| if ed_image < ed_video: | |
| t, h, w = image_grid_thw_list[image_index] | |
| image_index += 1 | |
| remain_images -= 1 | |
| ed = ed_image | |
| else: | |
| t, h, w = video_grid_thw_list[video_index] | |
| video_index += 1 | |
| remain_videos -= 1 | |
| ed = ed_video | |
| llm_grid_t, llm_grid_h, llm_grid_w = ( | |
| t, | |
| h // spatial_merge_size, | |
| w // spatial_merge_size, | |
| ) | |
| text_len = ed - st | |
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
| # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) | |
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() | |
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | |
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | |
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) | |
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w | |
| if st < len(input_tokens): | |
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | |
| text_len = len(input_tokens) - st | |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | |
| position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) | |
| mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) | |
| mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) | |
| return position_ids, mrope_position_deltas | |
| def get_video_features( | |
| self, | |
| pixel_values_videos: torch.FloatTensor, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | BaseModelOutputWithPooling: | |
| r""" | |
| pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): | |
| The tensors corresponding to the input videos. | |
| video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each video in LLM. | |
| """ | |
| # Same implementation as for images | |
| return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) | |
| def get_image_features( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | BaseModelOutputWithPooling: | |
| r""" | |
| pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): | |
| The tensors corresponding to the input images. | |
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each image in LLM. | |
| """ | |
| pixel_values = pixel_values.type(self.visual.dtype) | |
| vision_output: BaseModelOutputWithPooling = self.visual( | |
| pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs | |
| ) | |
| image_embeds = vision_output.pooler_output | |
| split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() | |
| image_embeds = torch.split(image_embeds, split_sizes) | |
| vision_output.pooler_output = image_embeds | |
| return vision_output | |
| def get_placeholder_mask( | |
| self, | |
| input_ids: torch.LongTensor, | |
| inputs_embeds: torch.FloatTensor, | |
| image_features: torch.FloatTensor | None = None, | |
| video_features: torch.FloatTensor | None = None, | |
| ): | |
| """ | |
| Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is | |
| equal to the length of multimodal features. If the lengths are different, an error is raised. | |
| """ | |
| if input_ids is None: | |
| special_image_mask = inputs_embeds == self.get_input_embeddings()( | |
| torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| special_image_mask = special_image_mask.all(-1) | |
| special_video_mask = inputs_embeds == self.get_input_embeddings()( | |
| torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| special_video_mask = special_video_mask.all(-1) | |
| else: | |
| special_image_mask = input_ids == self.config.image_token_id | |
| special_video_mask = input_ids == self.config.video_token_id | |
| n_image_tokens = special_image_mask.sum() | |
| special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | |
| if image_features is not None: | |
| torch_compilable_check( | |
| inputs_embeds[special_image_mask].numel() == image_features.numel(), | |
| f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}", | |
| ) | |
| n_video_tokens = special_video_mask.sum() | |
| special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | |
| if video_features is not None: | |
| torch_compilable_check( | |
| inputs_embeds[special_video_mask].numel() == video_features.numel(), | |
| f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}", | |
| ) | |
| return special_image_mask, special_video_mask | |
| def compute_3d_position_ids( | |
| self, | |
| input_ids: torch.Tensor | None, | |
| inputs_embeds: torch.Tensor | None, | |
| image_grid_thw: torch.Tensor | None = None, | |
| video_grid_thw: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| past_key_values: torch.Tensor | None = None, | |
| ) -> torch.Tensor | None: | |
| past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() | |
| can_compute_mrope = input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None) | |
| if can_compute_mrope and (self.rope_deltas is None or past_key_values_length == 0): | |
| position_ids, rope_deltas = self.get_rope_index( | |
| input_ids, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| attention_mask=attention_mask, | |
| ) | |
| self.rope_deltas = rope_deltas | |
| # Use pre-calculated rope-deltas to infer correct 3D position ids | |
| elif self.rope_deltas is not None: | |
| batch_size, seq_length, _ = inputs_embeds.shape | |
| if attention_mask is not None: | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids = position_ids.masked_fill(attention_mask == 0, 0) | |
| position_ids = position_ids.view(1, batch_size, -1).repeat(3, 1, 1).to(inputs_embeds.device) | |
| else: | |
| position_ids = torch.arange(past_key_values_length, past_key_values_length + seq_length) | |
| position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1).to(inputs_embeds.device) | |
| delta = self.rope_deltas.repeat_interleave(batch_size // self.rope_deltas.shape[0], dim=0) | |
| position_ids = position_ids + delta.to(device=position_ids.device) | |
| else: | |
| # Can't build correct 3D positions. Let the model infer it from `cache_position` | |
| position_ids = None | |
| return position_ids | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| pixel_values: torch.Tensor | None = None, | |
| pixel_values_videos: torch.FloatTensor | None = None, | |
| ts_values: torch.FloatTensor | list[torch.FloatTensor] = None, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| ts_lens: torch.Tensor | list[torch.Tensor] = None, | |
| ts_sr: torch.FloatTensor | list[torch.FloatTensor] = None, | |
| ts_channels: torch.LongTensor | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | InternS2PreviewModelOutputWithPast: | |
| r""" | |
| ts_values (`torch.FloatTensor` of shape `(batch_size, seq_len, num_channels)`, *optional*): | |
| The tensors corresponding to the input time series signals. | |
| ts_lens (`torch.Tensor` of shape `(batch_size,)`, *optional*): | |
| The valid lengths of each time series signal in the batch. | |
| ts_sr (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): | |
| The sampling rates of each time series signal in the batch. | |
| """ | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| if pixel_values is not None: | |
| image_outputs: BaseModelOutputWithPooling = self.get_image_features( | |
| pixel_values, image_grid_thw, return_dict=True | |
| ) | |
| image_embeds = image_outputs.pooler_output | |
| image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) | |
| image_mask, _ = self.get_placeholder_mask( | |
| input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds | |
| ) | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | |
| if pixel_values_videos is not None: | |
| video_outputs: BaseModelOutputWithPooling = self.get_video_features( | |
| pixel_values_videos, video_grid_thw, return_dict=True | |
| ) | |
| video_embeds = video_outputs.pooler_output | |
| video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) | |
| _, video_mask = self.get_placeholder_mask( | |
| input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds | |
| ) | |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) | |
| if ts_values is not None: | |
| ts_features, ts_pad_mask = self.get_ts_feature(ts_values, ts_lens, ts_sr, ts_channels) # [B, T, C], [B, T] | |
| ts_features = ts_features[~ts_pad_mask].to( | |
| inputs_embeds.device, inputs_embeds.dtype | |
| ) # [num_valid_ts_tokens, C] | |
| B, N, C = inputs_embeds.shape | |
| input_ids = input_ids.reshape(B * N) | |
| inputs_embeds = inputs_embeds.reshape(B * N, C) | |
| # replace ts_token in inputs_embeds and attention_mask | |
| ts_placeholder = input_ids == self.config.ts_token_id | |
| n_ts_placeholders = ts_placeholder.sum().item() | |
| n_ts_tokens = ts_features.size(0) | |
| assert n_ts_placeholders == n_ts_tokens, ( | |
| f"[ERROR]: Mismatch: <TS_CONTEXT> tokens={n_ts_placeholders}, ts_embeds_valid={n_ts_tokens}" | |
| ) | |
| try: | |
| # TODO why not scatter? | |
| inputs_embeds[ts_placeholder] = inputs_embeds[ts_placeholder] * 0.0 + ts_features | |
| except Exception as e: | |
| print( | |
| f"warning: {e}, inputs_embeds[selected].shape={inputs_embeds[ts_placeholder].shape}, ts_embeds_valid.shape={ts_features.shape}" | |
| ) | |
| inputs_embeds[ts_placeholder] = inputs_embeds[ts_placeholder] * 0.0 + n_ts_tokens[:n_ts_placeholders] | |
| inputs_embeds = inputs_embeds.reshape(B, N, C) | |
| if position_ids is None: | |
| position_ids = self.compute_3d_position_ids( | |
| input_ids=input_ids, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| ) | |
| outputs = self.language_model( | |
| input_ids=None, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| **kwargs, | |
| ) | |
| return InternS2PreviewModelOutputWithPast( | |
| **outputs, | |
| rope_deltas=self.rope_deltas, | |
| ) | |
| def get_ts_feature( | |
| self, | |
| ts_values: torch.FloatTensor | list[torch.FloatTensor], | |
| ts_lens: torch.Tensor | list[torch.Tensor], | |
| sr: torch.FloatTensor | list[torch.FloatTensor], | |
| ts_channels: torch.LongTensor | None = None, | |
| ) -> tuple[torch.FloatTensor, torch.Tensor]: | |
| r""" | |
| ts_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, input_size)`): | |
| The time series values to be fed to a model. | |
| ts_lens (`torch.Tensor` of shape `(batch_size,)`): | |
| The length of the time series. | |
| sr (`torch.FloatTensor` of shape `(batch_size,)`): | |
| The sampling rate of the time series. | |
| """ | |
| ts_embeds, ts_pad_mask = self.time_series( | |
| time_series_signals=ts_values, | |
| ts_lens=ts_lens, | |
| sr=sr, | |
| channels=ts_channels, | |
| output_hidden_states=False, | |
| return_dict=True, | |
| ) | |
| return ts_embeds, ts_pad_mask | |
| # NOTE: Cannot inherit from `Qwen3_5MoeForCausalLM` here due to a converter limitation: when the modular | |
| # `__init__` contains statements before `super().__init__()`, `_fix_init_location` unconditionally moves | |
| # `super().__init__()` to the top, which prevents pre-processing `config` (e.g. extracting `config.text_config`) | |
| # before the parent call. The full body is therefore written out explicitly. | |
| class InternS2PreviewForCausalLM(InternS2PreviewPreTrainedModel, GenerationMixin): | |
| # NOTE: `config` is annotated as `InternS2PreviewConfig` rather than `InternS2PreviewTextConfig` because remote | |
| # code only exposes a single AutoConfig entry. The auto factory checks `model_class.config_class` against | |
| # the loaded config type; a mismatch raises ValueError, so the annotation must match `InternS2PreviewConfig`. | |
| config: InternS2PreviewConfig | |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} | |
| _tp_plan = {"lm_head": "colwise_gather_output"} | |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} | |
| _keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"] | |
| def __init__(self, config: InternS2PreviewConfig | InternS2PreviewTextConfig): | |
| if isinstance(config, InternS2PreviewConfig): | |
| config = config.text_config | |
| super().__init__(config) | |
| self.model = InternS2PreviewTextModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.router_aux_loss_coef = config.router_aux_loss_coef | |
| self.num_experts = config.num_experts | |
| self.num_experts_per_tok = config.num_experts_per_tok | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def load_balancing_loss_func( | |
| gate_logits: torch.Tensor | tuple[torch.Tensor] | None, | |
| num_experts: int | None = None, | |
| top_k=2, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor | int: | |
| r""" | |
| Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. | |
| See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss | |
| function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between | |
| experts is too unbalanced. | |
| Args: | |
| gate_logits: | |
| Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of | |
| shape [batch_size X sequence_length, num_experts]. | |
| num_experts: | |
| Number of experts | |
| top_k: | |
| The number of experts to route per-token, can be also interpreted as the `top-k` routing | |
| parameter. | |
| attention_mask (`torch.Tensor`, *optional*): | |
| The attention_mask used in forward function | |
| shape [batch_size X sequence_length] if not None. | |
| Returns: | |
| The auxiliary loss. | |
| """ | |
| if gate_logits is None or not isinstance(gate_logits, tuple): | |
| return 0 | |
| if isinstance(gate_logits, tuple): | |
| compute_device = gate_logits[0].device | |
| concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) | |
| routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) | |
| _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) | |
| expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) | |
| if attention_mask is None: | |
| # Compute the percentage of tokens routed to each experts | |
| tokens_per_expert = torch.mean(expert_mask.float(), dim=0) | |
| # Compute the average probability of routing to these experts | |
| router_prob_per_expert = torch.mean(routing_weights, dim=0) | |
| else: | |
| batch_size, sequence_length = attention_mask.shape | |
| num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) | |
| # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask | |
| expert_attention_mask = ( | |
| attention_mask[None, :, :, None, None] | |
| .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) | |
| .reshape(-1, top_k, num_experts) | |
| .to(compute_device) | |
| ) | |
| # Compute the percentage of tokens routed to each experts | |
| tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( | |
| expert_attention_mask, dim=0 | |
| ) | |
| # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert | |
| router_per_expert_attention_mask = ( | |
| attention_mask[None, :, :, None] | |
| .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) | |
| .reshape(-1, num_experts) | |
| .to(compute_device) | |
| ) | |
| # Compute the average probability of routing to these experts | |
| router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( | |
| router_per_expert_attention_mask, dim=0 | |
| ) | |
| overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) | |
| return overall_loss * num_experts | |
| class InternS2PreviewForConditionalGeneration(InternS2PreviewPreTrainedModel, GenerationMixin): | |
| _checkpoint_conversion_mapping = {} | |
| _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} | |
| # Reference: fix gemma3 grad acc #37208 | |
| accepts_loss_kwargs = False | |
| config: InternS2PreviewConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = InternS2PreviewModel(config) | |
| self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.model.set_input_embeddings(value) | |
| def get_video_features( | |
| self, | |
| pixel_values_videos: torch.FloatTensor, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | BaseModelOutputWithPooling: | |
| r""" | |
| pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): | |
| The tensors corresponding to the input videos. | |
| video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each video in LLM. | |
| """ | |
| return self.model.get_video_features( | |
| pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs | |
| ) | |
| def get_image_features( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | BaseModelOutputWithPooling: | |
| r""" | |
| pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): | |
| The tensors corresponding to the input images. | |
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each image in LLM. | |
| """ | |
| return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| pixel_values: torch.Tensor | None = None, | |
| pixel_values_videos: torch.FloatTensor | None = None, | |
| ts_values: torch.FloatTensor | list[torch.FloatTensor] = None, | |
| image_grid_thw: torch.LongTensor | None = None, | |
| video_grid_thw: torch.LongTensor | None = None, | |
| ts_lens: torch.Tensor | list[torch.Tensor] = None, | |
| ts_sr: torch.FloatTensor | list[torch.FloatTensor] = None, | |
| ts_channels: torch.LongTensor | None = None, | |
| cache_position: torch.LongTensor | None = None, | |
| logits_to_keep: int | torch.Tensor = 0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple | InternS2PreviewCausalLMOutputWithPast: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each image in LLM. | |
| video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | |
| The temporal, height and width of feature shape of each video in LLM. | |
| Example: | |
| ```python | |
| >>> from PIL import Image | |
| >>> import requests | |
| >>> from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration | |
| >>> model = Qwen3_5MoeForConditionalGeneration.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct", dtype="auto", device_map="auto") | |
| >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3.5-35B-A3B-Instruct") | |
| >>> messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", | |
| }, | |
| {"type": "text", "text": "Describe this image in short."}, | |
| ], | |
| } | |
| ] | |
| >>> # Preparation for inference | |
| >>> inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| >>> inputs = inputs.to(model.device) | |
| >>> # Generate | |
| >>> generated_ids = model.generate(**inputs, max_new_tokens=128) | |
| >>> generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| >>> processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| "A woman in a plaid shirt sits on a sandy beach at sunset, smiling as she gives a high-five to a yellow Labrador Retriever wearing a harness. The ocean waves roll in the background." | |
| ```""" | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| ts_values=ts_values, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| ts_lens=ts_lens, | |
| ts_sr=ts_sr, | |
| ts_channels=ts_channels, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs[0] | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| loss = None | |
| if labels is not None: | |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) | |
| aux_loss = None | |
| if kwargs.get("output_router_logits", False): | |
| aux_loss = load_balancing_loss_func( | |
| outputs.router_logits, | |
| self.config.text_config.num_experts, | |
| self.config.text_config.num_experts_per_tok, | |
| attention_mask, | |
| ) | |
| if labels is not None: | |
| loss += self.config.text_config.router_aux_loss_coef * aux_loss.to( | |
| loss.device | |
| ) # make sure to reside in the same device | |
| return InternS2PreviewCausalLMOutputWithPast( | |
| loss=loss, | |
| aux_loss=aux_loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| rope_deltas=outputs.rope_deltas, | |
| router_logits=outputs.router_logits, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| cache_position=None, | |
| position_ids=None, | |
| use_cache=True, | |
| pixel_values=None, | |
| pixel_values_videos=None, | |
| ts_values=None, | |
| image_grid_thw=None, | |
| video_grid_thw=None, | |
| ts_lens=None, | |
| ts_sr=None, | |
| ts_channels=None, | |
| is_first_iteration=False, | |
| **kwargs, | |
| ): | |
| model_inputs = super().prepare_inputs_for_generation( | |
| input_ids, | |
| past_key_values=past_key_values, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| position_ids=position_ids, | |
| use_cache=use_cache, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| ts_values=ts_values, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| ts_lens=ts_lens, | |
| ts_sr=ts_sr, | |
| ts_channels=ts_channels, | |
| is_first_iteration=is_first_iteration, | |
| **kwargs, | |
| ) | |
| if not is_first_iteration and use_cache: | |
| model_inputs["ts_values"] = None | |
| model_inputs["ts_lens"] = None | |
| model_inputs["ts_sr"] = None | |
| model_inputs["ts_channels"] = None | |
| return model_inputs | |
| def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): | |
| # Overwritten -- requires 3D position ids | |
| text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs) | |
| # Early exit in case we are continuing generation from past kv | |
| past_length = 0 | |
| if (cache := model_kwargs.get("past_key_values")) is not None: | |
| past_length = cache.get_seq_length() | |
| if past_length != 0 and self.model.rope_deltas is not None: | |
| text_positions += self.model.rope_deltas | |
| return text_positions | |
| # Otherwise compute 3d position ids for vision tokens and concat with text position ids | |
| if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: | |
| inputs_tensor = model_kwargs["input_ids"] | |
| is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] | |
| if is_input_ids and ( | |
| model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None | |
| ): | |
| model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} | |
| vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs) | |
| self.model.rope_deltas = rope_deltas | |
| else: | |
| vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) | |
| self.model.rope_deltas = torch.zeros( | |
| inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device | |
| ) | |
| # Concatenate "text + vision" positions into [4, bs, seq-len] | |
| text_positions = text_positions[None, ...] | |
| position_ids = torch.cat([text_positions, vision_positions], dim=0) | |
| return position_ids | |
| def _get_image_nums_and_video_nums( | |
| self, | |
| input_ids: torch.LongTensor | None, | |
| inputs_embeds: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Get the number of images and videos for each sample to calculate the separation length of the sample tensor. | |
| These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. | |
| Args: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. | |
| Returns: | |
| image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) | |
| video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) | |
| """ | |
| image_token_id = self.config.image_token_id | |
| video_token_id = self.config.video_token_id | |
| vision_start_token_id = self.config.vision_start_token_id | |
| if inputs_embeds is not None: | |
| vision_start_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| )[..., 0] | |
| image_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| )[..., 0] | |
| video_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| )[..., 0] | |
| else: | |
| vision_start_mask = input_ids == vision_start_token_id | |
| image_mask = input_ids == image_token_id | |
| video_mask = input_ids == video_token_id | |
| vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) | |
| image_nums = torch.sum(vision_first_mask & image_mask, dim=1) | |
| video_nums = torch.sum(vision_first_mask & video_mask, dim=1) | |
| return image_nums, video_nums | |
| def _expand_inputs_for_generation( | |
| self, | |
| expand_size: int = 1, | |
| is_encoder_decoder: bool = False, | |
| input_ids: torch.LongTensor | None = None, | |
| **model_kwargs, | |
| ) -> tuple[torch.LongTensor, dict[str, Any]]: | |
| # Overwritten -- InternS2Preview use timestamps and remove second_per_grid_ts | |
| # Support for expanding tensors without a batch size dimension | |
| # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw | |
| # pixel_values.shape[0] is sum(seqlen_images for samples) | |
| # image_grid_thw.shape[0] is sum(num_images for samples) | |
| if expand_size == 1: | |
| return input_ids, model_kwargs | |
| visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] | |
| def _expand_dict_for_generation_visual(dict_to_expand): | |
| image_grid_thw = model_kwargs.get("image_grid_thw", None) | |
| video_grid_thw = model_kwargs.get("video_grid_thw", None) | |
| image_nums, video_nums = self._get_image_nums_and_video_nums( | |
| input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) | |
| ) | |
| # video_nums: (batch_size,) | |
| # since video_nums is the number of videos in the input dependent on the input_ids(vision_start), | |
| # but InternS2Preview append vision_start to each frame of each video, so we need to recover the real video_nums according to video_grid_thw | |
| if video_grid_thw is not None: | |
| cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0) | |
| cumulative_token_video_counts = torch.cumsum(video_nums, dim=0) | |
| # Find video boundaries in cumulative_frame_counts | |
| video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts) | |
| # example: video_boundary_indices = [3, 5] means video_nums = [4, 2] | |
| video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices])) | |
| def _repeat_interleave_samples(x, lengths, repeat_times): | |
| samples = torch.split(x, lengths) | |
| repeat_args = [repeat_times] + [1] * (x.dim() - 1) | |
| result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) | |
| return result | |
| for key in dict_to_expand: | |
| if key == "pixel_values": | |
| # split images into samples | |
| samples = torch.split(image_grid_thw, list(image_nums)) | |
| # compute the sequence length of images for each sample | |
| lengths = [torch.prod(sample, dim=1).sum() for sample in samples] | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| elif key == "image_grid_thw": | |
| # get the num of images for each sample | |
| lengths = list(image_nums) | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| elif key == "pixel_values_videos": | |
| samples = torch.split(video_grid_thw, list(video_nums)) | |
| lengths = [torch.prod(sample, dim=1).sum() for sample in samples] | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| elif key == "video_grid_thw": | |
| lengths = list(video_nums) | |
| dict_to_expand[key] = _repeat_interleave_samples( | |
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | |
| ) | |
| return dict_to_expand | |
| def _expand_dict_for_generation(dict_to_expand): | |
| for key in dict_to_expand: | |
| if key == "position_ids" and dict_to_expand[key].ndim == 3: | |
| dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=1) | |
| elif ( | |
| key != "cache_position" | |
| and dict_to_expand[key] is not None | |
| and isinstance(dict_to_expand[key], torch.Tensor) | |
| and key not in visual_keys | |
| ): | |
| dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) | |
| return dict_to_expand | |
| model_kwargs = _expand_dict_for_generation_visual(model_kwargs) | |
| if input_ids is not None: | |
| input_ids = input_ids.repeat_interleave(expand_size, dim=0) | |
| model_kwargs = _expand_dict_for_generation(model_kwargs) | |
| if is_encoder_decoder: | |
| if model_kwargs.get("encoder_outputs") is None: | |
| raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") | |
| model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) | |
| return input_ids, model_kwargs | |
| def time_series(self): | |
| return self.model.time_series | |
| def get_ts_feature( | |
| self, | |
| ts_values: torch.FloatTensor | list[torch.FloatTensor], | |
| ts_lens: torch.Tensor | list[torch.Tensor], | |
| sr: torch.FloatTensor | list[torch.FloatTensor], | |
| ): | |
| r""" | |
| ts_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, input_size)`): | |
| The time series values to be fed to a model. | |
| ts_lens (`torch.Tensor` of shape `(batch_size,)`): | |
| The length of the time series. | |
| sr (`torch.FloatTensor` of shape `(batch_size,)`): | |
| The sampling rate of the time series. | |
| """ | |
| return self.model.get_ts_feature(ts_values, ts_lens, sr) | |
| __all__ = [ | |
| "InternS2PreviewVisionModel", | |
| "InternS2PreviewTextModel", | |
| "InternS2PreviewModel", | |
| "InternS2PreviewForCausalLM", | |
| "InternS2PreviewForConditionalGeneration", | |
| "InternS2PreviewPreTrainedModel", | |
| ] | |