LongCat-Next / modular_longcat_next_visual.py
Locke's picture
code
3e934f1
from typing import Iterable, Optional, Tuple
import numpy as np
from safetensors.torch import load_file
import torch
import torch.utils.checkpoint
from torch import nn
from torch.amp import autocast
from torch.nn import functional as F
from einops import rearrange
from flash_attn import flash_attn_varlen_func
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2RMSNorm,
Qwen2_5_VisionTransformerPretrainedModel,
)
from transformers.utils import logging
from .image_refiner import (
ImageRefinerContainer,
RefinerImageProcessor,
RefinerPipeline,
de_transform,
tensor2pil,
)
from .refiner_modules import FlowMatchEulerDiscreteScheduler
logger = logging.get_logger(__name__)
def uniform_init(*shape):
t = torch.zeros(shape)
nn.init.kaiming_uniform_(t)
return t
class VQEmbedding(nn.Module):
"""VQ embedding module with ema update."""
def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5, init_std=0.02):
super().__init__()
self.ema = ema
self.decay = decay
self.eps = eps
self.restart_unused_codes = restart_unused_codes
self.n_embed = n_embed
self.init_std = init_std
assert self.ema
embed = uniform_init(n_embed + 1, embed_dim).to(torch.float32)
self.embed = nn.Parameter(embed)
self.embed_ema = nn.Parameter(embed[:-1, :].clone())
self.cluster_size_ema = nn.Parameter(torch.ones(n_embed))
del embed
_ = [p.requires_grad_(False) for p in self.parameters()]
@torch.no_grad()
def compute_distances(self, inputs):
codebook_t = self.embed[:-1, :].t()
(embed_dim, _) = codebook_t.shape
inputs_shape = inputs.shape
assert inputs_shape[-1] == embed_dim
inputs_flat = inputs.reshape(-1, embed_dim)
inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True)
codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True)
distances = torch.addmm(
inputs_norm_sq + codebook_t_norm_sq,
inputs_flat,
codebook_t,
alpha=-2.0,
)
distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1]
return distances
@torch.no_grad()
def find_nearest_embedding(self, inputs):
distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1]
embed_idxs = distances.argmin(dim=-1) # use padding index or not
return embed_idxs
@autocast('cuda', enabled=True, dtype=torch.float32)
@torch.no_grad()
def forward(self, inputs):
if inputs.dtype != torch.float32:
inputs = inputs.to(torch.float32)
embed_idxs = self.find_nearest_embedding(inputs)
embeds = self.embed[embed_idxs]
return embeds, embed_idxs
class RQBottleneck(nn.Module):
"""
Quantization bottleneck via Residual Quantization.
Arguments:
latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D)
code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d)
n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook)
If isinstance(n_embed, int), the sizes of all codebooks are same.
shared_codebook (bool): If True, codebooks are shared in all location. If False,
uses separate codebooks along the ``depth'' dimension. (default: False)
restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch
as the new embedding of unused codes in training. (default: True)
"""
def __init__(self,
latent_shape,
code_shape,
n_embed,
decay=0.99,
shared_codebook=False,
restart_unused_codes=True,
commitment_loss='cumsum'
):
super().__init__()
if not len(code_shape) == len(latent_shape) == 3:
raise ValueError("incompatible code shape or latent shape")
if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]):
raise ValueError("incompatible code shape or latent shape")
#residual quantization does not divide feature dims for quantization.
embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2]
self.latent_shape = torch.Size(latent_shape)
self.code_shape = torch.Size(code_shape)
self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))])
self.shared_codebook = shared_codebook
if self.shared_codebook:
if isinstance(n_embed, Iterable) or isinstance(decay, Iterable):
raise ValueError("Shared codebooks are incompatible \
with list types of momentums or sizes: Change it into int")
self.restart_unused_codes = restart_unused_codes
self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])]
self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])]
assert len(self.n_embed) == self.code_shape[-1]
assert len(self.decay) == self.code_shape[-1]
if self.shared_codebook:
codebook0 = VQEmbedding(self.n_embed[0],
embed_dim,
decay=self.decay[0],
restart_unused_codes=restart_unused_codes,
).to(torch.float32)
self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])])
else:
codebooks = [VQEmbedding(self.n_embed[idx],
embed_dim,
decay=self.decay[idx],
restart_unused_codes=restart_unused_codes,
).to(torch.float32) for idx in range(self.code_shape[-1])]
self.codebooks = nn.ModuleList(codebooks)
self.commitment_loss = commitment_loss
def to_code_shape(self, x):
(B, H, W, D) = x.shape
(rH, rW, _) = self.shape_divisor
x = x.reshape(B, H//rH, rH, W//rW, rW, D)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, H//rH, W//rW, -1)
return x
def to_latent_shape(self, x):
(B, h, w, _) = x.shape
(_, _, D) = self.latent_shape
(rH, rW, _) = self.shape_divisor
x = x.reshape(B, h, w, rH, rW, D)
x = x.permute(0, 1, 3, 2, 4, 5)
x = x.reshape(B, h*rH, w*rW, D)
return x
def quantize(self, x):
r"""
Return list of quantized features and the selected codewords by the residual quantization.
The code is selected by the residuals between x and quantized features by the previous codebooks.
Arguments:
x (Tensor): bottleneck feature maps to quantize.
Returns:
quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks.
codes (LongTensor): codewords index, corresponding to quants.
Shape:
- x: (B, h, w, embed_dim)
- quant_list[i]: (B, h, w, embed_dim)
- codes: (B, h, w, d)
"""
B, h, w, embed_dim = x.shape
ori_dtype = x.dtype
x = x.to(torch.float32)
self.codebooks = self.codebooks.to(torch.float32)
residual_feature = x.detach().clone()
quant_list = []
code_list = []
aggregated_quants = torch.zeros_like(x)
for i in range(self.code_shape[-1]):
quant, code = self.codebooks[i](residual_feature)
residual_feature.sub_(quant)
aggregated_quants.add_(quant)
quant_list.append(aggregated_quants.clone().to(dtype=ori_dtype))
code_list.append(code.unsqueeze(-1))
codes = torch.cat(code_list, dim=-1)
return quant_list, codes
def forward(self, x):
x_reshaped = self.to_code_shape(x)
# 强制使用float32精度来执行
quant_list, codes = self.quantize(x_reshaped)
# quant_list, codes = self.quantize(x_reshaped)
commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list)
quants_trunc = self.to_latent_shape(quant_list[-1])
quants_trunc = x + (quants_trunc - x).detach()
'''
if self.shared_codebook:
cur_len = codes.view(-1).shape[0]
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
self.codebook_used[-cur_len:] = codes.view(-1)
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_embed[0]
else:
# info|code: torch.Size([10, 16, 16, 4])
codebook_usage = 0
for idx in range(self.code_shape[-1]):
cur_len = codes[..., idx].view(-1).shape[0]
self.codebook_used[idx, :-cur_len] = self.codebook_used[idx, cur_len:].clone()
self.codebook_used[idx, -cur_len:] = codes[..., idx].view(-1)
codebook_usage += len(torch.unique(self.codebook_used[idx]))
codebook_usage /= (self.n_embed[0] * self.code_shape[-1])
'''
codebook_usage = 0
# (vq_loss, commit_loss, entropy_loss, codebook_usage) # 格式对齐
codebook_loss = [0, commitment_loss, 0, codebook_usage]
return quants_trunc, codebook_loss, codes
def compute_commitment_loss(self, x, quant_list):
r"""
Compute the commitment loss for the residual quantization.
The loss is iteratively computed by aggregating quantized features.
"""
loss_list = []
for idx, quant in enumerate(quant_list):
partial_loss = (x-quant.detach()).pow(2.0).mean()
loss_list.append(partial_loss)
commitment_loss = torch.mean(torch.stack(loss_list))
return commitment_loss
class Qwen2_5_VisionRotaryEmbedding_Modified(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int, device: torch.device) -> torch.Tensor:
self.inv_freq = self.inv_freq.to(device)
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class VisualEncoder(Qwen2_5_VisionTransformerPretrainedModel):
def __init__(self, config):
config._attn_implementation = 'flash_attention_2'
super().__init__(config)
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding_Modified(config.hidden_size // config.num_heads // 2)
self.gradient_checkpointing = False
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
del self.merger # register visual.merger in visual_bridge_model
def get_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.down_proj.weight.dtype
def get_device(self) -> torch.device:
return self.blocks[0].mlp.down_proj.weight.device
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
require_window_index: bool = False,
):
'''
pixel_values.shape=[NumOfPatches, 1176]
grid_thw.shape=[NumOfSamples, 3]. [grid_t,grid_h,grid_w]
'''
hidden_states = pixel_values.to(torch.bfloat16)
grid_thw = grid_thw.to(pixel_values.device)
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings)
else:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
)
if require_window_index:
return hidden_states, window_index
return hidden_states
class OmniVisualBridge(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
self.hidden_size = self.config.hidden_size * (self.merge_size**2)
self.window_index = self.config.window_size
self.ln_q = Qwen2RMSNorm(self.config.hidden_size, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, self.config.out_hidden_size),
)
def forward(self, x: torch.Tensor, window_index) -> torch.Tensor:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
reverse_indices = torch.argsort(window_index)
x = x[reverse_indices, :]
return x
class VisualQuantizer(nn.Module):
def __init__(self, quantizer_config):
super().__init__()
self.config = quantizer_config
self.depth = self.config.depth
self.decay = self.config.decay
self.codebook_size = self.config.codebook_size
self.codebook_dim = self.config.codebook_dim
self.shared_codebook = self.config.shared_codebook
self.restart_unused_codes = self.config.restart_unused_codes
self.in_channels = self.config.in_channels
self.vq_loss_ratio = self.config.vq_loss_ratio
self.entropy_loss_ratio = self.config.entropy_loss_ratio
self.commit_loss_ratio = self.config.commit_loss_ratio
code_h_w = int(448 / 14)
latent_shape = [code_h_w, code_h_w, self.codebook_dim]
code_shape = [code_h_w, code_h_w, self.depth]
self.quantize = RQBottleneck(
latent_shape=latent_shape,
code_shape=code_shape,
n_embed=self.codebook_size,
decay=self.decay,
shared_codebook=self.shared_codebook,
restart_unused_codes=self.restart_unused_codes,
)
if self.config.quant_conv:
self.quant_conv = nn.Sequential(
nn.LayerNorm(self.in_channels),
nn.Linear(self.in_channels, self.in_channels),
nn.GELU(),
nn.Linear(self.in_channels, self.codebook_dim)
)
else:
self.quant_conv = None
def encode(self, x):
L, D = x.shape
to_qnt_feat = x.clone()
to_qnt_feat = to_qnt_feat.unsqueeze(0) # [L, D] -> [1, L, D]
N = 1
if self.quant_conv is not None:
to_qnt_feat = self.quant_conv(to_qnt_feat)
# quantizer needs nchw format. N,L,d -> N,1,L,d -> N,d,1,L
to_qnt_feat = to_qnt_feat.reshape(N, 1, L, self.codebook_dim).permute(0,3,1,2)
if self.config.quantizer_type == "rq":
to_qnt_feat = to_qnt_feat.permute(0, 2, 3, 1).contiguous() # N,d,1,L -> N,1,L,d
quant, emb_loss, info = self.quantize(to_qnt_feat)
info = info.reshape(-1, info.shape[-1]) # n,h,w,lv -> n*h*w,lv
info = [None, None, info]
quant = quant.permute(0, 3, 1, 2).contiguous() # N,1,L,d -> N,d,1,L
else:
quant, emb_loss, info = self.quantize(to_qnt_feat)
return quant, emb_loss, info, x.detach()
def forward(self, x):
quant, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices), align_feature = \
self.encode(x)
return min_encoding_indices
class MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.mlp = MLP(
hidden_size=self.hidden_size,
intermediate_size=config.visual_embedding_layer_intermediate_size,
hidden_act=config.visual_embedding_layer_hidden_act,
)
self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
):
residual = hidden_states
hidden_states = self.pre_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class VisualEmbeddingBridge(nn.Module):
def __init__(self, config):
super().__init__()
self.pre_buffer = DecoderLayer(config)
def forward(self, embeding):
return self.pre_buffer(embeding)
class VisualVQBridge(nn.Module):
def __init__(self, visual_config):
super().__init__()
self.bridge = OmniVisualBridge(visual_config)
self.quantizer = VisualQuantizer(visual_config.vq_config)
def forward(
self,
visual_embed: torch.Tensor,
window_index: torch.Tensor,
):
visual_embed = self.bridge(visual_embed, window_index)
indices = self.quantizer(visual_embed)
return indices
class LongcatNextVisualTokenizer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.visual_model = VisualEncoder(config.visual_config)
self.visual_bridge_model = VisualVQBridge(config.visual_config)
self.visual_embedding_layer = VisualEmbeddingBridge(config)
self.image_decoder = None
self._refiner_pipeline = None
@torch.no_grad()
def encode(self, pixel_values: torch.Tensor, visual_grid_thw: torch.Tensor):
visual_embed, window_index = self.visual_model(pixel_values, grid_thw=visual_grid_thw, require_window_index=True)
indices = self.visual_bridge_model(visual_embed, window_index)
return indices
@torch.no_grad()
def lazy_decode_and_save(self, visual_ids, tokens_h, tokens_w, save_path):
device = next(self.parameters()).device
if self.image_decoder is None:
print("lazy load image_decoder / image_refiner / _refiner_pipeline ...")
vdc = self.config.visual_config.visual_decoder_config
self.image_decoder = VisionTransformerDecoder.from_pretrained(
vdc.image_decoder_config,
vdc.weight_path,
).to(device=device, dtype=torch.bfloat16)
image_refiner = ImageRefinerContainer.from_pretrained(vdc, vdc.weight_path).to(device=device, dtype=torch.bfloat16)
sc = vdc.scheduler_config
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=sc.num_train_timesteps,
dynamic_time_shift=sc.dynamic_time_shift)
self._refiner_pipeline = RefinerPipeline(
vae=image_refiner.vae,
transformer=image_refiner.base_transformer,
scheduler=scheduler,
cond_proj=image_refiner.cond_proj,
)
self._refiner_pipeline.set_progress_bar_config(disable=False)
data = torch.as_tensor(visual_ids, dtype=torch.long)
if data.ndim == 1:
data = data.view(-1, len(self.config.visual_config.vq_config.codebook_sizes))
if data.ndim == 2:
data = data.unsqueeze(0)
batch_size = data.shape[0]
quant_features = None
for idx in range(len(self.config.visual_config.vq_config.codebook_sizes)):
embed = self.visual_bridge_model.quantizer.quantize.codebooks[idx].embed
feat = embed[data[..., idx].to(embed.device)]
quant_features = feat if quant_features is None else quant_features + feat
quant_features = quant_features.to(device)
# tokens_h/tokens_w are the merged grid; expand to the full (unmerged) grid
s = self.image_decoder.spatial_merge_size
grid_thw_list = [(1, tokens_h * s, tokens_w * s)]
grid_thw_batch = list(grid_thw_list) * batch_size
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
emb_2d = quant_features.reshape(-1, quant_features.shape[-1]).contiguous()
device_type = "cuda" if str(device).startswith("cuda") else str(device)
with torch.amp.autocast(device_type=device_type, enabled=True, dtype=torch.float32):
decoder_out = self.image_decoder(emb_2d, grid_thw_batch, return_pixel_features=False)
decoded_tensors = decoder_out.get("images") or []
decoded_images = [tensor2pil(t, image_mean, image_std) for t in decoded_tensors]
decoded_path = save_path.replace(".png", "_decoded.png")
# decoded_images[0].save(decoded_path)
ref_input = []
for t in decoded_tensors:
img_01 = de_transform(t, mean=image_mean, std=image_std, rescale_factor=1 / 255)
img_norm = RefinerImageProcessor.normalize(img_01)
ref_input.append(img_norm.squeeze(0).to(device))
generators = [torch.Generator(device=device).manual_seed(42 + b) for b in range(batch_size)]
out = self._refiner_pipeline(
encoder_hidden_states=quant_features,
grid_thw_list=grid_thw_list,
image=ref_input,
generator=generators[0] if batch_size == 1 else generators,
output_type="pil",
return_dict=True,
)
refined_images = out.images
refined_path = save_path.replace(".png", "_refined.png")
refined_images[0].save(refined_path)
return [refined_path]
# ---------------------------------------------------------------------------
# Vision Transformer Decoder
# ---------------------------------------------------------------------------
def _rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class VisionRoPE2D(nn.Module):
"""2D Rotary Position Embedding for Q/K in vision decoder attention."""
def __init__(self, theta: float = 10000.0):
super().__init__()
self.theta = theta
def _rope_half(self, x_half, pos_1d, theta):
BH, T, d_half = x_half.shape
idx = torch.arange(0, d_half, 2, device=x_half.device, dtype=torch.float32)
inv_freq = (1.0 / (theta ** (idx / d_half))).to(x_half.dtype)
angles = pos_1d.to(x_half.dtype)[:, None] * inv_freq[None, :]
cos = torch.repeat_interleave(torch.cos(angles), 2, dim=-1).unsqueeze(0)
sin = torch.repeat_interleave(torch.sin(angles), 2, dim=-1).unsqueeze(0)
return x_half * cos + _rotate_half(x_half) * sin
def forward(self, x, positions_2d):
d_half = x.shape[-1] // 2
x_y = self._rope_half(x[:, :, :d_half], positions_2d[:, 0], self.theta)
x_x = self._rope_half(x[:, :, d_half:], positions_2d[:, 1], self.theta)
return torch.cat([x_y, x_x], dim=-1)
class VisionAttention(nn.Module):
"""Multi-headed attention with 2D RoPE + FlashAttention varlen."""
def __init__(self, config, rope=None, rope_shift=0):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got embed_dim={self.embed_dim}, num_heads={self.num_heads})"
)
self.scale = self.head_dim ** -0.5
self.dropout = config.attention_dropout
self.subln = config.subln
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "k_bias", True))
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "v_bias", True))
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "q_bias", True))
self.inner_attn_ln = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) if config.subln else nn.Identity()
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.rope = rope
self.rope_shift = int(rope_shift)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _maybe_flash_attention(self, query_states, key_states, value_states, seq_lens, training):
if not (query_states.is_cuda and (query_states.dtype in (torch.float16, torch.bfloat16, torch.float32))):
return None
if seq_lens is None:
return None
try:
BxH, T, hd = query_states.shape
H = self.num_heads
assert BxH % H == 0
B = BxH // H
if int(seq_lens.sum().item()) != T:
return None
q = query_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
k = key_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
v = value_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
cu_q = torch.zeros(seq_lens.numel() + 1, dtype=torch.int32, device=seq_lens.device)
cu_q[1:] = torch.cumsum(seq_lens.to(torch.int32), dim=0)
cu_k = cu_q
max_seqlen = int(seq_lens.max().item())
orig_dtype = q.dtype
use_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.float16
if q.dtype != use_dtype:
q = q.to(use_dtype)
k = k.to(use_dtype)
v = v.to(use_dtype)
out = flash_attn_varlen_func(
q, k, v, cu_q, cu_k, max_seqlen, max_seqlen,
dropout_p=self.dropout if training else 0.0,
softmax_scale=None, causal=False, return_attn_probs=False
)
if out.dtype != orig_dtype:
out = out.to(orig_dtype)
return out.view(B, -1, H, hd).transpose(1, 2).contiguous().view(B * H, T, hd)
except Exception:
return None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
positions_2d: Optional[torch.Tensor] = None,
seq_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states) * self.scale
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._shape(query_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
key_states = self._shape(key_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
value_states = self._shape(value_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
if self.rope is not None and positions_2d is not None:
if self.rope_shift > 0:
q_pref = query_states[:, :self.rope_shift, :]
k_pref = key_states[:, :self.rope_shift, :]
q_rot = self.rope(query_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
k_rot = self.rope(key_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
query_states = torch.cat([q_pref, q_rot], dim=1).type_as(value_states)
key_states = torch.cat([k_pref, k_rot], dim=1).type_as(value_states)
else:
query_states = self.rope(query_states, positions_2d).type_as(value_states)
key_states = self.rope(key_states, positions_2d).type_as(value_states)
attn_output = self._maybe_flash_attention(
query_states, key_states, value_states, seq_lens=seq_lens, training=self.training
)
if attn_output is not None:
attn_weights_reshaped = None
else:
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if causal_attention_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
attn_output = self.inner_attn_ln(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class VisionSwiGLU(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size)
self.w2 = nn.Linear(self.hidden_size, self.intermediate_size)
self.w3 = nn.Linear(self.intermediate_size, self.hidden_size)
self.act_fn = nn.SiLU()
self.ffn_ln = Qwen2RMSNorm(self.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act_fn(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
return x
class VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.ffn_ln = Qwen2RMSNorm(config.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.ffn_ln(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class VisionEncoderLayer(nn.Module):
def __init__(self, config, rope=None, rope_shift=0):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = VisionAttention(config, rope=rope, rope_shift=rope_shift)
self.layer_norm1 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = VisionSwiGLU(config) if config.swiglu else VisionMLP(config)
self.layer_norm2 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
causal_attention_mask: Optional[torch.Tensor],
output_attentions: Optional[bool] = False,
positions_2d: Optional[torch.Tensor] = None,
seq_lens: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
positions_2d=positions_2d,
seq_lens=seq_lens,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class VisionEncoder(nn.Module):
def __init__(self, config, rope=None, rope_shift=0):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[VisionEncoderLayer(config, rope=rope, rope_shift=rope_shift) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
positions_2d: Optional[torch.Tensor] = None,
seq_lens: Optional[torch.Tensor] = None,
):
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
return_dict = True if return_dict is None else return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def custom_forward(hs, attn, causal, pos2d, seqlens):
return layer(
hs,
attention_mask=attn,
causal_attention_mask=causal,
output_attentions=False,
positions_2d=pos2d,
seq_lens=seqlens,
)[0]
hidden_states = self._gradient_checkpointing_func(
custom_forward,
hidden_states,
attention_mask if attention_mask is not None else torch.tensor(0., device=hidden_states.device),
causal_attention_mask if causal_attention_mask is not None else torch.tensor(0., device=hidden_states.device),
positions_2d,
seq_lens if seq_lens is not None else torch.tensor([], device=hidden_states.device),
use_reentrant=False,
)
else:
layer_outputs = layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
positions_2d=positions_2d,
seq_lens=seq_lens,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
class PatchUnMerger(nn.Module):
"""Learnable inverse of Qwen2_5_VLPatchMerger."""
def __init__(self, dim, context_dim, spatial_merge_size=2):
super().__init__()
self.spatial_merge_size = spatial_merge_size
self.context_dim = context_dim
hidden = context_dim * (spatial_merge_size ** 2)
self.ln_q = Qwen2RMSNorm(dim, eps=1e-6)
self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, hidden))
def forward(self, x):
x = self.mlp(self.ln_q(x))
return x.view(x.shape[0] * (self.spatial_merge_size ** 2), self.context_dim)
def restore_spatial_structure_and_convert_to_images(patches, grid_thw_list, patch_size,
channel_dim=3, temporal_patch_size=2, merge_size=2):
"""Convert decoder pixel features back to image tensors [3, H, W]."""
if isinstance(patches, tuple):
patches = patches[0]
image_tensors = []
ptr = 0
for grid in grid_thw_list:
gt, gh, gw = (int(x) for x in (grid if not isinstance(grid, torch.Tensor) else grid.tolist()))
n = gt * gh * gw
chunk = patches[ptr:ptr + n]
ptr += n
r = chunk.reshape(gt, gh // merge_size, gw // merge_size, merge_size, merge_size,
channel_dim, temporal_patch_size, patch_size, patch_size)
r = r.permute(0, 6, 5, 1, 3, 7, 2, 4, 8)
image_tensors.append(r.reshape(gt * temporal_patch_size, channel_dim, gh * patch_size, gw * patch_size)[0])
return image_tensors
class VisionTransformerDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.spatial_merge_size = config.spatial_merge_size
self.codebook_dim = config.codebook_dim
self.temporal_patch_size = config.temporal_patch_size
self.rope2d = VisionRoPE2D(theta=10000.0)
self.post_quant_conv = nn.Linear(self.codebook_dim, self.embed_dim)
self.post_quant_norm = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.patch_unmerger = PatchUnMerger(self.embed_dim, self.embed_dim, self.spatial_merge_size)
self.norm_in = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.encoder = VisionEncoder(config, rope=self.rope2d, rope_shift=0)
self.norm_out = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.decoder_head = nn.Sequential(
nn.Linear(self.embed_dim, config.intermediate_size), nn.GELU(),
nn.Linear(config.intermediate_size, 3 * self.patch_size * self.patch_size * self.temporal_patch_size),
)
@classmethod
def from_pretrained(cls, config, model_path: str):
"""Load a pretrained model from a checkpoint."""
model = cls(config)
weight_dict = load_file(model_path, device="cpu")
model.load_state_dict({k.removeprefix("image_decoder."): v for k, v in weight_dict.items() if k.startswith("image_decoder.")}, strict=True)
model.eval()
return model
def _build_2d_positions(self, grid_thw_list):
pos_list = []
for (t, gh, gw) in grid_thw_list:
for _ in range(int(t)):
for y in range(int(gh)):
for x in range(int(gw)):
pos_list.append([y, x])
return torch.tensor(pos_list, dtype=torch.long)
def _build_attention_mask(self, grid_thw_list, device, dtype, B, num_heads):
counts = [int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list]
L = sum(counts)
mask = torch.zeros((B, num_heads, L, L), device=device, dtype=dtype)
s = 0
for c in counts:
e = s + c
if s > 0:
mask[:, :, s:e, :s] = float("-inf")
if e < L:
mask[:, :, s:e, e:] = float("-inf")
s = e
return mask
def forward(self, embeddings, grid_thw, return_pixel_features=False, return_last_latent=False):
device = embeddings.device
grid_thw_list = ([(int(t), int(h), int(w)) for t, h, w in grid_thw.detach().cpu().numpy()]
if isinstance(grid_thw, torch.Tensor) else list(grid_thw))
if embeddings.shape[-1] == self.codebook_dim:
embeddings = self.post_quant_conv(embeddings)
embeddings = self.post_quant_norm(embeddings)
unmerged = self.patch_unmerger(embeddings)
if unmerged.dim() == 2:
unmerged = unmerged.unsqueeze(0)
B, L, D = unmerged.shape
hidden_states = self.norm_in(unmerged)
positions_2d = self._build_2d_positions(grid_thw_list).to(device)
seq_lens = torch.tensor([int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list],
device=device, dtype=torch.int32)
assert positions_2d.shape[0] == L, f"positions_2d {positions_2d.shape[0]} != L {L}"
last_latent = hidden_states.detach().squeeze(0) if return_last_latent else None
enc_out = self.encoder(
inputs_embeds=hidden_states,
attention_mask=None,
causal_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
positions_2d=positions_2d,
seq_lens=seq_lens,
)
hidden_states = enc_out.last_hidden_state
hidden_states = self.norm_out(hidden_states)
pixel_features = self.decoder_head(hidden_states).squeeze(0)
out_imgs = (None if return_pixel_features else
restore_spatial_structure_and_convert_to_images(
pixel_features, grid_thw_list, self.patch_size,
temporal_patch_size=self.temporal_patch_size, merge_size=self.spatial_merge_size))
ret = {"images": out_imgs, "pixel_features": pixel_features}
if last_latent is not None:
ret["last_latent"] = last_latent
return ret