Step-3.7-Flash / modeling_step3p7.py
luotingdan
update processor config
a9c0171
raw
history blame
56.8 kB
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Tuple, TypedDict, Union
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.masking_utils import (
create_causal_mask,
create_sliding_window_causal_mask,
)
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple, logging
from .configuration_step3p7 import Step3p7Config, Step3p7TextConfig
from .vision_encoder import StepRoboticsVisionEncoder
logger = logging.get_logger(__name__)
_MASK_INPUT_EMBEDS_ARG = (
"inputs_embeds"
if "inputs_embeds" in inspect.signature(create_causal_mask).parameters
else "input_embeds"
)
__all__ = [
"Step3p7Model",
]
class StepVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
patch_pixel_values: Optional[torch.Tensor]
num_patches: list[int]
class StepVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
StepVLImageInputs = Union[StepVLImagePixelInputs, StepVLImageEmbeddingInputs]
def _flatten_embeddings(embeddings) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
)
is_multimodal = is_multimodal.to(inputs_embeds.device)
flattened = flattened.to(inputs_embeds.device)
inputs_embeds[is_multimodal] = flattened
return inputs_embeds
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings,
placeholder_token_id: Union[int, list[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
placeholder_token_id = torch.tensor(
placeholder_token_id, device=input_ids.device
)
return _merge_multimodal_embeddings(
inputs_embeds,
torch.isin(input_ids, placeholder_token_id),
multimodal_embeddings,
)
return _merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
)
class Step3p7PreTrainedModel(PreTrainedModel):
# Link this model family to its configuration class so PreTrainedModel.from_pretrained
# can load the config instead of failing with a NoneType error.
config_class = Step3p7Config
supports_gradient_checkpointing = True
_skip_keys_device_placement = ["past_key_values"]
_keys_to_ignore_on_load_unexpected = [
r"model\.layers\.45\.*",
r"model\.layers\.46\.*",
r"model\.layers\.47\.*",
]
_supports_flash_attn = False
_supports_sdpa = True
_supports_flex_attn = True
_supports_static_cache = True
_supports_attention_backend = True
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path, *model_args, **kwargs
):
key_mapping = getattr(cls, "_checkpoint_conversion_mapping", None)
if key_mapping is not None and kwargs.get("key_mapping") is None:
# Transformers only applies checkpoint renaming when key_mapping is
# passed explicitly; inheriting the class attribute alone is not enough.
kwargs["key_mapping"] = copy.deepcopy(key_mapping)
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
class Step3p7RotaryEmbedding(nn.Module):
def __init__(self, config: Step3p7TextConfig, device=None, layer_idx=None):
super().__init__()
self.layer_idx = layer_idx
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
rope_theta = config.rope_theta
if isinstance(rope_theta, list):
rope_theta = rope_theta[0 if layer_idx is None else layer_idx]
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
partial_rotary_factors = getattr(config, "partial_rotary_factors", None)
if partial_rotary_factors is not None:
partial_rotary_factor = partial_rotary_factors[
0 if layer_idx is None else layer_idx
]
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor
self.config = copy.copy(config)
self.config.rope_theta = rope_theta
self.config.partial_rotary_factor = partial_rotary_factor
if config.rope_parameters is not None:
self.config.rope_parameters = copy.deepcopy(config.rope_parameters)
self.config.rope_parameters["rope_theta"] = rope_theta
self.config.rope_parameters["partial_rotary_factor"] = (
partial_rotary_factor
)
self.rope_type = self.config.rope_parameters.get(
"rope_type", self.config.rope_parameters.get("type")
)
else:
self.rope_type = "default"
self.rope_init_fn = self.compute_default_rope_parameters
if self.rope_type != "default":
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float().to(x.device)
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
with torch.autocast(
device_type=device_type, enabled=False
): # Force float32
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@staticmethod
def compute_default_rope_parameters(
config: Step3p7TextConfig | None = None,
device: Optional["torch.device"] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_theta
partial_rotary_factor = getattr(
config, "partial_rotary_factor", 1.0
)
head_dim = (
getattr(config, "head_dim", None)
or config.hidden_size // config.num_attention_heads
)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
# Compute the inverse frequencies
inv_freq = 1.0 / (
base
** (
torch.arange(0, dim, 2, dtype=torch.int64).to(
device=device, dtype=torch.float
)
/ dim
)
)
return inv_freq, attention_factor
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(
batch, num_key_value_heads * n_rep, slen, head_dim
)
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward.
# Llama4 does not cast attention weights to fp32 here.
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
# breakpoint()
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
@dataclass
class Step3p7CausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
"""
loss: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[list[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
class Step3p7MLP(nn.Module):
def __init__(self, config, intermediate_size=None, swiglu_limit=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
intermediate_size
if intermediate_size is not None
else config.intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size,
self.intermediate_size,
bias=False)
self.up_proj = nn.Linear(self.hidden_size,
self.intermediate_size,
bias=False)
self.down_proj = nn.Linear(self.intermediate_size,
self.hidden_size,
bias=False)
self.act_fn = ACT2FN["silu"]
self.limit = swiglu_limit
def forward(self, x):
up = self.up_proj(x)
gate = self.act_fn(self.gate_proj(x))
if self.limit is not None:
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return self.down_proj(gate * up)
def sigmoid_routing_function(gating_output: torch.Tensor, topk: int,
renormalize: bool):
gating_output = gating_output.float()
gate_prob = torch.sigmoid(gating_output)
gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1)
expert_topk_weight = topk_prob
if renormalize:
expert_topk_weight = expert_topk_weight / torch.sum(
expert_topk_weight, dim=-1, keepdim=True)
return expert_topk_weight, indices
def softmax_routing_function(gating_output: torch.Tensor, top_k: int,
renormalize: bool):
gating_output = gating_output.float()
gate_prob = torch.softmax(gating_output, dim=-1)
gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True)
topk_prob, indices = torch.topk(gate_prob, k=top_k, dim=1)
expert_topk_weight = topk_prob
if renormalize:
expert_topk_weight = expert_topk_weight / torch.sum(
expert_topk_weight, dim=-1, keepdim=True)
return expert_topk_weight, indices.to(torch.int32)
class MoELinear(nn.Module):
def __init__(self, num_experts, in_features, out_features):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty(num_experts, out_features, in_features))
def forward(self, x, expert_id):
x = F.linear(x.float(), self.weight[expert_id].float())
return x
class Step3p7MoEMLP(nn.Module):
def __init__(self, config, swiglu_limit=None):
super().__init__()
self.num_experts = config.moe_num_experts
self.top_k = config.moe_top_k
self.hidden_size = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.use_moe_router_bias = config.use_moe_router_bias
if self.use_moe_router_bias:
self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts,
dtype=torch.float32),
requires_grad=False)
self.custom_routing_function = self.router_bias_func
elif config.moe_router_activation == "sigmoid":
self.custom_routing_function = sigmoid_routing_function
else:
self.custom_routing_function = None
self.need_fp32_gate = config.need_fp32_gate
self.routed_scaling_factor = getattr(config,
"moe_router_scaling_factor", 1.0)
# gating
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
self.act_fn = ACT2FN["silu"]
self.limit = swiglu_limit
self.up_proj = MoELinear(self.num_experts, self.hidden_size,
self.moe_intermediate_size)
self.gate_proj = MoELinear(self.num_experts, self.hidden_size,
self.moe_intermediate_size)
self.down_proj = MoELinear(self.num_experts,
self.moe_intermediate_size,
self.hidden_size)
def router_bias_func(self, gating_output: torch.Tensor, topk: int,
renormalize: bool):
gate_prob = torch.sigmoid(gating_output.float())
gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0)
_, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1)
topk_prob = torch.gather(gate_prob, 1, indices)
expert_topk_weight = topk_prob
if renormalize:
expert_topk_weight = expert_topk_weight / (
torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20)
return expert_topk_weight, indices
def get_expert_output(self, inputs: torch.Tensor, expert_id):
#if self.limit is None:
up = self.up_proj(inputs, expert_id)
gate = self.act_fn(self.gate_proj(inputs, expert_id))
if self.limit is not None:
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return self.down_proj(gate * up, expert_id)
def forward(self, hidden_states):
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.need_fp32_gate:
router_logits = torch.matmul(
hidden_states.to(torch.float32),
self.gate.weight.t().to(torch.float32),
)
else:
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
if self.custom_routing_function:
routing_weights, selected_experts = self.custom_routing_function(
router_logits, self.top_k, renormalize=True)
else:
routing_weights = F.softmax(router_logits,
dim=1,
dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights = routing_weights * self.routed_scaling_factor
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
self.get_expert_output(current_state, expert_idx) *
routing_weights[top_x, idx, None])
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim)
return final_hidden_states
class Step3p7RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
normed = x * torch.rsqrt(variance + self.variance_epsilon)
normed = normed * (self.weight.float() + 1)
return normed.to(dtype)
class Step3p7Attention(nn.Module):
def __init__(self, config: Step3p7TextConfig, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_attention_groups
layer_types = getattr(config, "layer_types", [])
if layer_types:
enable_sliding_window = layer_types[
self.layer_idx] == "sliding_attention"
else:
enable_sliding_window = self.layer_idx % 2 == 0
yarn_only_types = getattr(config, "yarn_only_types", None)
if yarn_only_types and layer_types[
self.layer_idx] not in yarn_only_types:
config.rope_parameters = None
else:
config.rope_parameters = getattr(config, "rope_scaling", None)
self.sliding_window = config.sliding_window
if enable_sliding_window:
self.num_attention_heads = config.attention_other_setting[
"num_attention_heads"]
self.num_key_value_heads = config.attention_other_setting[
"num_attention_groups"]
if self.sliding_window is not None and enable_sliding_window:
self.sliding_window = (self.sliding_window)
else:
self.sliding_window = None
self.head_dim = getattr(config, "head_dim",
config.hidden_size // self.num_attention_heads)
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.rotary_emb = Step3p7RotaryEmbedding(config, layer_idx=layer_idx)
self.q_size = self.num_attention_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.kv_size, bias=False)
self.o_proj = nn.Linear(self.q_size, config.hidden_size, bias=False)
self.attention_dropout = getattr(config, "attention_dropout", 0.0)
self.q_norm = Step3p7RMSNorm(self.head_dim,
eps=config.rms_norm_eps)
self.k_norm = Step3p7RMSNorm(self.head_dim,
eps=config.rms_norm_eps)
self.use_head_wise_attn_gate = config.use_head_wise_attn_gate
if self.use_head_wise_attn_gate:
self.g_proj = nn.Linear(config.hidden_size,
self.num_attention_heads,
bias=False)
self.use_rope = True
use_rope_layers = getattr(config, "use_rope_layers", None)
if use_rope_layers:
self.use_rope = use_rope_layers[self.layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(
self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(
self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(
1, 2)
if self.use_head_wise_attn_gate:
gate_states = self.g_proj(hidden_states)
cos, sin = self.rotary_emb(hidden_states, position_ids)
# cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin)
# query_states, key_states = apply_rotary_pos_emb(query_norm_states, key_norm_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
# TODO: considering FP8;
# RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype,
# but got attn_mask.dtype: long int and query.dtype: c10::BFloat16 instead.
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1)
if self.use_head_wise_attn_gate:
output = attn_output.view(
*attn_output.shape[:-1], self.num_attention_heads,
self.head_dim) * gate_states.unsqueeze(-1).sigmoid()
attn_output = output.view(*attn_output.shape)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Step3p7DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = Step3p7Attention(config, layer_idx)
layer_types = getattr(config, "layer_types", None) or []
if layer_types:
self.attention_type = layer_types[layer_idx]
else:
self.attention_type = (
"sliding_attention" if layer_idx % 2 == 0 else "full_attention"
)
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
if isinstance(moe_layers_enum, str):
moe_layers_idx = [
int(i) for i in moe_layers_enum.split(',') if i.strip()
]
else:
moe_layers_idx = [int(i) for i in moe_layers_enum]
else:
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
self.is_moe_layer = layer_idx in moe_layers_idx
self.use_moe = False
if config.swiglu_limits_shared and config.swiglu_limits_shared[
layer_idx] is not None and config.swiglu_limits_shared[
layer_idx] != 0:
swiglu_limit_shared = config.swiglu_limits_shared[layer_idx]
else:
swiglu_limit_shared = None
if config.swiglu_limits and config.swiglu_limits[
layer_idx] is not None and config.swiglu_limits[layer_idx] != 0:
swiglu_limit = config.swiglu_limits[layer_idx]
else:
swiglu_limit = None
if self.is_moe_layer:
self.moe = Step3p7MoEMLP(config, swiglu_limit=swiglu_limit) #
self.share_expert = Step3p7MLP(
config,
intermediate_size=config.share_expert_dim,
swiglu_limit=swiglu_limit_shared)
self.use_moe = True
else:
self.mlp = Step3p7MLP(config,
intermediate_size=config.intermediate_size,
swiglu_limit=swiglu_limit_shared)
self.input_layernorm = Step3p7RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = Step3p7RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
ffn_output = moe_output + share_output
else:
ffn_output = self.mlp(hidden_states)
if isinstance(ffn_output, tuple):
hidden_states, _ = ffn_output
else:
hidden_states = ffn_output
hidden_states = residual + hidden_states
return hidden_states
class Step3p7TextPreTrainedModel(PreTrainedModel):
# Link this model family to its configuration class so PreTrainedModel.from_pretrained
# can load the config instead of failing with a NoneType error.
config_class = Step3p7TextConfig
supports_gradient_checkpointing = True
_skip_keys_device_placement = ["past_key_values"]
_keys_to_ignore_on_load_unexpected = [
r"model\.layers\.45\.*",
r"model\.layers\.46\.*",
r"model\.layers\.47\.*",
]
_supports_flash_attn = False
_supports_sdpa = True
_supports_flex_attn = True
_supports_static_cache = True
_supports_attention_backend = True
class Step3p7TextModel(Step3p7TextPreTrainedModel, GenerationMixin):
_no_split_modules = ["Step3p7DecoderLayer"]
base_model_prefix = "model"
_tied_weights_keys = ["lm_head.weight"]
config: Step3p7TextConfig
def __init__(self, config: Step3p7TextConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
Step3p7DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = Step3p7RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
layer_types = self.config.layer_types or []
self.has_sliding_layers = (not layer_types or
"sliding_attention" in layer_types)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self, input_ids):
return self.embed_tokens(input_ids)
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict
if return_dict is not None
else getattr(self.config, "return_dict", True)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(
input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length(
) if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens,
past_seen_tokens +
inputs_embeds.shape[1],
device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
mask_kwargs[_MASK_INPUT_EMBEDS_ARG] = inputs_embeds
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping[
"sliding_attention"] = create_sliding_window_causal_mask(
**mask_kwargs)
# # create position embeddings to be shared across the decoder layers
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[:self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states, )
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[
decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Step3p7Model(Step3p7PreTrainedModel, GenerationMixin):
config: Step3p7Config
_tied_weights_keys = ["lm_head.weight"]
base_model_prefix = ""
def __init__(self, config: Step3p7Config):
super().__init__(config)
self.vision_model = StepRoboticsVisionEncoder(config.vision_config)
self.language_model = Step3p7TextModel(config.text_config)
self.vocab_size = config.text_config.vocab_size
self.vit_large_projector = nn.Linear(
config.vision_config.width * 4,
config.text_config.hidden_size,
bias=config.projector_bias)
self.image_placeholder_token_id = config.image_token_id
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings = None,
) -> torch.Tensor:
# breakpoint()
input_ids = input_ids.squeeze(0)
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
else:
is_text = input_ids != self.config.image_token_id
text_ids = input_ids[is_text]
text_embeds = self.language_model.get_input_embeddings(text_ids)
inputs_embeds = torch.empty(input_ids.shape[0],
text_embeds.shape[-1],
dtype=text_embeds.dtype,
device=text_embeds.device)
inputs_embeds[is_text] = text_embeds
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
inputs_embeds = inputs_embeds.unsqueeze(0)
return inputs_embeds
def set_input_embeddings(self, value):
return self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[StepVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
# pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
# patch_pixel_values = flatten_bn(patch_pixel_values,
# concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:])
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
return StepVLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(
self.device) if patch_pixel_values is not None else None,
num_patches=num_patches,
)
if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(
f"Unexpected shape for image_embeds: {image_embeds.shape}")
return StepVLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device),
)
return None
def _process_image_features(self,
image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
HW = int(P ** 0.5)
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW)
image_features = self.vision_model.vit_downsampler1(image_features)
image_features = self.vision_model.vit_downsampler2(image_features)
B, C, HW, HW = image_features.shape
image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1)
image_features = self.vit_large_projector(image_features)
return image_features
def _get_vision_model_output(self,
input_tensor: torch.Tensor) -> torch.Tensor:
return self.vision_model(input_tensor)
def _process_image_input(
self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
image_features = image_input["image_embeds"]
else:
image_features = self._get_vision_model_output(
image_input["pixel_values"])
patch_image_features = self._get_vision_model_output(
image_input["patch_pixel_values"]
) if image_input["patch_pixel_values"] is not None else None
num_patches = image_input["num_patches"]
image_features = self._process_image_features(image_features)
patch_image_features = self._process_image_features(
patch_image_features) if patch_image_features is not None else None
merged_image_features = []
cur_patch_idx = 0
for i, num_patch in enumerate(num_patches):
cur_feature = []
if num_patch > 0:
patch_slice = patch_image_features[
cur_patch_idx:cur_patch_idx + num_patch]
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1]))
cur_feature.append(image_features[i].view(
-1, image_features.shape[-1]))
cur_patch_idx += num_patch
merged_image_features.append(
torch.cat(cur_feature) if len(cur_feature) >
1 else cur_feature[0])
return merged_image_features
def get_multimodal_embeddings(self, **kwargs):
# breakpoint()
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
images: Optional[list[Image.Image]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
input_ids = input_ids
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = Step3p7CausalLMOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
attentions=outputs.attentions,
)
return output if return_dict else output.to_tuple()
class Step3p7ForConditionalGeneration(Step3p7PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^vision_model": "model.vision_model",
r"^model(?!\.(language_model|vision_model))": "model.language_model",
"^vit_large_projector": "model.vit_large_projector",
}
_tied_weights_keys = ["lm_head.weight"]
config: Step3p7Config
def __init__(self, config: Step3p7Config):
super().__init__(config)
self.model = Step3p7Model(config)
self.lm_head = nn.Linear(config.hidden_size,
config.text_config.vocab_size,
bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.vision_model
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Optional[torch.Tensor] = None,
num_patches=None,
patch_pixel_values=None,
patch_newline_mask=None,
image_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Step3p7CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
outputs = self.model(
input_ids=input_ids,
num_patches=num_patches,
patch_pixel_values=patch_pixel_values,
patch_newline_mask=patch_newline_mask,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
los = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.vocab_size
)
return Step3p7CausalLMOutputWithPast(
logits=logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
patch_pixel_values=None,
num_patches=None,
image_embeds=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
generation_cache_position = model_inputs.get("cache_position", cache_position)
is_prefill = past_key_values is None
if generation_cache_position is not None and generation_cache_position.numel() > 0:
is_prefill = generation_cache_position[0].item() == 0
if is_prefill:
# During cached decoding, input ids no longer contain image tokens,
# so pixel values should only be passed at the first step.
model_inputs["pixel_values"] = pixel_values
return model_inputs
def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]:
if key.startswith("language_model."):
return key[len("language_model.") :], True
return key, False