Step3-VL-10B / vision_encoder.py
luotingdan
remove some unuse code
f7bf7c1
raw
history blame
17.8 kB
from typing import Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.activations import ACT2FN
from configuration_step_vl import StepRoboticsVisionEncoderConfig
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate last dimension halves (used by RoPE)."""
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(freqs: torch.Tensor,
t: torch.Tensor,
start_index: int = 0,
scale: float = 1.0,
seq_dim: int = -2) -> torch.Tensor:
"""Apply 2D rotary embeddings to queries / keys."""
dtype = t.dtype
if t.ndim == 3:
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:]
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], (
f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
out = torch.cat((t_left, t, t_right), dim=-1)
return out.type(dtype)
class EncoderRope2D(nn.Module):
"""Cacheable 2D rotary positional embedding."""
def __init__(
self,
dim: int,
max_grid_height: int,
max_grid_width: int,
use_cls_token: bool = False,
theta: Union[int, float] = 10000,
max_freq: int = 10,
num_freqs: int = 1,
theta_rescale_factor: float = 1.0,
):
super().__init__()
self.dim = dim
self.max_grid_height = max_grid_height
self.max_grid_width = max_grid_width
self.use_cls_token = use_cls_token
self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
self.max_freq = max_freq
self.num_freqs = num_freqs
cache = self._compute_2d_freqs()
self.register_buffer("freqs_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float],
dim: int) -> torch.Tensor:
freqs = 1.0 / (base**(
torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
return freqs
def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
inv_freq)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
def _compute_2d_freqs(self) -> torch.Tensor:
grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
if self.use_cls_token:
grid_h_range += 1
grid_w_range += 1
inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
self.max_grid_height, self.max_grid_width, -1)
freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
self.max_grid_height, self.max_grid_width, -1)
freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
self.max_grid_height * self.max_grid_width, -1)
if self.use_cls_token:
freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
freqs = freqs[None, None, ...]
return freqs
def forward(self, q: torch.Tensor, k: torch.Tensor,
grid_hw: tuple[int, int]):
# If grid matches cached shape we reuse directly to avoid recomputation.
if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
positions = (rows * self.max_grid_width + cols).reshape(-1).to(
torch.long)
if self.use_cls_token:
positions = torch.cat(
[torch.zeros(1, device=q.device), positions + 1], dim=0)
freqs = self.freqs_cache.index_select(2, positions)
else:
freqs = self.freqs_cache
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
return q, k
class EncoderLayerScale(nn.Module):
"""Per-channel residual scaling used when ls_init_value is set."""
def __init__(self, dim: int, init_values: float):
super().__init__()
self.gamma = nn.Parameter(torch.full((dim,), init_values))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
return hidden_states * self.gamma
class EncoderMLP(nn.Module):
"""Feed-forward network used inside each transformer block."""
def __init__(self, hidden_size: int, intermediate_size: int,
hidden_act: str):
super().__init__()
self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
self.act_fn = ACT2FN[hidden_act]
self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
return hidden_states
class EncoderVisionAttention(nn.Module):
"""Multi-head self attention with optional 2D RoPE."""
def __init__(
self,
hidden_size: int,
num_heads: int,
max_grid_height: int,
max_grid_width: int,
use_cls_token: bool = False,
use_rope2d: bool = True,
rope_theta: Union[int, float] = 10000,
rope_max_freq: int = 10,
rope_num_freqs: int = 1,
rope_theta_rescale_factor: float = 1.0,
rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
):
super().__init__()
if hidden_size % num_heads != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
)
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.rope = None
if use_rope2d:
self.rope = EncoderRope2D(
dim=self.head_dim,
max_grid_height=max_grid_height,
max_grid_width=max_grid_width,
use_cls_token=use_cls_token,
theta=rope_theta,
max_freq=rope_max_freq,
num_freqs=rope_num_freqs,
theta_rescale_factor=rope_theta_rescale_factor,
freqs_for=rope_freqs_for,
)
def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
bsz, seq_len, _ = hidden_states.shape
qkv = F.linear(
hidden_states,
self.in_proj_weight,
self.in_proj_bias,
)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2)
if self.rope is not None:
q, k = self.rope(q, k, grid_hw=grid_hw)
v = v.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
q, k, v, is_causal=False, scale=self.scale)
attn_output = attn_output.transpose(1, 2).reshape(
bsz, seq_len, self.num_heads * self.head_dim)
return self.out_proj(attn_output)
class EncoderVisionBlock(nn.Module):
"""A single Vision Transformer block (self-attention + MLP)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
hidden_act: str,
layer_norm_eps: float,
ls_init_value: Optional[float] = None,
max_grid_height: Optional[int] = None,
max_grid_width: Optional[int] = None,
use_cls_token: bool = False,
use_rope2d: bool = True,
rope_kwargs: Optional[dict] = None,
):
super().__init__()
rope_kwargs = rope_kwargs or {}
self.attn = EncoderVisionAttention(
hidden_size,
num_heads,
max_grid_height=max_grid_height,
max_grid_width=max_grid_width,
use_cls_token=use_cls_token,
use_rope2d=use_rope2d,
**rope_kwargs,
)
self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
intermediate = int(hidden_size * mlp_ratio)
self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
def forward(self, hidden_states: torch.Tensor,
grid_hw: tuple[int, int]) -> torch.Tensor:
# breakpoint()
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
hidden_states = residual + self.ls_1(hidden_states)
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.ls_2(hidden_states)
return hidden_states
class EncoderVisionTransformer(nn.Module):
"""Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
def __init__(
self,
embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
hidden_act: str,
layer_norm_eps: float,
ls_init_value: Optional[float] = None,
max_grid_height: Optional[int] = None,
max_grid_width: Optional[int] = None,
use_cls_token: bool = False,
use_rope2d: bool = True,
rope_kwargs: Optional[dict] = None,
):
super().__init__()
self.layers = depth
rope_kwargs = rope_kwargs or {}
self.resblocks = nn.ModuleList([
EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
layer_norm_eps,
max_grid_height=max_grid_height,
max_grid_width=max_grid_width,
use_cls_token=use_cls_token,
use_rope2d=use_rope2d,
ls_init_value=ls_init_value,
rope_kwargs=rope_kwargs)
for _ in range(depth)
])
def forward(self,
hidden_states: torch.Tensor,
grid_hw: tuple[int, int]) -> torch.Tensor:
for block in self.resblocks:
hidden_states = block(hidden_states, grid_hw=grid_hw)
return hidden_states
class StepRoboticsVisionEncoder(nn.Module):
"""
Vision encoder built from StepRoboticsVisionEncoderConfig.
The encoder performs patch embedding followed by a stack of transformer
blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
StepRoboticVLConfig.vision_config) are expected.
"""
def __init__(self, config: StepRoboticsVisionEncoderConfig):
super().__init__()
self.config = config
# Align commonly used attributes so downstream code (e.g. StepRoboticVL)
# can access them without extra renaming.
self.hidden_size = config.width
self.num_heads = config.heads
self.num_hidden_layers = config.layers
self.patch_size = config.patch_size
self.image_size = config.image_size
self.use_cls_token = getattr(config, "use_cls_token", False)
self.use_rope2d = getattr(config, "use_rope2d", True)
self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
self.layer_norm_eps = config.layer_norm_eps
self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
self.ls_init_value = getattr(config, "ls_init_value", None)
self.hidden_act = config.hidden_act
self.use_ln_pre = getattr(config, "use_ln_pre", False)
self.use_ln_post = getattr(config, "use_ln_post", True)
# Patch embedding.
self.conv1 = nn.Conv2d(in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False)
self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
grid_size = self.image_size // self.patch_size
self.base_grid = (grid_size, grid_size)
if self.use_cls_token:
self.class_embedding = nn.Parameter(
torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
else:
self.class_embedding = None
if self.use_abs_posemb:
self.posemb_grid_size = self.image_size // self.patch_size
self.positional_embedding = nn.Parameter(
(self.hidden_size**-0.5) * torch.randn(
int(self.use_cls_token) + self.posemb_grid_size**2,
self.hidden_size,
))
self.transformer = EncoderVisionTransformer(
embed_dim=self.hidden_size,
depth=self.num_hidden_layers,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps,
ls_init_value=self.ls_init_value,
max_grid_height=self.base_grid[0],
max_grid_width=self.base_grid[1],
use_cls_token=self.use_cls_token,
use_rope2d=self.use_rope2d,
rope_kwargs={
"rope_theta": getattr(config, "rope_theta", 10000),
"rope_max_freq": getattr(config, "rope_max_freq", 10),
"rope_num_freqs": getattr(config, "rope_num_freqs", 1),
"rope_theta_rescale_factor":
getattr(config, "rope_theta_rescale_factor", 1.0),
"rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
},
)
self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
self.hidden_size * 2,
kernel_size=3,
stride=2,
padding=1)
self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
self.hidden_size * 4,
kernel_size=3,
stride=2,
padding=1)
def sample_abs_posemb(self, grid_h: int, grid_w: int):
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
return self.positional_embedding[None, ...]
pos_embed = self.positional_embedding
if self.use_cls_token:
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
self.posemb_grid_size,
-1).permute(0, 3, 1, 2).contiguous())
pos_embed = F.interpolate(pos_embed,
size=(grid_h, grid_w),
mode="bilinear",
align_corners=False)
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
if self.use_cls_token:
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
return pos_embed[None, ...]
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Args:
pixel_values: Image tensor of shape (B, C, H, W).
layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
strip_cls_token: If True and cls token is used, remove it from output.
"""
bsz, _, height, width = pixel_values.shape
grid_h, grid_w = height // self.patch_size, width // self.patch_size
hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
if self.use_cls_token:
cls_token = self.class_embedding.view(1, 1,
-1).expand(bsz, -1, -1)
hidden_state = torch.cat([cls_token, hidden_state], dim=1)
if self.use_abs_posemb:
pos_emb = self.sample_abs_posemb(grid_h, grid_w)
hidden_state = hidden_state + pos_emb
hidden_state = self.ln_pre(hidden_state)
hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
if self.use_ln_post:
hidden_state = self.ln_post(hidden_state)
if self.use_cls_token:
hidden_state = hidden_state[:, 1:, :]
return hidden_state