| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """DiffusionVL model implementation.""" |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Callable, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
| from transformers.utils import logging |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.integrations import use_kernel_forward_from_hub |
|
|
| from .configuration_diffusionvl_qwen2_5_vl import DiffusionVL_Qwen2_5_VL_Config, DiffusionVL_Qwen2_5_VL_VisionConfig |
|
|
| IMAGE_TOKEN_INDEX = -200 |
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Rotates half the hidden dims of the input for rotary position embedding. |
| |
| Args: |
| x: Input tensor of shape (..., head_dim). |
| |
| Returns: |
| Rotated tensor of the same shape. |
| """ |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb_vision( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply rotary position embedding for vision encoder. |
| |
| Args: |
| q: Query tensor. |
| k: Key tensor. |
| cos: Cosine part of rotary embedding. |
| sin: Sine part of rotary embedding. |
| |
| Returns: |
| Tuple of (rotated_q, rotated_k). |
| """ |
| orig_q_dtype = q.dtype |
| orig_k_dtype = k.dtype |
| q, k = q.float(), k.float() |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) |
|
|
|
|
| def apply_multimodal_rotary_pos_emb( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| mrope_section: List[int], |
| unsqueeze_dim: int = 1, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply multimodal rotary position embedding (M-RoPE) for 3D position encoding. |
| |
| Args: |
| q: Query tensor of shape (batch, heads, seq_len, head_dim). |
| k: Key tensor of shape (batch, heads, seq_len, head_dim). |
| cos: Cosine tensor of shape (3, batch, seq_len, head_dim). |
| sin: Sine tensor of shape (3, batch, seq_len, head_dim). |
| mrope_section: List of 3 ints defining section sizes [temporal, height, width]. |
| For example, [16, 24, 24] for head_dim=128. |
| unsqueeze_dim: Dimension to unsqueeze for broadcasting. |
| |
| Returns: |
| Tuple of (rotated_q, rotated_k) with M-RoPE applied. |
| """ |
| |
| |
| mrope_section = mrope_section * 2 |
|
|
| |
| |
| cos = torch.cat( |
| [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 |
| ).unsqueeze(unsqueeze_dim) |
| sin = torch.cat( |
| [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 |
| ).unsqueeze(unsqueeze_dim) |
|
|
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class DiffusionVL_Qwen2_5_VL_RMSNorm(nn.Module): |
| """RMSNorm implementation matching Qwen2RMSNorm from modeling_qwen2.py""" |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| """Eager attention implementation.""" |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionMLP(nn.Module): |
| def __init__(self, config, bias: bool = False): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_state): |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionPatchEmbed(nn.Module): |
| def __init__(self, patch_size=14, temporal_patch_size=2, in_channels=3, embed_dim=1152): |
| super().__init__() |
| self.patch_size = patch_size |
| self.temporal_patch_size = temporal_patch_size |
| self.in_channels = in_channels |
| self.embed_dim = embed_dim |
| kernel_size = [temporal_patch_size, patch_size, patch_size] |
| self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| target_dtype = self.proj.weight.dtype |
| hidden_states = hidden_states.view( |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size |
| ) |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) |
| return hidden_states |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, dim: int, theta: float = 10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, seqlen: int) -> torch.Tensor: |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(seq, self.inv_freq) |
| return freqs |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionPatchMerger(nn.Module): |
| def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2): |
| super().__init__() |
| self.hidden_size = context_dim * (spatial_merge_size ** 2) |
| self.ln_q = DiffusionVL_Qwen2_5_VL_RMSNorm(context_dim, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(self.hidden_size, self.hidden_size), |
| nn.GELU(), |
| nn.Linear(self.hidden_size, dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) |
| return x |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionAttention(nn.Module): |
| def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig) -> None: |
| super().__init__() |
| self.dim = config.hidden_size |
| self.num_heads = config.num_heads |
| self.head_dim = self.dim // self.num_heads |
| self.num_key_value_groups = 1 |
| self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) |
| self.proj = nn.Linear(self.dim, self.dim) |
| self.scaling = self.head_dim**-0.5 |
| self.config = config |
| self.attention_dropout = 0.0 |
| self.is_causal = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| rotary_pos_emb: Optional[torch.Tensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| seq_length = hidden_states.shape[0] |
| query_states, key_states, value_states = ( |
| self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
| ) |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) |
|
|
| query_states = query_states.transpose(0, 1).unsqueeze(0) |
| key_states = key_states.transpose(0, 1).unsqueeze(0) |
| value_states = value_states.transpose(0, 1).unsqueeze(0) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if getattr(self.config, "_attn_implementation", "eager") != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2": |
| |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
| attn_output, _ = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask=None, |
| scaling=self.scaling, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| cu_seq_lens_q=cu_seqlens, |
| cu_seq_lens_k=cu_seqlens, |
| max_length_q=max_seqlen, |
| max_length_k=max_seqlen, |
| is_causal=False, |
| **kwargs, |
| ) |
| else: |
| |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| splits = [ |
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) |
| ] |
|
|
| attn_outputs = [ |
| attention_interface( |
| self, |
| q, |
| k, |
| v, |
| attention_mask=None, |
| scaling=self.scaling, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| is_causal=False, |
| **kwargs, |
| )[0] |
| for q, k, v in zip(*splits) |
| ] |
| attn_output = torch.cat(attn_outputs, dim=1) |
|
|
| attn_output = attn_output.reshape(seq_length, -1).contiguous() |
| attn_output = self.proj(attn_output) |
| return attn_output |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionBlock(GradientCheckpointingLayer): |
| def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
| super().__init__() |
| self.norm1 = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=1e-6) |
| self.norm2 = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=1e-6) |
| self.attn = DiffusionVL_Qwen2_5_VL_VisionAttention(config=config) |
| self.mlp = DiffusionVL_Qwen2_5_VL_VisionMLP(config, bias=True) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| rotary_pos_emb: Optional[torch.Tensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| hidden_states = hidden_states + self.attn( |
| self.norm1(hidden_states), |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
| return hidden_states |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionPreTrainedModel(PreTrainedModel): |
| config_class = DiffusionVL_Qwen2_5_VL_VisionConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DiffusionVL_Qwen2_5_VL_VisionBlock"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_attention_backend = True |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionTransformer(DiffusionVL_Qwen2_5_VL_VisionPreTrainedModel): |
| config_class = DiffusionVL_Qwen2_5_VL_VisionConfig |
| _no_split_modules = ["DiffusionVL_Qwen2_5_VL_VisionBlock"] |
|
|
| def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig, *inputs, **kwargs) -> None: |
| super().__init__(config, *inputs, **kwargs) |
| self.spatial_merge_size = config.spatial_merge_size |
| self.patch_size = config.patch_size |
| self.fullatt_block_indexes = config.fullatt_block_indexes |
| self.window_size = config.window_size |
| self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size |
|
|
| self.patch_embed = DiffusionVL_Qwen2_5_VL_VisionPatchEmbed( |
| patch_size=config.patch_size, |
| temporal_patch_size=config.temporal_patch_size, |
| in_channels=config.in_channels, |
| embed_dim=config.hidden_size, |
| ) |
|
|
| head_dim = config.hidden_size // config.num_heads |
| self.rotary_pos_emb = DiffusionVL_Qwen2_5_VL_VisionRotaryEmbedding(head_dim // 2) |
|
|
| self.blocks = nn.ModuleList([DiffusionVL_Qwen2_5_VL_VisionBlock(config) for _ in range(config.depth)]) |
| self.gradient_checkpointing = False |
|
|
| def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: |
|
|
| pos_ids = [] |
| for t, h, w in grid_thw: |
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
| hpos_ids = hpos_ids.reshape( |
| h // self.spatial_merge_size, |
| self.spatial_merge_size, |
| w // self.spatial_merge_size, |
| self.spatial_merge_size, |
| ) |
| hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() |
|
|
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
| wpos_ids = wpos_ids.reshape( |
| h // self.spatial_merge_size, |
| self.spatial_merge_size, |
| w // self.spatial_merge_size, |
| self.spatial_merge_size, |
| ) |
| wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
| pos_ids = torch.cat(pos_ids, dim=0) |
| max_grid_size = grid_thw[:, 1:].max() |
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
| return rotary_pos_emb |
|
|
| def get_window_index(self, grid_thw: torch.Tensor): |
|
|
| window_index: list = [] |
| cu_window_seqlens: list = [0] |
| window_index_id = 0 |
| vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size |
|
|
| for grid_t, grid_h, grid_w in grid_thw: |
| llm_grid_h = grid_h // self.spatial_merge_size |
| llm_grid_w = grid_w // self.spatial_merge_size |
| index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
| pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
| pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
| num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
| num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
| index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
| index_padded = index_padded.reshape( |
| grid_t, |
| num_windows_h, |
| vit_merger_window_size, |
| num_windows_w, |
| vit_merger_window_size, |
| ) |
| index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
| grid_t, |
| num_windows_h * num_windows_w, |
| vit_merger_window_size, |
| vit_merger_window_size, |
| ) |
| seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
| index_padded = index_padded.reshape(-1) |
| index_new = index_padded[index_padded != -100] |
| window_index.append(index_new + window_index_id) |
| cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] |
| cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
| window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
| window_index = torch.cat(window_index, dim=0) |
| return window_index, cu_window_seqlens |
|
|
| def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs): |
|
|
| hidden_states = self.patch_embed(hidden_states) |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
| cu_window_seqlens = torch.tensor( |
| cu_window_seqlens, |
| device=hidden_states.device, |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
| seq_len, _ = hidden_states.size() |
| hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| hidden_states = hidden_states[window_index, :, :] |
| hidden_states = hidden_states.reshape(seq_len, -1) |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| position_embeddings = (emb.cos(), emb.sin()) |
|
|
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| dim=0, |
| |
| |
| |
| |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
| for layer_num, blk in enumerate(self.blocks): |
| if layer_num in self.fullatt_block_indexes: |
| cu_seqlens_now = cu_seqlens |
| else: |
| cu_seqlens_now = cu_window_seqlens |
|
|
| hidden_states = blk( |
| hidden_states, |
| cu_seqlens=cu_seqlens_now, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
|
|
| |
| return hidden_states, window_index |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_VisionTower(nn.Module): |
|
|
| def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig): |
| super().__init__() |
| self.vision_tower = DiffusionVL_Qwen2_5_VL_VisionTransformer(config) |
| self.spatial_merge_size = config.spatial_merge_size |
|
|
| def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor = None): |
| """Returns (hidden_states, window_index) tuple for MMProjector.""" |
| return self.vision_tower(hidden_states, grid_thw) |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_MMProjector(nn.Module): |
|
|
| def __init__(self, config: DiffusionVL_Qwen2_5_VL_VisionConfig): |
| super().__init__() |
| self.merger = DiffusionVL_Qwen2_5_VL_VisionPatchMerger( |
| dim=config.out_hidden_size, |
| context_dim=config.hidden_size, |
| spatial_merge_size=config.spatial_merge_size, |
| ) |
|
|
| def forward(self, features_tuple): |
| """Forward pass with merger and window index reversal.""" |
| if isinstance(features_tuple, tuple): |
| hidden_states, window_index = features_tuple |
| |
| projected_features = self.merger(hidden_states) |
| |
| reverse_indices = torch.argsort(window_index) |
| final_features = projected_features[reverse_indices, :] |
| return final_features |
| else: |
| |
| return self.merger(features_tuple) |
|
|
| class DiffusionVL_Qwen2_5_VL_RotaryEmbedding(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| dim = config.hidden_size // config.num_attention_heads |
| inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, x, position_ids): |
| """ |
| Args: |
| x: Input tensor for dtype reference |
| position_ids: Position IDs with shape (3, batch_size, seq_length) for M-RoPE |
| or (batch_size, seq_length) for standard RoPE (will be converted to 3D) |
| |
| Returns: |
| cos, sin: Tensors of shape (3, batch, seq_len, head_dim) for M-RoPE |
| """ |
| |
| if position_ids.ndim == 2: |
| |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| |
| if position_ids.ndim == 3 and position_ids.shape[0] == 3: |
| |
| |
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand( |
| 3, position_ids.shape[1], -1, 1 |
| ) |
| |
| position_ids_expanded = position_ids[:, :, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) |
| |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
| else: |
| |
| inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1) |
| position_ids_expanded = position_ids[:, None, :].float() |
| freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
| return cos.to(x.dtype), sin.to(x.dtype) |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_MLP(nn.Module): |
| def __init__(self, config, bias: bool = False): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, hidden_state): |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_Attention(nn.Module): |
| """Non-causal attention for diffusion-based generation with KV-cache support.""" |
|
|
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.scaling = self.head_dim ** -0.5 |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
| |
| self.is_causal = False |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| output_attentions=False, |
| use_cache=False, |
| cache_position=None, |
| position_embeddings=None, |
| store_kv=False, |
| **kwargs, |
| ): |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| if position_embeddings is not None: |
| cos, sin = position_embeddings |
| query_states, key_states = apply_multimodal_rotary_pos_emb( |
| query_states, key_states, cos, sin, |
| self.config.rope_scaling.get("mrope_section", [16, 24, 24]) |
| ) |
|
|
| |
| if past_key_values is not None and use_cache: |
| cache_kwargs = {"cache_position": cache_position} |
| if store_kv: |
| |
| key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| else: |
| |
| cached_key = past_key_values.key_cache[self.layer_idx] if self.layer_idx < len(past_key_values.key_cache) else None |
| cached_value = past_key_values.value_cache[self.layer_idx] if self.layer_idx < len(past_key_values.value_cache) else None |
| if cached_key is not None and cached_value is not None: |
| key_states = torch.cat([cached_key, key_states], dim=2) |
| value_states = torch.cat([cached_value, value_states], dim=2) |
|
|
| |
| key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
| value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
| |
| if attention_mask is not None: |
| if isinstance(attention_mask, dict): |
| |
| attn_mask = attention_mask.get("full_attention", None) |
| else: |
| attn_mask = attention_mask |
| else: |
| attn_mask = None |
|
|
| if attn_mask is not None: |
| attn_output = F.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=attn_mask, |
| dropout_p=0.0, |
| is_causal=False, |
| scale=self.scaling, |
| ) |
| else: |
| attn_output = F.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| dropout_p=0.0, |
| is_causal=False, |
| scale=self.scaling, |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, -1) |
| attn_output = self.o_proj(attn_output) |
|
|
| return attn_output, None |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_DecoderLayer(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = DiffusionVL_Qwen2_5_VL_Attention(config, layer_idx) |
| self.mlp = DiffusionVL_Qwen2_5_VL_MLP(config) |
| self.input_layernorm = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| output_attentions=False, |
| use_cache=False, |
| cache_position=None, |
| position_embeddings=None, |
| store_kv=False, |
| **kwargs, |
| ): |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| hidden_states, attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| store_kv=store_kv, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states, attn_weights |
|
|
| class DiffusionVL_Qwen2_5_VL_PreTrainedModel(PreTrainedModel): |
|
|
| config_class = DiffusionVL_Qwen2_5_VL_Config |
| base_model_prefix = "model" |
| input_modalities = ["image", "video", "text"] |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DiffusionVL_Qwen2_5_VL_DecoderLayer", "DiffusionVL_Qwen2_5_VL_VisionBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| _supports_sdpa = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_Model(DiffusionVL_Qwen2_5_VL_PreTrainedModel): |
|
|
| def __init__(self, config: DiffusionVL_Qwen2_5_VL_Config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.vision_tower = DiffusionVL_Qwen2_5_VL_VisionTower(config.vision_config) |
| self.mm_projector = DiffusionVL_Qwen2_5_VL_MMProjector(config.vision_config) |
|
|
| |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([ |
| DiffusionVL_Qwen2_5_VL_DecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ]) |
| self.norm = DiffusionVL_Qwen2_5_VL_RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = DiffusionVL_Qwen2_5_VL_RotaryEmbedding(config) |
|
|
| |
| self.bd3lm_block_size = config.bd3lm_block_size |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
| """ |
| Encodes images into continuous embeddings through vision tower and mm_projector. |
| |
| Args: |
| pixel_values: Image tensor |
| image_grid_thw: Grid dimensions (temporal, height, width) for each image |
| |
| Returns: |
| Image embeddings ready to be merged with text embeddings |
| """ |
| pixel_values = pixel_values.to(dtype=self.vision_tower.vision_tower.patch_embed.proj.weight.dtype) |
| hidden_states = self.vision_tower(pixel_values, image_grid_thw) |
| image_embeds = self.mm_projector(hidden_states) |
| return image_embeds |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| cache_position=None, |
| store_kv=False, |
| pixel_values=None, |
| image_grid_thw=None, |
| **kwargs, |
| ): |
| """Forward pass with optional vision input processing.""" |
| output_attentions = output_attentions or False |
| output_hidden_states = output_hidden_states or False |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else True |
|
|
| IMAGE_TOKEN_INDEX = -200 |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if pixel_values is not None and image_grid_thw is not None: |
| |
| image_features = self.get_image_features(pixel_values, image_grid_thw) |
|
|
| |
| spatial_merge_size = self.vision_tower.spatial_merge_size |
| split_sizes = (image_grid_thw.prod(dim=1) // (spatial_merge_size ** 2)).tolist() |
| image_features_list = list(torch.split(image_features, split_sizes)) |
|
|
| |
| batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] |
| new_inputs_embeds_list = [] |
|
|
| for batch_idx in range(batch_size): |
| cur_input_ids = input_ids[batch_idx] if input_ids is not None else None |
| cur_embeds = inputs_embeds[batch_idx] |
|
|
| if cur_input_ids is None or (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: |
| new_inputs_embeds_list.append(cur_embeds) |
| continue |
|
|
| |
| image_positions = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() |
| image_token_indices = [-1] + image_positions + [len(cur_input_ids)] |
|
|
| |
| cur_new_embeds = [] |
| cur_image_idx = 0 |
|
|
| for i in range(len(image_token_indices) - 1): |
| start = image_token_indices[i] + 1 |
| end = image_token_indices[i + 1] |
|
|
| |
| if start < end: |
| cur_new_embeds.append(cur_embeds[start:end]) |
|
|
| |
| if i < len(image_positions) and cur_image_idx < len(image_features_list): |
| cur_new_embeds.append(image_features_list[cur_image_idx].to(cur_embeds.dtype)) |
| cur_image_idx += 1 |
|
|
| if cur_new_embeds: |
| new_inputs_embeds_list.append(torch.cat(cur_new_embeds, dim=0)) |
| else: |
| new_inputs_embeds_list.append(cur_embeds) |
|
|
| |
| max_len = max(x.shape[0] for x in new_inputs_embeds_list) |
| hidden_size = new_inputs_embeds_list[0].shape[-1] |
| inputs_embeds = torch.zeros( |
| batch_size, max_len, hidden_size, |
| dtype=new_inputs_embeds_list[0].dtype, |
| device=new_inputs_embeds_list[0].device |
| ) |
| for i, embed in enumerate(new_inputs_embeds_list): |
| inputs_embeds[i, :embed.shape[0]] = embed |
|
|
| batch_size, seq_length = inputs_embeds.shape[:2] |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) |
|
|
| if position_ids is None: |
| |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| position_embeddings = self.rotary_emb(inputs_embeds, position_ids) |
|
|
| hidden_states = inputs_embeds |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
|
|
| for layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| hidden_states, attn_weights = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| store_kv=store_kv, |
| ) |
|
|
| if output_attentions: |
| all_attentions += (attn_weights,) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| class DiffusionVL_Qwen2_5_VL_ForConditionalGeneration(DiffusionVL_Qwen2_5_VL_PreTrainedModel): |
| r""" |
| DiffusionVL Model with a language modeling head for diffusion-based generation. |
| |
| This model uses block diffusion instead of autoregressive |
| generation. The `generate()` method implements the diffusion denoising process. |
| |
| """ |
|
|
| |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config: DiffusionVL_Qwen2_5_VL_Config): |
| super().__init__(config) |
| self.model = DiffusionVL_Qwen2_5_VL_Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.mask_token_id = config.mask_token_id |
| self.block_size = config.bd3lm_block_size |
|
|
| self.post_init() |
|
|
| def get_model(self): |
| return self.model |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def tie_weights(self): |
| """Tie weights if config.tie_word_embeddings is True (3B model).""" |
| if getattr(self.config, "tie_word_embeddings", False): |
| |
| super().tie_weights() |
| |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| labels=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| pixel_values=None, |
| image_grid_thw=None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else True |
|
|
| |
| if pixel_values is not None and inputs_embeds is None: |
| |
| vision_features = self.model.vision_tower(pixel_values, image_grid_thw) |
| inputs_embeds = self._merge_vision_text(input_ids, vision_features) |
| input_ids = None |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def _merge_vision_text(self, input_ids, vision_features): |
| """Merge vision features with text embeddings.""" |
| text_embeds = self.model.embed_tokens(input_ids) |
| |
| return text_embeds |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| images: Optional[torch.Tensor] = None, |
| image_sizes: Optional[torch.Tensor] = None, |
| image_grid_thws: Optional[torch.Tensor] = None, |
| modalities: Optional[List] = None, |
| gen_length: int = 256, |
| steps: int = 8, |
| temperature: float = 0.0, |
| **kwargs, |
| ): |
| """ |
| Diffusion-based generation using BD3LM algorithm. |
| |
| Follows the same logic as DiffusionVLQwenVLForCausalLM.generate(): |
| 1. If images provided, call prepare_inputs_labels_for_multimodal |
| 2. Otherwise, just embed the input tokens |
| 3. Call generate_with_bd3lm |
| |
| Args: |
| inputs: Input token IDs (prompt) [batch_size, seq_len] |
| images: Image tensor (pixel_values) for vision inputs |
| image_sizes: Image sizes |
| image_grid_thws: Grid dimensions for vision inputs (num_images, 3) |
| modalities: List of modalities (e.g., ["image"]) |
| gen_length: Number of tokens to generate |
| steps: Number of diffusion steps per block |
| temperature: Sampling temperature (0 for greedy) |
| |
| Returns: |
| Generated token IDs |
| """ |
| if modalities is None: |
| modalities = ["image"] |
|
|
| if images is not None: |
| inputs_embeds = self.prepare_inputs_labels_for_multimodal( |
| input_ids=inputs, |
| images=images, |
| image_grid_thws=image_grid_thws, |
| ) |
| else: |
| inputs_embeds = self.get_input_embeddings()(inputs) |
|
|
| |
| return self.generate_with_bd3lm( |
| inputs_embeds=inputs_embeds, |
| gen_length=gen_length, |
| steps=steps, |
| temperature=temperature, |
| **kwargs, |
| ) |
|
|
| def prepare_inputs_labels_for_multimodal( |
| self, |
| input_ids: torch.Tensor, |
| images: torch.Tensor, |
| image_grid_thws: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Prepare inputs_embeds by merging text embeddings with image features. |
| |
| Uses LLaVA format: IMAGE_TOKEN_INDEX (-200) as placeholder. |
| |
| Args: |
| input_ids: Input token IDs with IMAGE_TOKEN_INDEX (-200) as image placeholders |
| images: Pixel values tensor |
| image_grid_thws: Grid dimensions for each image |
| |
| Returns: |
| inputs_embeds: Merged text + image embeddings |
| """ |
| IMAGE_TOKEN_INDEX = -200 |
|
|
| device = input_ids.device |
| batch_size = input_ids.shape[0] |
|
|
| |
| if image_grid_thws is not None: |
| if not isinstance(image_grid_thws, torch.Tensor): |
| image_grid_thw = torch.tensor(image_grid_thws, device=device) |
| else: |
| image_grid_thw = image_grid_thws.to(device) |
| else: |
| raise ValueError("image_grid_thws is required for vision processing") |
|
|
| |
| image_features = self.model.get_image_features(images, image_grid_thw) |
|
|
| |
| spatial_merge_size = self.model.vision_tower.spatial_merge_size |
| split_sizes = (image_grid_thw.prod(dim=1) // (spatial_merge_size ** 2)).tolist() |
| image_features_list = list(torch.split(image_features, split_sizes)) |
|
|
| |
| new_input_embeds_list = [] |
|
|
| for batch_idx in range(batch_size): |
| cur_input_ids = input_ids[batch_idx] |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum().item() |
|
|
| if num_images == 0: |
| |
| cur_input_embeds = self.get_input_embeddings()(cur_input_ids) |
| new_input_embeds_list.append(cur_input_embeds) |
| continue |
|
|
| |
| image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [len(cur_input_ids)] |
|
|
| cur_input_ids_noim = [] |
| for idx in range(len(image_token_indices) - 1): |
| start = image_token_indices[idx] + 1 |
| end = image_token_indices[idx + 1] |
| if start < end: |
| cur_input_ids_noim.append(cur_input_ids[start:end]) |
|
|
| if cur_input_ids_noim: |
| cur_input_embeds_noim = self.get_input_embeddings()(torch.cat(cur_input_ids_noim)) |
| split_sizes_text = [x.shape[0] for x in cur_input_ids_noim] |
| cur_input_embeds_noim_split = list(torch.split(cur_input_embeds_noim, split_sizes_text)) |
| else: |
| cur_input_embeds_noim_split = [] |
|
|
| cur_new_input_embeds = [] |
| cur_image_idx = 0 |
|
|
| for idx in range(num_images + 1): |
| if idx < len(cur_input_embeds_noim_split): |
| cur_new_input_embeds.append(cur_input_embeds_noim_split[idx]) |
| if idx < num_images and cur_image_idx < len(image_features_list): |
| cur_image_features = image_features_list[cur_image_idx] |
| target_dtype = cur_input_embeds_noim_split[0].dtype if cur_input_embeds_noim_split else images.dtype |
| cur_new_input_embeds.append(cur_image_features.to(target_dtype)) |
| cur_image_idx += 1 |
|
|
| if cur_new_input_embeds: |
| |
| target_device = cur_new_input_embeds[0].device |
| cur_new_input_embeds = [t.to(target_device) for t in cur_new_input_embeds] |
| cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) |
| else: |
| cur_new_input_embeds = self.get_input_embeddings()(cur_input_ids) |
|
|
| new_input_embeds_list.append(cur_new_input_embeds) |
|
|
| |
| max_len = max(x.shape[0] for x in new_input_embeds_list) |
| hidden_size = new_input_embeds_list[0].shape[-1] |
| dtype = new_input_embeds_list[0].dtype |
|
|
| inputs_embeds = torch.zeros(batch_size, max_len, hidden_size, dtype=dtype, device=device) |
| for i, embed in enumerate(new_input_embeds_list): |
| inputs_embeds[i, :embed.shape[0]] = embed.to(device) |
|
|
| return inputs_embeds |
|
|
| @torch.no_grad() |
| def generate_with_bd3lm( |
| self, |
| inputs_embeds: torch.FloatTensor, |
| gen_length: int = 256, |
| steps: int = 8, |
| temperature: float = 0.0, |
| top_k: int = 0, |
| top_p: float = 1.0, |
| remasking_strategy: str = 'low_confidence_static', |
| confidence_threshold: float = 0.85, |
| **kwargs, |
| ): |
| """ |
| BD3LM generation algorithm with KV-cache support. |
| |
| Args: |
| inputs_embeds: Input embeddings (prompt) |
| gen_length: Number of tokens to generate |
| steps: Number of diffusion steps per block |
| temperature: Sampling temperature (0 for greedy) |
| top_k: Top-k sampling parameter |
| top_p: Top-p (nucleus) sampling parameter |
| remasking_strategy: 'low_confidence_static', 'low_confidence_dynamic', or 'sequential' |
| confidence_threshold: Threshold for low_confidence_dynamic strategy |
| |
| Returns: |
| Generated token IDs |
| """ |
| device = inputs_embeds.device |
| batch_size = inputs_embeds.shape[0] |
| prompt_len = inputs_embeds.shape[1] |
| block_size = self.block_size |
| mask_id = self.mask_token_id |
|
|
| |
| num_blocks = (prompt_len + gen_length + block_size - 1) // block_size |
| total_length = num_blocks * block_size |
|
|
| |
| x_ids = torch.full((batch_size, total_length), mask_id, dtype=torch.long, device=device) |
| |
| embed_layer = self.get_input_embeddings() |
| mask_embed = embed_layer(torch.tensor([mask_id], device=embed_layer.weight.device)) |
| mask_embed = mask_embed.to(device) |
| x_embeds = mask_embed.repeat(batch_size, total_length, 1) |
| x_embeds[:, :prompt_len] = inputs_embeds.clone() |
|
|
| |
| prompt_logits = self.lm_head(inputs_embeds) |
| prompt_ids = torch.argmax(prompt_logits, dim=-1) |
| x_ids[:, :prompt_len] = prompt_ids |
|
|
| |
| block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device)).to(inputs_embeds.dtype) |
| block_diffusion_mask_bool = block_mask.repeat_interleave(block_size, dim=0) \ |
| .repeat_interleave(block_size, dim=1).unsqueeze(0) |
| block_diffusion_mask = block_diffusion_mask_bool.unsqueeze(1) |
| block_diffusion_mask = torch.where(block_diffusion_mask == 0., torch.full_like(block_diffusion_mask, float('-inf')), 0.) |
|
|
| position_ids = torch.arange(total_length, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
| |
| prefill_blocks = prompt_len // block_size |
| prefill_length = prefill_blocks * block_size |
|
|
| past_key_values = DynamicCache() |
| if prefill_length > 0: |
| prefill_embeds = x_embeds[:, :prefill_length] |
| prefill_mask = block_diffusion_mask[:, :, :prefill_length, :prefill_length] |
| prefill_pos_ids = position_ids[:, :prefill_length] |
|
|
| |
| model_mask = {"full_attention": prefill_mask, "sliding_attention": prefill_mask} |
|
|
| prefill_outputs = self.model( |
| inputs_embeds=prefill_embeds, |
| attention_mask=model_mask, |
| position_ids=prefill_pos_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| store_kv=True |
| ) |
| prefill_logits = self.lm_head(prefill_outputs.last_hidden_state).float() |
| self.last_prefill_logits = prefill_logits[:, -1:, :].clone() |
| past_key_values = prefill_outputs.past_key_values |
|
|
| |
| num_transfer_tokens = self._get_num_transfer_tokens(block_size, steps) |
| eos_token_id = kwargs.get('eos_token_id', 151645) |
|
|
| |
| for block_idx in range(prefill_blocks, num_blocks): |
| block_start = block_idx * block_size |
| block_end = block_start + block_size |
|
|
| cur_block_embeds = x_embeds[:, block_start:block_end].clone() |
| cur_block_ids = x_ids[:, block_start:block_end] |
|
|
| cur_mask = block_diffusion_mask[:, :, block_start:block_end, :block_end] |
| cur_pos_ids = position_ids[:, block_start:block_end] |
|
|
| |
| model_mask = {"full_attention": cur_mask, "sliding_attention": cur_mask} |
|
|
| |
| for step in range(steps + 1): |
| |
| is_mask = torch.all(torch.abs(cur_block_embeds - mask_embed.to(cur_block_embeds.device)) < 1e-5, dim=-1) |
| if not is_mask.any(): |
| |
| _ = self.model( |
| inputs_embeds=cur_block_embeds, |
| attention_mask=model_mask, |
| position_ids=cur_pos_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| store_kv=True |
| ) |
| break |
|
|
| |
| outputs = self.model( |
| inputs_embeds=cur_block_embeds, |
| attention_mask=model_mask, |
| position_ids=cur_pos_ids, |
| past_key_values=past_key_values, |
| use_cache=True, |
| store_kv=False |
| ) |
| logits = self.lm_head(outputs.last_hidden_state).float() |
|
|
| |
| x0, x0_p = self._sample_tokens(logits, temperature, top_k, top_p) |
|
|
| |
| num_to_transfer = num_transfer_tokens[step].item() |
|
|
| |
| target_device = x0.device |
| is_mask = is_mask.to(target_device) |
| x0_p = x0_p.to(target_device) |
|
|
| transfer_mask = torch.zeros_like(x0, dtype=torch.bool) |
|
|
| if remasking_strategy == 'sequential': |
| for j in range(batch_size): |
| if is_mask[j].any(): |
| mask_positions = is_mask[j].nonzero(as_tuple=True)[0] |
| num_to_select = min(num_to_transfer, len(mask_positions)) |
| selected_positions = mask_positions[:num_to_select] |
| transfer_mask[j, selected_positions] = True |
|
|
| elif remasking_strategy == 'low_confidence_static': |
| confidence = torch.where(is_mask, x0_p, torch.tensor(-torch.inf, device=target_device)) |
| for j in range(batch_size): |
| num_masks = is_mask[j].sum().item() |
| k = min(num_to_transfer, num_masks) |
| if k > 0 and not torch.all(torch.isinf(confidence[j])): |
| _, idx = torch.topk(confidence[j], k) |
| transfer_mask[j, idx] = True |
|
|
| elif remasking_strategy == 'low_confidence_dynamic': |
| confidence = torch.where(is_mask, x0_p, torch.tensor(-torch.inf, device=target_device)) |
| for j in range(batch_size): |
| high_conf_mask = confidence[j] > confidence_threshold |
| num_high_confidence = high_conf_mask.sum().item() |
| if num_high_confidence >= num_to_transfer: |
| transfer_mask[j] = high_conf_mask |
| else: |
| num_masks = is_mask[j].sum().item() |
| k = min(num_to_transfer, num_masks) |
| if k > 0: |
| _, idx = torch.topk(confidence[j], k) |
| transfer_mask[j, idx] = True |
|
|
| else: |
| raise ValueError(f"Unknown remasking strategy: {remasking_strategy}") |
|
|
| |
| cur_block_ids = cur_block_ids.to(x0.device) |
| cur_block_ids = torch.where(transfer_mask, x0, cur_block_ids) |
| |
| embed_layer = self.get_input_embeddings() |
| x0_embeds = embed_layer(x0.to(embed_layer.weight.device)) |
| cur_block_embeds = cur_block_embeds.to(x0_embeds.device) |
| cur_block_embeds = torch.where(transfer_mask.unsqueeze(-1).to(x0_embeds.device), x0_embeds, cur_block_embeds) |
|
|
| |
| x_embeds[:, block_start:block_end] = cur_block_embeds.to(x_embeds.device) |
| x_ids[:, block_start:block_end] = cur_block_ids.to(x_ids.device) |
|
|
| |
| if block_end > prompt_len: |
| gen_start_in_block = max(prompt_len, block_start) |
| gen_ids_check = x_ids[:, gen_start_in_block:block_end] |
| if eos_token_id in gen_ids_check: |
| break |
|
|
| |
| return x_ids[:, prompt_len:prompt_len + gen_length] |
|
|
| def _sample_tokens(self, logits, temperature=0.0, top_k=0, top_p=1.0): |
| """Sample tokens with temperature, top-k, and top-p.""" |
| batch_size = logits.shape[0] |
| seq_len = logits.shape[1] |
| vocab_size = logits.shape[-1] |
|
|
| logits_2d = logits.reshape(-1, vocab_size) |
|
|
| if temperature == 0: |
| |
| tokens = torch.argmax(logits_2d, dim=-1, keepdim=True) |
| probs = F.softmax(logits_2d, dim=-1) |
| token_probs = torch.gather(probs, -1, tokens) |
| else: |
| |
| logits_scaled = logits_2d / temperature |
|
|
| |
| if top_k > 0: |
| values, _ = torch.topk(logits_scaled, top_k) |
| min_values = values[:, -1:] |
| logits_scaled = torch.where(logits_scaled < min_values, float('-inf'), logits_scaled) |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits_scaled, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_mask = cumulative_probs > top_p |
| sorted_mask[:, 1:] = sorted_mask[:, :-1].clone() |
| sorted_mask[:, 0] = False |
| mask_indices = torch.scatter( |
| torch.zeros_like(logits_scaled, dtype=torch.bool), |
| -1, sorted_indices, sorted_mask |
| ) |
| logits_scaled = logits_scaled.masked_fill(mask_indices, float('-inf')) |
|
|
| probs = F.softmax(logits_scaled, dim=-1) |
| tokens = torch.multinomial(probs, num_samples=1) |
| token_probs = torch.gather(probs, -1, tokens) |
|
|
| return tokens.view(batch_size, seq_len), token_probs.view(batch_size, seq_len) |
|
|
| def _get_num_transfer_tokens(self, block_length, steps): |
| """Calculate how many tokens to unmask at each step.""" |
| if steps == 0: |
| return torch.zeros(1, dtype=torch.int64) |
| base = block_length // steps |
| remainder = block_length % steps |
| num_transfer = torch.zeros(steps + 1, dtype=torch.int64) + base |
| num_transfer[:remainder] += 1 |
| return num_transfer |
|
|
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| AutoConfig.register("diffusionvl_qwenvl", DiffusionVL_Qwen2_5_VL_Config) |
| AutoModelForCausalLM.register(DiffusionVL_Qwen2_5_VL_Config, DiffusionVL_Qwen2_5_VL_ForConditionalGeneration) |
|
|
|
|
| __all__ = [ |
| "DiffusionVL_Qwen2_5_VL_Config", |
| "DiffusionVL_Qwen2_5_VL_VisionConfig", |
| "DiffusionVL_Qwen2_5_VL_PreTrainedModel", |
| "DiffusionVL_Qwen2_5_VL_Model", |
| "DiffusionVL_Qwen2_5_VL_ForConditionalGeneration", |
| ] |
|
|