LLaVA-UHD-v3 / modeling_llava_uhd_v3.py
Sishxo's picture
Upload 12 files
8f993ed verified
from ast import Module
from cProfile import label
from functools import partial
from black import Mode
from matplotlib.pyplot import grid
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import PytorchGELUTanh
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.utils import is_flash_attn_2_available, logging
from transformers.integrations import use_kernel_forward_from_hub
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
from collections.abc import Callable
from transformers.activations import ACT2FN
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
import math
from copy import deepcopy
from typing import Union, Tuple, Sequence, Optional, List
from einops import rearrange
from .configuration_llava_uhd_v3 import LlavaUHDV3Config, LlavaUHDV3VisionConfig, LlavaUHDV3TextConfig
logger = logging.get_logger(__name__)
##### MoonViT part #####
def multihead_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
):
"""Multi-head attention using flash attention 2.
Args:
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
The first element should be 0 and the last element should be k.shape[0].
Returns:
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
where dim = num_heads * head_dim
"""
# Unified format legal check
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
assert (
k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
), "k_cu_seqlens must sum to k.shape[0]"
assert q.dtype in [
torch.bfloat16,
torch.float16,
], f"unsupported dtype {q.dtype} for multihead attn"
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
attn_out = flash_attn_varlen_func(
q,
k,
v,
q_cu_seqlens,
k_cu_seqlens,
max_seqlen_q,
max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)
return attn_out
def sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""SDPA attention.
Args:
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
"""
seq_length = q.shape[0]
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(q_cu_seqlens)):
attention_mask[
...,
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
return attn_output
def eager_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_length = q.shape[0]
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(q_cu_seqlens)):
attention_mask[
...,
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
attn_weight += attention_mask
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = attn_weight @ v
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
return attn_output
VL_VISION_ATTENTION_FUNCTIONS = {
"flash_attention_2": multihead_attention,
"sdpa": sdpa_attention,
"eager": eager_attention,
}
def _apply_rope_input_validation(x, freqs_cis):
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
def apply_rope(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args: (The leading dimensions of all inputs should be the same)
xq: query, tensor of shape (..., num_heads, head_dim)
xk: key, tensor of shape (..., num_heads, head_dim)
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
Returns:
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
"""
_apply_rope_input_validation(xq, freqs_cis)
_apply_rope_input_validation(xk, freqs_cis)
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
# ..., num_heads, head_dim/2
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
return xq_out.type_as(xq), xk_out.type_as(xk)
class Learnable2DInterpPosEmb(nn.Module):
def __init__(
self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
) -> None:
super().__init__()
self.height = height
self.width = width
self.interpolation_mode = interpolation_mode
self.weight = nn.Parameter(torch.empty(height, width, dim))
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.weight)
def forward(self, x, grid_hws) -> torch.Tensor:
pos_embs = []
for shape in grid_hws.tolist():
shape = [int(i) for i in shape]
if shape == self.weight.shape[:-1]:
pos_embs.append(self.weight.flatten(end_dim=1))
else:
pos_embs.append(
F.interpolate(
self.weight.permute((2, 0, 1)).unsqueeze(0),
size=shape,
mode=self.interpolation_mode,
)
.squeeze(0)
.permute((1, 2, 0))
.flatten(end_dim=1)
)
out = x + torch.cat(pos_embs)
return out
class MoonVisionPatchEmbed(nn.Module):
def __init__(
self,
out_dim: int,
in_dim: int = 3,
patch_size: Union[int, Tuple[int, int]] = (14, 14),
pos_emb_height: int = 14,
pos_emb_width: int = 14,
):
super().__init__()
assert isinstance(
patch_size, (int, Sequence)
), f"Invalid patch_size type: {type(patch_size)}"
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
assert (
len(patch_size) == 2
), f"Expected patch_size to be a tuple of 2, got {patch_size}"
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
)
self.pos_emb = Learnable2DInterpPosEmb(
height=pos_emb_height, width=pos_emb_width, dim=out_dim
)
def forward(self, x, grid_hws) -> torch.Tensor:
"""
Args:
x (L, Channels): input tensor
grid_hws (N, 2): grid height and width
Returns:
(L, Cout) tensor
"""
x = self.proj(x).view(x.size(0), -1)
# apply positional embedding
x = self.pos_emb(x, grid_hws)
return x
class Rope2DPosEmb(nn.Module):
"""2D rotary position embedding with multi-resolution support.
This class is intended to be used in the following way:
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
The rope is shared across all attention layers and all heads.
Refs:
- RoFormer: https://arxiv.org/abs/2104.09864
- VisionLLaMA: https://arxiv.org/abs/2403.00522
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
Args:
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
max_height (int): the maximum height of the 2D grid
max_width (int): the maximum width of the 2D grid
theta_base (float): the base of the theta
device (str): the device to store the precomputed cis
"""
def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
super().__init__()
self.dim = dim
assert self.dim % 4 == 0, "dim must be divisible by 4"
self.max_height = max_height
self.max_width = max_width
self.theta_base = theta_base
self.freqs_cis = None
def extra_repr(self):
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
def _precompute_freqs_cis(self, down_scale_rate, device: torch.device) -> torch.Tensor:
"""Calculate the cis(freqs) for each position in the 2D grid.
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
"""
max_height = self.max_height // down_scale_rate
max_width = self.max_width // down_scale_rate
N = max_height * max_width
flat_pos = torch.arange(0, N).float().to(device)
x_pos = flat_pos % max_width
y_pos = flat_pos // max_width
dim_range = (
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
) # C/4
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
# N, C/4, 2
freqs_cis = torch.cat(
[x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
)
# max_height, max_width, C/2
freqs_cis = freqs_cis.reshape(max_height, max_width, -1)
return freqs_cis
def get_freqs_cis(self, grid_hws: torch.Tensor, down_scale_rate=1, init_freqs=False) -> torch.Tensor:
"""
Args:
grid_hws (torch.Tensor): grid height and width
Returns:
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
"""
max_height = self.max_height // down_scale_rate
max_width = self.max_width // down_scale_rate
if self.freqs_cis is None or init_freqs:
self.freqs_cis = self._precompute_freqs_cis(down_scale_rate, grid_hws.device)
shapes = grid_hws.tolist()
assert all(
1 <= h <= max_height and 1 <= w <= max_width for h, w in shapes
), (
shapes,
max_height,
max_width,
)
freqs_cis = torch.cat(
[self.freqs_cis[:int(h), :int(w)].reshape(-1, self.dim // 2) for h, w in shapes],
dim=0,
)
return freqs_cis
class MLP2(nn.Module):
"""
Args:
dims: [in_dim, hidden_dim, out_dim]
bias: whether to use bias in linear layer.
"""
def __init__(self, dims: list[int], activation, bias=True):
super().__init__()
assert len(dims) == 3
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
self.activation = activation
for m in [self.fc0, self.fc1]:
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc0(x)
x = self.activation(x)
return self.fc1(x)
class PatchMergingLayer(nn.Module):
def __init__(self, embed_dim, enable_merging=True, merging_method="avg_pooling", norm_layer=nn.LayerNorm):
"""
:param embed_dim: Transformer token 的嵌入维度
:param enable_merging: 是否启用 token 合并功能
:param merging_method: 选择 'mlp' 或 'avg_pooling' 作为合并方式
"""
super().__init__()
self.enable_merging = enable_merging
self.merging_method = merging_method
self.zero_init_fc = nn.Linear(embed_dim, embed_dim, bias=False)
if self.merging_method == 'avg_pooling':
pass
elif self.merging_method == 'm_pooling':
self.attn_layer = nn.Sequential(
nn.Linear(embed_dim * 2, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, embed_dim)
)
self.num_head = 16
def forward(self, x, cu_seqlens, spatial_shapes):
if not self.enable_merging:
return x, cu_seqlens
cu_seqlens_out = cu_seqlens.clone() # (N+1, )
feature_x = x
x_i_list = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i-1].item()
end_idx = cu_seqlens[i].item()
x_i = x[start_idx:end_idx, :]
h, w = spatial_shapes[i-1]
x_i = x_i.view(int(h), int(w), -1) # (h, w, embed_dim)
if self.merging_method == 'avg_pooling':
x_i = rearrange(x_i, 'h w c -> c h w')
x_i = F.avg_pool2d(x_i, kernel_size=2, stride=2)
x_i = rearrange(x_i, 'c h w -> (h w) c')
elif self.merging_method == 'm_pooling':
x_i = rearrange(x_i, '(h p1) (w p2) c -> (h w) (p1 p2) c', p1=2, p2=2)
pooled_x_i = x_i.mean(-2, keepdim=True).expand(-1, 4, -1)
fused_x_i = torch.cat([x_i, pooled_x_i], dim=-1)
attn_logits = self.attn_layer(fused_x_i)
# multi-head attn
attn_logits = rearrange(attn_logits, 'n s (m d) -> n m s d', m=self.num_head)
attn_weights = F.softmax(attn_logits, dim=-2)
attn_weights = rearrange(attn_weights, 'n m s d -> n s (m d)')
# multi-head attn
x_i = (x_i * attn_weights).sum(-2)
x_i_list.append(x_i)
cu_seqlens_out[i] = cu_seqlens_out[i-1] + x_i.shape[0]
x = torch.cat(x_i_list, dim=0) # (L, embed_dim)
return x, cu_seqlens_out, spatial_shapes//2, feature_x
class MoonVitEncoderLayer(nn.Module):
def __init__(
self,
layer_idx: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
*,
attn_implementation: str = "eager",
activation=F.gelu,
attn_bias: bool = False,
enable_merging: bool = False,
merging_method: str = "avg_pooling",
merger_layer_index: List[int] = None,
):
super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
if merger_layer_index is not None and layer_idx in merger_layer_index:
self.merger = PatchMergingLayer(
embed_dim=hidden_dim,
enable_merging=enable_merging,
merging_method=merging_method,
)
else:
self.merger = None
def attention_qkvpacked(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: Optional[torch.Tensor] = None,
):
"""
Args:
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
cu_seqlens (torch.Tensor):
"""
xqkv = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + (
3,
self.num_heads,
self.hidden_size_per_attention_head,
)
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
xqkv = xqkv.view(*qkv_shape)
xq, xk, xv = torch.unbind(xqkv, dim=-3)
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
attn_out = attn_func(
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
)
attn_out = self.wo(attn_out)
return attn_out
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: Union[torch.Tensor, None] = None,
spatial_shapes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
Returns:
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
"""
residual = hidden_states
hidden_states = self.norm0(hidden_states)
attn_out = self.attention_qkvpacked(
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.mlp(self.norm1(hidden_states))
hidden_states = residual + hidden_states
if self.merger is not None:
hidden_states, cu_seqlens, spatial_shapes, feature_x = self.merger(
hidden_states, cu_seqlens, spatial_shapes
)
outputs = (hidden_states, cu_seqlens, spatial_shapes, feature_x)# return the feature_x for later use
else:
outputs = (hidden_states, cu_seqlens)
return outputs
class MoonVitEncoder(nn.Module):
def __init__(
self,
hidden_dim: int,
num_layers: int,
block_cfg: dict,
) -> None:
super().__init__()
self.blocks = nn.ModuleList(
[MoonVitEncoderLayer(layer_idx=i, **block_cfg) for i in range(num_layers)]
)
self.final_layernorm = nn.LayerNorm(hidden_dim)
self.rope_2d = Rope2DPosEmb(
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
)
def forward(
self, hidden_states: torch.Tensor, grid_hws: torch.Tensor
) -> torch.Tensor:
rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws)
lengths = torch.cat(
(
torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype),
grid_hws[:, 0] * grid_hws[:, 1],
)
)
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
down_scale_rate = 1
feature_x_list = []
for _, block in enumerate(self.blocks):
layer_outputs = block(
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis, spatial_shapes=grid_hws
)
if len(layer_outputs) > 2:
down_scale_rate *= 2
hidden_states, cu_seqlens, grid_hws, feature_x = layer_outputs
rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws, down_scale_rate=down_scale_rate)
feature_x_list.append(feature_x)
else:
hidden_states, cu_seqlens = layer_outputs
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, grid_hws
##### Qwen2 part #####
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Qwen2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
sliding_window = None
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
attention_interface: Callable = eager_attention_forward
if self.config.attn_implementation != "eager":
if self.config.attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config.attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=sliding_window, # main diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
##### LlavaUHDV3 part #####
class Qwen2vlPatchMerger(nn.Module):
def __init__(
self,
embed_dim,
image_embed_dim=1024,
compression_factor=(2,2),
norm_layer=partial(nn.LayerNorm, eps=1e-6)
):
super().__init__()
self.embed_dim = embed_dim
self.image_embed_dim = image_embed_dim
self.hidden_size = image_embed_dim * (compression_factor[0]*compression_factor[1])
self.nl = norm_layer(image_embed_dim)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, embed_dim),
)
self.compression_factor = compression_factor
def forward(self, x, tgt_size=(24,24), attn_mask=None):
# x = x.to(torch.bfloat16)
# dtype = x.dtype
height, width = tgt_size
if height * width != x.shape[1]:
x = x[:, :int(height * width)]
x = self.nl(x)
x = x.permute(0, 2, 1).unflatten(-1, (int(height), int(width))) # b, dim, h, w
batch_size, dim, height, width = x.shape
# 计算输出空间的高度和宽度
# h_compressed = (height + self.compression_factor[0] - 1) // self.compression_factor[0]
# w_compressed = (width + self.compression_factor[1] - 1) // self.compression_factor[1]
unfolded = x.unfold(2, self.compression_factor[0], self.compression_factor[0]).unfold(3, self.compression_factor[1], self.compression_factor[1])
unfolded = unfolded.contiguous().view(batch_size, dim, -1, self.compression_factor[0] * self.compression_factor[1])
unfolded = unfolded.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, dim*self.compression_factor[0] * self.compression_factor[1])
compressed_x = self.mlp(unfolded)
return compressed_x
class LlavaUHDV3PretrainedModel(PreTrainedModel):
config: LlavaUHDV3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen2DecoderLayer", "MoonViTEncoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_attention_backend = True
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
class LlavaUHDV3VisionTransformerPretrainedModel(LlavaUHDV3PretrainedModel):
config: LlavaUHDV3VisionConfig
_no_split_modules = ["MoonViTEncoderLayer"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.patch_size = config.patch_size
self.patch_embed = MoonVisionPatchEmbed(
out_dim=config.hidden_size,
patch_size=config.patch_size,
pos_emb_height=config.init_pos_emb_height,
pos_emb_width=config.init_pos_emb_width,
)
if hasattr(config, "merger_layer_index"):
merger_layer_index = config.merger_layer_index
merging_method = config.merging_method
if merger_layer_index is not None:
enable_merging = True
merging_method = merging_method if merging_method is not None else "avg_pooling"
else:
enable_merging = False
merging_method = None
self.encoder = MoonVitEncoder(
hidden_dim=config.hidden_size,
num_layers=config.num_hidden_layers,
block_cfg={
"num_heads": config.num_attention_heads,
"hidden_dim": config.hidden_size,
"mlp_dim": config.intermediate_size,
"activation": PytorchGELUTanh(),
"attn_bias": True,
"attn_implementation": self.config.attn_implementation,
"enable_merging": enable_merging,
"merging_method": merging_method,
"merger_layer_index": merger_layer_index,
},
)
def forward(
self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
) -> torch.Tensor:
"""
Args:
pixel_values (torch.Tensor): The input pixel values.
grid_hws (torch.Tensor): The grid height and width.
Returns:
torch.Tensor: The output tokens.
"""
pixel_values = pixel_values.to(torch.bfloat16)
hidden_states = self.patch_embed(pixel_values, grid_hws)
image_features, grid_hws = self.encoder(hidden_states, grid_hws)
output_features = []
offset = 0
for grid_hw in grid_hws:
h, w = grid_hw
num_tokens = int(h * w)
output_features.append(image_features[offset: offset+num_tokens].unsqueeze(0))
offset += num_tokens
assert offset == image_features.shape[0], \
f"Used {offset} tokens, but image_features has {image_features.shape[0]} tokens!"
return output_features
class LlavaUHDV3TextModel(LlavaUHDV3PretrainedModel):
config: LlavaUHDV3TextConfig
_no_split_modules = ["Qwen2DecoderLayer"]
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_input_embeddings(self):
return self.embed_tokens
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> BaseModelOutputWithPast:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config.attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config.attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config.attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Qwen2Config`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class LlavaUHDV3Model(LlavaUHDV3PretrainedModel):
config_class = LlavaUHDV3Config
def __init__(self, config):
super().__init__(config)
config.model_type = "llava_uhd_v3"
config.rope_scaling = None
self.visual = LlavaUHDV3VisionTransformerPretrainedModel._from_config(config.vision_config)
self.language_model = LlavaUHDV3TextModel._from_config(config.text_config)
self.projector = Qwen2vlPatchMerger(
embed_dim=config.text_config.hidden_size,
image_embed_dim=config.vision_config.hidden_size,
compression_factor=(2, 2),
)
self.rope_deltas = None
# Initialize model layers here
self.post_init()
def get_image_features(self, pixel_values, grid_hws):
down_smaple_ratio = 1
merger_layer_index = getattr(self.config.vision_config, "merger_layer_index", None)
if merger_layer_index is not None:
down_smaple_ratio = down_smaple_ratio * len(merger_layer_index)**2
image_features = self.visual(pixel_values, grid_hws)
projected_image_feaures = []
for image_feature, grid_hw in zip(image_features, grid_hws):
grid_hw = (grid_hw[0]//down_smaple_ratio, grid_hw[1]//down_smaple_ratio)
projected_image_feature = self.projector(image_feature, tgt_size=grid_hw)[0]
projected_image_feaures.append(projected_image_feature)
return projected_image_feaures
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
pixel_values,
grid_hws
):
if pixel_values is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
image_features = self.get_image_features(pixel_values, grid_hws)
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, -100)
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == -200).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.language_model.embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0][0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == -200)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.language_model.embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
try:
cur_image_features = image_features[cur_image_idx]
except IndexError:
cur_image_features = image_features[cur_image_idx - 1]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), -100, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", 4096)
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), -100, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def forward(
self,
input_ids = None,
position_ids = None,
attention_mask = None,
past_key_values = None,
inputs_embeds = None,
labels = None,
use_cache = None,
output_attentions = None,
output_hidden_states = None,
pixel_values = None,
grid_hws = None,
return_dict = None,
**kwargs,
):
if inputs_embeds is None:
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, pixel_values, grid_hws
)
output = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
if labels is not None:
return output[0], labels
return output
class LlavaUHDV3ForCausalLM(LlavaUHDV3PretrainedModel, GenerationMixin):
config_class = LlavaUHDV3Config
_checkpoint_conversion_mapping = {
"^visual": "model.visual",
r"^model(?!\.(language_model|visual|projector))": "model.language_model",
}
# _tied_weights_keys = ["lm_head.weight", "model.language_model.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlavaUHDV3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.visual
def get_input_embeddings(self):
return self.language_model.embed_tokens
def get_output_embeddings(self):
return self.lm_head
def forward(self, input_ids, labels=None, attention_mask=None, pixel_values=None, grid_hws=None, **kwargs):
if labels is not None:
outputs, labels = self.model(input_ids, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs)
else:
outputs = self.model(input_ids, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs)
hidden_states = outputs.last_hidden_state
slice_indices = slice(0, None)
logits = self.lm_head(hidden_states[:,slice_indices,:])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
grid_hws: Optional[torch.Tensor] = None,
**kwargs,
):
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if pixel_values is not None:
input_ids, position_ids, attention_mask, _, inputs_embeds, _ = self.model.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, None, None, pixel_values, grid_hws
)
else:
inputs_embeds = self.model.language_model.embed_tokens(input_ids)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
grid_hws = kwargs.pop("grid_hws", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
if pixel_values is not None:
inputs["pixel_values"] = pixel_values
if grid_hws is not None:
inputs["grid_hws"] = grid_hws
return inputs
__all__ = ["LlavaUHDV3ForCausalLM", "LlavaUHDV3Model", "LlavaUHDV3PretrainedModel", "LlavaUHDV3TextModel"]
# At the end of this model file
# ModelClass = LlavaUHDV3ForCausalLM