|
|
from ast import Module |
|
|
from cProfile import label |
|
|
from functools import partial |
|
|
|
|
|
from black import Mode |
|
|
from matplotlib.pyplot import grid |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers.activations import PytorchGELUTanh |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from transformers.utils import is_flash_attn_2_available, logging |
|
|
from transformers.integrations import use_kernel_forward_from_hub |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
|
|
|
|
if is_flash_attn_2_available(): |
|
|
from flash_attn import flash_attn_varlen_func |
|
|
else: |
|
|
flash_attn_varlen_func = None |
|
|
from collections.abc import Callable |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache |
|
|
from transformers.generation import GenerationMixin |
|
|
|
|
|
import math |
|
|
from copy import deepcopy |
|
|
from typing import Union, Tuple, Sequence, Optional, List |
|
|
from einops import rearrange |
|
|
|
|
|
from .configuration_llava_uhd_v3 import LlavaUHDV3Config, LlavaUHDV3VisionConfig, LlavaUHDV3TextConfig |
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def multihead_attention( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
q_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
k_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
): |
|
|
"""Multi-head attention using flash attention 2. |
|
|
Args: |
|
|
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), |
|
|
or (tot_seqlens, num_heads, head_dim) if packing. |
|
|
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. |
|
|
The first element should be 0 and the last element should be q.shape[0]. |
|
|
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. |
|
|
The first element should be 0 and the last element should be k.shape[0]. |
|
|
Returns: |
|
|
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, |
|
|
where dim = num_heads * head_dim |
|
|
""" |
|
|
|
|
|
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" |
|
|
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]" |
|
|
assert ( |
|
|
k_cu_seqlens[-1] == k.shape[0] == v.shape[0] |
|
|
), "k_cu_seqlens must sum to k.shape[0]" |
|
|
assert q.dtype in [ |
|
|
torch.bfloat16, |
|
|
torch.float16, |
|
|
], f"unsupported dtype {q.dtype} for multihead attn" |
|
|
|
|
|
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() |
|
|
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() |
|
|
attn_out = flash_attn_varlen_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
q_cu_seqlens, |
|
|
k_cu_seqlens, |
|
|
max_seqlen_q, |
|
|
max_seqlen_k, |
|
|
causal=False, |
|
|
) |
|
|
attn_out = attn_out.flatten(start_dim=-2) |
|
|
|
|
|
return attn_out |
|
|
|
|
|
def sdpa_attention( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
q_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
k_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
"""SDPA attention. |
|
|
Args: |
|
|
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), |
|
|
or (tot_seqlens, num_heads, head_dim) if packing. |
|
|
""" |
|
|
seq_length = q.shape[0] |
|
|
attention_mask = torch.zeros( |
|
|
[1, seq_length, seq_length], device=q.device, dtype=torch.bool |
|
|
) |
|
|
for i in range(1, len(q_cu_seqlens)): |
|
|
attention_mask[ |
|
|
..., |
|
|
q_cu_seqlens[i - 1] : q_cu_seqlens[i], |
|
|
q_cu_seqlens[i - 1] : q_cu_seqlens[i], |
|
|
] = True |
|
|
q = q.transpose(0, 1) |
|
|
k = k.transpose(0, 1) |
|
|
v = v.transpose(0, 1) |
|
|
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) |
|
|
attn_output = attn_output.transpose(0, 1) |
|
|
attn_output = attn_output.reshape(seq_length, -1) |
|
|
return attn_output |
|
|
|
|
|
def eager_attention( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
q_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
k_cu_seqlens: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
seq_length = q.shape[0] |
|
|
attention_mask = torch.zeros( |
|
|
[1, seq_length, seq_length], device=q.device, dtype=torch.bool |
|
|
) |
|
|
for i in range(1, len(q_cu_seqlens)): |
|
|
attention_mask[ |
|
|
..., |
|
|
q_cu_seqlens[i - 1] : q_cu_seqlens[i], |
|
|
q_cu_seqlens[i - 1] : q_cu_seqlens[i], |
|
|
] = True |
|
|
q = q.transpose(0, 1) |
|
|
k = k.transpose(0, 1) |
|
|
v = v.transpose(0, 1) |
|
|
|
|
|
attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1]) |
|
|
attn_weight += attention_mask |
|
|
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype) |
|
|
|
|
|
attn_output = attn_weight @ v |
|
|
attn_output = attn_output.transpose(0, 1) |
|
|
attn_output = attn_output.reshape(seq_length, -1) |
|
|
return attn_output |
|
|
|
|
|
VL_VISION_ATTENTION_FUNCTIONS = { |
|
|
"flash_attention_2": multihead_attention, |
|
|
"sdpa": sdpa_attention, |
|
|
"eager": eager_attention, |
|
|
} |
|
|
|
|
|
def _apply_rope_input_validation(x, freqs_cis): |
|
|
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) |
|
|
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) |
|
|
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) |
|
|
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype |
|
|
|
|
|
def apply_rope( |
|
|
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: (The leading dimensions of all inputs should be the same) |
|
|
xq: query, tensor of shape (..., num_heads, head_dim) |
|
|
xk: key, tensor of shape (..., num_heads, head_dim) |
|
|
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid. |
|
|
Returns: |
|
|
xq_out, xk_out: tensors of shape (..., num_heads, head_dim) |
|
|
""" |
|
|
_apply_rope_input_validation(xq, freqs_cis) |
|
|
_apply_rope_input_validation(xk, freqs_cis) |
|
|
|
|
|
freqs_cis = freqs_cis.unsqueeze(-2) |
|
|
|
|
|
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) |
|
|
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) |
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) |
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) |
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
class Learnable2DInterpPosEmb(nn.Module): |
|
|
def __init__( |
|
|
self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic" |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.height = height |
|
|
self.width = width |
|
|
self.interpolation_mode = interpolation_mode |
|
|
self.weight = nn.Parameter(torch.empty(height, width, dim)) |
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
nn.init.normal_(self.weight) |
|
|
|
|
|
def forward(self, x, grid_hws) -> torch.Tensor: |
|
|
pos_embs = [] |
|
|
for shape in grid_hws.tolist(): |
|
|
shape = [int(i) for i in shape] |
|
|
if shape == self.weight.shape[:-1]: |
|
|
pos_embs.append(self.weight.flatten(end_dim=1)) |
|
|
else: |
|
|
pos_embs.append( |
|
|
F.interpolate( |
|
|
self.weight.permute((2, 0, 1)).unsqueeze(0), |
|
|
size=shape, |
|
|
mode=self.interpolation_mode, |
|
|
) |
|
|
.squeeze(0) |
|
|
.permute((1, 2, 0)) |
|
|
.flatten(end_dim=1) |
|
|
) |
|
|
out = x + torch.cat(pos_embs) |
|
|
return out |
|
|
|
|
|
class MoonVisionPatchEmbed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
out_dim: int, |
|
|
in_dim: int = 3, |
|
|
patch_size: Union[int, Tuple[int, int]] = (14, 14), |
|
|
pos_emb_height: int = 14, |
|
|
pos_emb_width: int = 14, |
|
|
): |
|
|
super().__init__() |
|
|
assert isinstance( |
|
|
patch_size, (int, Sequence) |
|
|
), f"Invalid patch_size type: {type(patch_size)}" |
|
|
if isinstance(patch_size, int): |
|
|
patch_size = (patch_size, patch_size) |
|
|
assert ( |
|
|
len(patch_size) == 2 |
|
|
), f"Expected patch_size to be a tuple of 2, got {patch_size}" |
|
|
self.patch_size = patch_size |
|
|
|
|
|
self.proj = nn.Conv2d( |
|
|
in_dim, out_dim, kernel_size=patch_size, stride=patch_size |
|
|
) |
|
|
|
|
|
self.pos_emb = Learnable2DInterpPosEmb( |
|
|
height=pos_emb_height, width=pos_emb_width, dim=out_dim |
|
|
) |
|
|
|
|
|
def forward(self, x, grid_hws) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x (L, Channels): input tensor |
|
|
grid_hws (N, 2): grid height and width |
|
|
Returns: |
|
|
(L, Cout) tensor |
|
|
""" |
|
|
x = self.proj(x).view(x.size(0), -1) |
|
|
|
|
|
x = self.pos_emb(x, grid_hws) |
|
|
return x |
|
|
|
|
|
class Rope2DPosEmb(nn.Module): |
|
|
"""2D rotary position embedding with multi-resolution support. |
|
|
This class is intended to be used in the following way: |
|
|
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. |
|
|
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration. |
|
|
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation. |
|
|
The rope is shared across all attention layers and all heads. |
|
|
Refs: |
|
|
- RoFormer: https://arxiv.org/abs/2104.09864 |
|
|
- VisionLLaMA: https://arxiv.org/abs/2403.00522 |
|
|
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py |
|
|
Args: |
|
|
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed) |
|
|
max_height (int): the maximum height of the 2D grid |
|
|
max_width (int): the maximum width of the 2D grid |
|
|
theta_base (float): the base of the theta |
|
|
device (str): the device to store the precomputed cis |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
assert self.dim % 4 == 0, "dim must be divisible by 4" |
|
|
self.max_height = max_height |
|
|
self.max_width = max_width |
|
|
self.theta_base = theta_base |
|
|
|
|
|
self.freqs_cis = None |
|
|
|
|
|
def extra_repr(self): |
|
|
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}" |
|
|
|
|
|
def _precompute_freqs_cis(self, down_scale_rate, device: torch.device) -> torch.Tensor: |
|
|
"""Calculate the cis(freqs) for each position in the 2D grid. |
|
|
Return: complex tensor of shape (max_height, max_width, dim//2) and value: |
|
|
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim)) |
|
|
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4)) |
|
|
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x, |
|
|
""" |
|
|
max_height = self.max_height // down_scale_rate |
|
|
max_width = self.max_width // down_scale_rate |
|
|
|
|
|
N = max_height * max_width |
|
|
flat_pos = torch.arange(0, N).float().to(device) |
|
|
x_pos = flat_pos % max_width |
|
|
y_pos = flat_pos // max_width |
|
|
dim_range = ( |
|
|
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device) |
|
|
) |
|
|
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim)) |
|
|
x_freqs = torch.outer(x_pos, freqs).float() |
|
|
y_freqs = torch.outer(y_pos, freqs).float() |
|
|
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) |
|
|
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) |
|
|
|
|
|
freqs_cis = torch.cat( |
|
|
[x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1 |
|
|
) |
|
|
|
|
|
freqs_cis = freqs_cis.reshape(max_height, max_width, -1) |
|
|
return freqs_cis |
|
|
|
|
|
def get_freqs_cis(self, grid_hws: torch.Tensor, down_scale_rate=1, init_freqs=False) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
grid_hws (torch.Tensor): grid height and width |
|
|
Returns: |
|
|
freqs_cis: tensor of shape (sum(t * height * width), dim//2) |
|
|
""" |
|
|
max_height = self.max_height // down_scale_rate |
|
|
max_width = self.max_width // down_scale_rate |
|
|
|
|
|
if self.freqs_cis is None or init_freqs: |
|
|
self.freqs_cis = self._precompute_freqs_cis(down_scale_rate, grid_hws.device) |
|
|
|
|
|
shapes = grid_hws.tolist() |
|
|
assert all( |
|
|
1 <= h <= max_height and 1 <= w <= max_width for h, w in shapes |
|
|
), ( |
|
|
shapes, |
|
|
max_height, |
|
|
max_width, |
|
|
) |
|
|
freqs_cis = torch.cat( |
|
|
[self.freqs_cis[:int(h), :int(w)].reshape(-1, self.dim // 2) for h, w in shapes], |
|
|
dim=0, |
|
|
) |
|
|
return freqs_cis |
|
|
|
|
|
class MLP2(nn.Module): |
|
|
""" |
|
|
Args: |
|
|
dims: [in_dim, hidden_dim, out_dim] |
|
|
bias: whether to use bias in linear layer. |
|
|
""" |
|
|
|
|
|
def __init__(self, dims: list[int], activation, bias=True): |
|
|
super().__init__() |
|
|
assert len(dims) == 3 |
|
|
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) |
|
|
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) |
|
|
self.activation = activation |
|
|
for m in [self.fc0, self.fc1]: |
|
|
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.fc0(x) |
|
|
x = self.activation(x) |
|
|
return self.fc1(x) |
|
|
|
|
|
class PatchMergingLayer(nn.Module): |
|
|
def __init__(self, embed_dim, enable_merging=True, merging_method="avg_pooling", norm_layer=nn.LayerNorm): |
|
|
""" |
|
|
:param embed_dim: Transformer token 的嵌入维度 |
|
|
:param enable_merging: 是否启用 token 合并功能 |
|
|
:param merging_method: 选择 'mlp' 或 'avg_pooling' 作为合并方式 |
|
|
""" |
|
|
super().__init__() |
|
|
self.enable_merging = enable_merging |
|
|
self.merging_method = merging_method |
|
|
self.zero_init_fc = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
if self.merging_method == 'avg_pooling': |
|
|
pass |
|
|
elif self.merging_method == 'm_pooling': |
|
|
self.attn_layer = nn.Sequential( |
|
|
nn.Linear(embed_dim * 2, embed_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(embed_dim, embed_dim) |
|
|
) |
|
|
self.num_head = 16 |
|
|
|
|
|
def forward(self, x, cu_seqlens, spatial_shapes): |
|
|
if not self.enable_merging: |
|
|
return x, cu_seqlens |
|
|
cu_seqlens_out = cu_seqlens.clone() |
|
|
feature_x = x |
|
|
x_i_list = [] |
|
|
for i in range(1, len(cu_seqlens)): |
|
|
start_idx = cu_seqlens[i-1].item() |
|
|
end_idx = cu_seqlens[i].item() |
|
|
x_i = x[start_idx:end_idx, :] |
|
|
h, w = spatial_shapes[i-1] |
|
|
x_i = x_i.view(int(h), int(w), -1) |
|
|
|
|
|
if self.merging_method == 'avg_pooling': |
|
|
x_i = rearrange(x_i, 'h w c -> c h w') |
|
|
x_i = F.avg_pool2d(x_i, kernel_size=2, stride=2) |
|
|
x_i = rearrange(x_i, 'c h w -> (h w) c') |
|
|
elif self.merging_method == 'm_pooling': |
|
|
x_i = rearrange(x_i, '(h p1) (w p2) c -> (h w) (p1 p2) c', p1=2, p2=2) |
|
|
pooled_x_i = x_i.mean(-2, keepdim=True).expand(-1, 4, -1) |
|
|
fused_x_i = torch.cat([x_i, pooled_x_i], dim=-1) |
|
|
attn_logits = self.attn_layer(fused_x_i) |
|
|
|
|
|
attn_logits = rearrange(attn_logits, 'n s (m d) -> n m s d', m=self.num_head) |
|
|
attn_weights = F.softmax(attn_logits, dim=-2) |
|
|
attn_weights = rearrange(attn_weights, 'n m s d -> n s (m d)') |
|
|
|
|
|
x_i = (x_i * attn_weights).sum(-2) |
|
|
|
|
|
x_i_list.append(x_i) |
|
|
cu_seqlens_out[i] = cu_seqlens_out[i-1] + x_i.shape[0] |
|
|
x = torch.cat(x_i_list, dim=0) |
|
|
return x, cu_seqlens_out, spatial_shapes//2, feature_x |
|
|
|
|
|
class MoonVitEncoderLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
layer_idx: int, |
|
|
num_heads: int, |
|
|
hidden_dim: int, |
|
|
mlp_dim: int, |
|
|
*, |
|
|
attn_implementation: str = "eager", |
|
|
activation=F.gelu, |
|
|
attn_bias: bool = False, |
|
|
enable_merging: bool = False, |
|
|
merging_method: str = "avg_pooling", |
|
|
merger_layer_index: List[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.hidden_dim = hidden_dim |
|
|
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads |
|
|
self.attn_implementation = attn_implementation |
|
|
|
|
|
self.norm0 = nn.LayerNorm(hidden_dim) |
|
|
self.norm1 = nn.LayerNorm(hidden_dim) |
|
|
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) |
|
|
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) |
|
|
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) |
|
|
|
|
|
if merger_layer_index is not None and layer_idx in merger_layer_index: |
|
|
self.merger = PatchMergingLayer( |
|
|
embed_dim=hidden_dim, |
|
|
enable_merging=enable_merging, |
|
|
merging_method=merging_method, |
|
|
) |
|
|
else: |
|
|
self.merger = None |
|
|
|
|
|
def attention_qkvpacked( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cu_seqlens: torch.Tensor, |
|
|
rope_freqs_cis: Optional[torch.Tensor] = None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): (batch_size, seqlen, hidden_dim) |
|
|
cu_seqlens (torch.Tensor): |
|
|
""" |
|
|
xqkv = self.wqkv(x) |
|
|
|
|
|
qkv_shape = xqkv.size()[:-1] + ( |
|
|
3, |
|
|
self.num_heads, |
|
|
self.hidden_size_per_attention_head, |
|
|
) |
|
|
|
|
|
xqkv = xqkv.view(*qkv_shape) |
|
|
xq, xk, xv = torch.unbind(xqkv, dim=-3) |
|
|
|
|
|
xq, xk = apply_rope(xq, xk, rope_freqs_cis) |
|
|
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] |
|
|
attn_out = attn_func( |
|
|
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens |
|
|
) |
|
|
|
|
|
attn_out = self.wo(attn_out) |
|
|
return attn_out |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cu_seqlens: torch.Tensor, |
|
|
rope_freqs_cis: Union[torch.Tensor, None] = None, |
|
|
spatial_shapes: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set |
|
|
Returns: |
|
|
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input |
|
|
""" |
|
|
residual = hidden_states |
|
|
hidden_states = self.norm0(hidden_states) |
|
|
attn_out = self.attention_qkvpacked( |
|
|
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis |
|
|
) |
|
|
hidden_states = residual + attn_out |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.mlp(self.norm1(hidden_states)) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
if self.merger is not None: |
|
|
hidden_states, cu_seqlens, spatial_shapes, feature_x = self.merger( |
|
|
hidden_states, cu_seqlens, spatial_shapes |
|
|
) |
|
|
outputs = (hidden_states, cu_seqlens, spatial_shapes, feature_x) |
|
|
else: |
|
|
outputs = (hidden_states, cu_seqlens) |
|
|
|
|
|
return outputs |
|
|
|
|
|
class MoonVitEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
hidden_dim: int, |
|
|
num_layers: int, |
|
|
block_cfg: dict, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.blocks = nn.ModuleList( |
|
|
[MoonVitEncoderLayer(layer_idx=i, **block_cfg) for i in range(num_layers)] |
|
|
) |
|
|
self.final_layernorm = nn.LayerNorm(hidden_dim) |
|
|
self.rope_2d = Rope2DPosEmb( |
|
|
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512 |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, grid_hws: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws) |
|
|
|
|
|
lengths = torch.cat( |
|
|
( |
|
|
torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype), |
|
|
grid_hws[:, 0] * grid_hws[:, 1], |
|
|
) |
|
|
) |
|
|
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) |
|
|
down_scale_rate = 1 |
|
|
feature_x_list = [] |
|
|
for _, block in enumerate(self.blocks): |
|
|
layer_outputs = block( |
|
|
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis, spatial_shapes=grid_hws |
|
|
) |
|
|
if len(layer_outputs) > 2: |
|
|
down_scale_rate *= 2 |
|
|
hidden_states, cu_seqlens, grid_hws, feature_x = layer_outputs |
|
|
rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws, down_scale_rate=down_scale_rate) |
|
|
feature_x_list.append(feature_x) |
|
|
else: |
|
|
hidden_states, cu_seqlens = layer_outputs |
|
|
|
|
|
hidden_states = self.final_layernorm(hidden_states) |
|
|
return hidden_states, grid_hws |
|
|
|
|
|
|
|
|
class Qwen2MLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
|
|
def forward(self, x): |
|
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
return down_proj |
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
|
|
Args: |
|
|
q (`torch.Tensor`): The query tensor. |
|
|
k (`torch.Tensor`): The key tensor. |
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
|
position_ids (`torch.Tensor`, *optional*): |
|
|
Deprecated and unused. |
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
|
Returns: |
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
|
""" |
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def eager_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
scaling: float, |
|
|
dropout: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
if attention_mask is not None: |
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class Qwen2Attention(nn.Module): |
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
def __init__(self, config, layer_idx: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.attention_dropout = config.attention_dropout |
|
|
self.is_causal = True |
|
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) |
|
|
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
|
|
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) |
|
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
past_key_value: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
input_shape = hidden_states.shape[:-1] |
|
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
sliding_window = None |
|
|
if ( |
|
|
self.config.use_sliding_window |
|
|
and getattr(self.config, "sliding_window", None) is not None |
|
|
and self.layer_idx >= self.config.max_window_layers |
|
|
): |
|
|
sliding_window = self.config.sliding_window |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
if self.config.attn_implementation != "eager": |
|
|
if self.config.attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
|
|
logger.warning_once( |
|
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
|
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
|
) |
|
|
else: |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config.attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
scaling=self.scaling, |
|
|
sliding_window=sliding_window, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output, attn_weights |
|
|
|
|
|
class Qwen2RMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
""" |
|
|
Qwen2RMSNorm is equivalent to T5LayerNorm |
|
|
""" |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
def extra_repr(self): |
|
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
|
class Qwen2DecoderLayer(nn.Module): |
|
|
def __init__(self, config, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) |
|
|
self.mlp = Qwen2MLP(config) |
|
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
if output_attentions: |
|
|
outputs += (self_attn_weights,) |
|
|
|
|
|
return outputs |
|
|
|
|
|
class Qwen2RotaryEmbedding(nn.Module): |
|
|
def __init__(self, config, device=None): |
|
|
super().__init__() |
|
|
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
|
else: |
|
|
self.rope_type = "default" |
|
|
self.max_seq_len_cached = config.max_position_embeddings |
|
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
|
|
self.config = config |
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self.original_inv_freq = self.inv_freq |
|
|
|
|
|
@torch.no_grad() |
|
|
@dynamic_rope_update |
|
|
def forward(self, x, position_ids): |
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() * self.attention_scaling |
|
|
sin = emb.sin() * self.attention_scaling |
|
|
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
class Qwen2vlPatchMerger(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
image_embed_dim=1024, |
|
|
compression_factor=(2,2), |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6) |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.image_embed_dim = image_embed_dim |
|
|
self.hidden_size = image_embed_dim * (compression_factor[0]*compression_factor[1]) |
|
|
self.nl = norm_layer(image_embed_dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, self.hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.hidden_size, embed_dim), |
|
|
) |
|
|
self.compression_factor = compression_factor |
|
|
|
|
|
def forward(self, x, tgt_size=(24,24), attn_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
height, width = tgt_size |
|
|
|
|
|
if height * width != x.shape[1]: |
|
|
x = x[:, :int(height * width)] |
|
|
x = self.nl(x) |
|
|
|
|
|
x = x.permute(0, 2, 1).unflatten(-1, (int(height), int(width))) |
|
|
batch_size, dim, height, width = x.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unfolded = x.unfold(2, self.compression_factor[0], self.compression_factor[0]).unfold(3, self.compression_factor[1], self.compression_factor[1]) |
|
|
unfolded = unfolded.contiguous().view(batch_size, dim, -1, self.compression_factor[0] * self.compression_factor[1]) |
|
|
unfolded = unfolded.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, dim*self.compression_factor[0] * self.compression_factor[1]) |
|
|
compressed_x = self.mlp(unfolded) |
|
|
return compressed_x |
|
|
|
|
|
class LlavaUHDV3PretrainedModel(PreTrainedModel): |
|
|
config: LlavaUHDV3Config |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen2DecoderLayer", "MoonViTEncoderLayer"] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
|
|
|
_can_compile_fullgraph = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
def __init__(self, config, *args, **kwargs): |
|
|
super().__init__(config, *args, **kwargs) |
|
|
|
|
|
class LlavaUHDV3VisionTransformerPretrainedModel(LlavaUHDV3PretrainedModel): |
|
|
config: LlavaUHDV3VisionConfig |
|
|
_no_split_modules = ["MoonViTEncoderLayer"] |
|
|
|
|
|
def __init__(self, config, *inputs, **kwargs): |
|
|
super().__init__(config, *inputs, **kwargs) |
|
|
config = deepcopy(config) |
|
|
self.patch_size = config.patch_size |
|
|
self.patch_embed = MoonVisionPatchEmbed( |
|
|
out_dim=config.hidden_size, |
|
|
patch_size=config.patch_size, |
|
|
pos_emb_height=config.init_pos_emb_height, |
|
|
pos_emb_width=config.init_pos_emb_width, |
|
|
) |
|
|
|
|
|
if hasattr(config, "merger_layer_index"): |
|
|
merger_layer_index = config.merger_layer_index |
|
|
merging_method = config.merging_method |
|
|
|
|
|
if merger_layer_index is not None: |
|
|
enable_merging = True |
|
|
merging_method = merging_method if merging_method is not None else "avg_pooling" |
|
|
else: |
|
|
enable_merging = False |
|
|
merging_method = None |
|
|
|
|
|
self.encoder = MoonVitEncoder( |
|
|
hidden_dim=config.hidden_size, |
|
|
num_layers=config.num_hidden_layers, |
|
|
block_cfg={ |
|
|
"num_heads": config.num_attention_heads, |
|
|
"hidden_dim": config.hidden_size, |
|
|
"mlp_dim": config.intermediate_size, |
|
|
"activation": PytorchGELUTanh(), |
|
|
"attn_bias": True, |
|
|
"attn_implementation": self.config.attn_implementation, |
|
|
"enable_merging": enable_merging, |
|
|
"merging_method": merging_method, |
|
|
"merger_layer_index": merger_layer_index, |
|
|
}, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, pixel_values: torch.Tensor, grid_hws: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
pixel_values (torch.Tensor): The input pixel values. |
|
|
grid_hws (torch.Tensor): The grid height and width. |
|
|
Returns: |
|
|
torch.Tensor: The output tokens. |
|
|
""" |
|
|
pixel_values = pixel_values.to(torch.bfloat16) |
|
|
hidden_states = self.patch_embed(pixel_values, grid_hws) |
|
|
image_features, grid_hws = self.encoder(hidden_states, grid_hws) |
|
|
output_features = [] |
|
|
offset = 0 |
|
|
for grid_hw in grid_hws: |
|
|
h, w = grid_hw |
|
|
num_tokens = int(h * w) |
|
|
output_features.append(image_features[offset: offset+num_tokens].unsqueeze(0)) |
|
|
offset += num_tokens |
|
|
assert offset == image_features.shape[0], \ |
|
|
f"Used {offset} tokens, but image_features has {image_features.shape[0]} tokens!" |
|
|
return output_features |
|
|
|
|
|
class LlavaUHDV3TextModel(LlavaUHDV3PretrainedModel): |
|
|
config: LlavaUHDV3TextConfig |
|
|
_no_split_modules = ["Qwen2DecoderLayer"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
self.layers = nn.ModuleList( |
|
|
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = Qwen2RotaryEmbedding(config=config) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs, |
|
|
) -> BaseModelOutputWithPast: |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if use_cache and past_key_values is None: |
|
|
past_key_values = DynamicCache() |
|
|
|
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
causal_mask = self._update_causal_mask( |
|
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
|
) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
|
|
|
for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
partial(decoder_layer.__call__, **kwargs), |
|
|
hidden_states, |
|
|
causal_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
cache_position, |
|
|
position_embeddings, |
|
|
) |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=past_key_values if use_cache else None, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
) |
|
|
|
|
|
def _update_causal_mask( |
|
|
self, |
|
|
attention_mask: torch.Tensor, |
|
|
input_tensor: torch.Tensor, |
|
|
cache_position: torch.Tensor, |
|
|
past_key_values: Cache, |
|
|
output_attentions: bool = False, |
|
|
): |
|
|
if self.config.attn_implementation == "flash_attention_2": |
|
|
if attention_mask is not None and past_key_values is not None: |
|
|
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] |
|
|
if is_padding_right: |
|
|
raise ValueError( |
|
|
"You are attempting to perform batched generation with padding_side='right'" |
|
|
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " |
|
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
|
|
) |
|
|
if attention_mask is not None and 0.0 in attention_mask: |
|
|
return attention_mask |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
using_static_cache = isinstance(past_key_values, StaticCache) |
|
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.config.attn_implementation == "sdpa" |
|
|
and not (using_static_cache or using_sliding_window_cache) |
|
|
and not output_attentions |
|
|
): |
|
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
|
attention_mask, |
|
|
inputs_embeds=input_tensor, |
|
|
past_key_values_length=past_seen_tokens, |
|
|
sliding_window=self.config.sliding_window, |
|
|
is_training=self.training, |
|
|
): |
|
|
return None |
|
|
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
sequence_length = input_tensor.shape[1] |
|
|
|
|
|
if using_sliding_window_cache or using_static_cache: |
|
|
target_length = past_key_values.get_max_cache_shape() |
|
|
|
|
|
else: |
|
|
target_length = ( |
|
|
attention_mask.shape[-1] |
|
|
if isinstance(attention_mask, torch.Tensor) |
|
|
else past_seen_tokens + sequence_length + 1 |
|
|
) |
|
|
|
|
|
|
|
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask, |
|
|
sequence_length=sequence_length, |
|
|
target_length=target_length, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
cache_position=cache_position, |
|
|
batch_size=input_tensor.shape[0], |
|
|
config=self.config, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
|
|
|
if ( |
|
|
self.config.attn_implementation == "sdpa" |
|
|
and attention_mask is not None |
|
|
and attention_mask.device.type in ["cuda", "xpu"] |
|
|
and not output_attentions |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
|
|
|
return causal_mask |
|
|
|
|
|
@staticmethod |
|
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask: torch.Tensor, |
|
|
sequence_length: int, |
|
|
target_length: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
cache_position: torch.Tensor, |
|
|
batch_size: int, |
|
|
config, |
|
|
past_key_values: Cache, |
|
|
): |
|
|
""" |
|
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
|
|
Args: |
|
|
attention_mask (`torch.Tensor`): |
|
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
|
|
sequence_length (`int`): |
|
|
The sequence length being processed. |
|
|
target_length (`int`): |
|
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
|
|
dtype (`torch.dtype`): |
|
|
The dtype to use for the 4D attention mask. |
|
|
device (`torch.device`): |
|
|
The device to place the 4D attention mask on. |
|
|
cache_position (`torch.Tensor`): |
|
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
|
batch_size (`torch.Tensor`): |
|
|
Batch size. |
|
|
config (`Qwen2Config`): |
|
|
The model's configuration class |
|
|
past_key_values (`Cache`): |
|
|
The cache class that is being used currently to generate |
|
|
""" |
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
|
|
causal_mask = attention_mask |
|
|
else: |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
causal_mask = torch.full( |
|
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
|
|
) |
|
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
|
if config.sliding_window is not None: |
|
|
|
|
|
|
|
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: |
|
|
sliding_attend_mask = torch.arange(target_length, device=device) <= ( |
|
|
cache_position.reshape(-1, 1) - config.sliding_window |
|
|
) |
|
|
diagonal_attend_mask.bitwise_or_(sliding_attend_mask) |
|
|
causal_mask *= diagonal_attend_mask |
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
|
if attention_mask is not None: |
|
|
causal_mask = causal_mask.clone() |
|
|
if attention_mask.shape[-1] > target_length: |
|
|
attention_mask = attention_mask[:, :target_length] |
|
|
mask_length = attention_mask.shape[-1] |
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
|
|
causal_mask.device |
|
|
) |
|
|
padding_mask = padding_mask == 0 |
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
|
padding_mask, min_dtype |
|
|
) |
|
|
return causal_mask |
|
|
|
|
|
class LlavaUHDV3Model(LlavaUHDV3PretrainedModel): |
|
|
config_class = LlavaUHDV3Config |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
config.model_type = "llava_uhd_v3" |
|
|
config.rope_scaling = None |
|
|
self.visual = LlavaUHDV3VisionTransformerPretrainedModel._from_config(config.vision_config) |
|
|
self.language_model = LlavaUHDV3TextModel._from_config(config.text_config) |
|
|
self.projector = Qwen2vlPatchMerger( |
|
|
embed_dim=config.text_config.hidden_size, |
|
|
image_embed_dim=config.vision_config.hidden_size, |
|
|
compression_factor=(2, 2), |
|
|
) |
|
|
self.rope_deltas = None |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_image_features(self, pixel_values, grid_hws): |
|
|
down_smaple_ratio = 1 |
|
|
merger_layer_index = getattr(self.config.vision_config, "merger_layer_index", None) |
|
|
if merger_layer_index is not None: |
|
|
down_smaple_ratio = down_smaple_ratio * len(merger_layer_index)**2 |
|
|
image_features = self.visual(pixel_values, grid_hws) |
|
|
projected_image_feaures = [] |
|
|
for image_feature, grid_hw in zip(image_features, grid_hws): |
|
|
grid_hw = (grid_hw[0]//down_smaple_ratio, grid_hw[1]//down_smaple_ratio) |
|
|
projected_image_feature = self.projector(image_feature, tgt_size=grid_hw)[0] |
|
|
projected_image_feaures.append(projected_image_feature) |
|
|
return projected_image_feaures |
|
|
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
|
self, |
|
|
input_ids, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
past_key_values, |
|
|
labels, |
|
|
pixel_values, |
|
|
grid_hws |
|
|
): |
|
|
if pixel_values is None or input_ids.shape[1] == 1: |
|
|
return input_ids, position_ids, attention_mask, past_key_values, None, labels |
|
|
image_features = self.get_image_features(pixel_values, grid_hws) |
|
|
_labels = labels |
|
|
_position_ids = position_ids |
|
|
_attention_mask = attention_mask |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
|
else: |
|
|
attention_mask = attention_mask.bool() |
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
|
|
if labels is None: |
|
|
labels = torch.full_like(input_ids, -100) |
|
|
|
|
|
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] |
|
|
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
|
|
|
|
|
new_input_embeds = [] |
|
|
new_labels = [] |
|
|
cur_image_idx = 0 |
|
|
|
|
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
|
num_images = (cur_input_ids == -200).sum() |
|
|
if num_images == 0: |
|
|
cur_image_features = image_features[cur_image_idx] |
|
|
cur_input_embeds_1 = self.language_model.embed_tokens(cur_input_ids) |
|
|
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0][0:0]], dim=0) |
|
|
new_input_embeds.append(cur_input_embeds) |
|
|
new_labels.append(labels[batch_idx]) |
|
|
cur_image_idx += 1 |
|
|
continue |
|
|
image_token_indices = [-1] + torch.where(cur_input_ids == -200)[0].tolist() + [cur_input_ids.shape[0]] |
|
|
cur_input_ids_noim = [] |
|
|
cur_labels = labels[batch_idx] |
|
|
cur_labels_noim = [] |
|
|
for i in range(len(image_token_indices) - 1): |
|
|
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
|
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
|
split_sizes = [x.shape[0] for x in cur_labels_noim] |
|
|
cur_input_embeds = self.language_model.embed_tokens(torch.cat(cur_input_ids_noim)) |
|
|
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
|
|
cur_new_input_embeds = [] |
|
|
cur_new_labels = [] |
|
|
for i in range(num_images + 1): |
|
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
|
|
cur_new_labels.append(cur_labels_noim[i]) |
|
|
if i < num_images: |
|
|
try: |
|
|
cur_image_features = image_features[cur_image_idx] |
|
|
except IndexError: |
|
|
cur_image_features = image_features[cur_image_idx - 1] |
|
|
cur_image_idx += 1 |
|
|
cur_new_input_embeds.append(cur_image_features) |
|
|
cur_new_labels.append(torch.full((cur_image_features.shape[0],), -100, device=cur_labels.device, dtype=cur_labels.dtype)) |
|
|
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] |
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
|
|
cur_new_labels = torch.cat(cur_new_labels) |
|
|
new_input_embeds.append(cur_new_input_embeds) |
|
|
new_labels.append(cur_new_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", 4096) |
|
|
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
|
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
|
batch_size = len(new_input_embeds) |
|
|
new_input_embeds_padded = [] |
|
|
new_labels_padded = torch.full((batch_size, max_len), -100, dtype=new_labels[0].dtype, device=new_labels[0].device) |
|
|
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) |
|
|
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
|
|
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
|
|
cur_len = cur_new_embed.shape[0] |
|
|
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) |
|
|
if cur_len > 0: |
|
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
|
attention_mask[i, :cur_len] = True |
|
|
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
|
|
if _labels is None: |
|
|
new_labels = None |
|
|
else: |
|
|
new_labels = new_labels_padded |
|
|
|
|
|
if _attention_mask is None: |
|
|
attention_mask = None |
|
|
else: |
|
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
|
|
if _position_ids is None: |
|
|
position_ids = None |
|
|
|
|
|
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids = None, |
|
|
position_ids = None, |
|
|
attention_mask = None, |
|
|
past_key_values = None, |
|
|
inputs_embeds = None, |
|
|
labels = None, |
|
|
use_cache = None, |
|
|
output_attentions = None, |
|
|
output_hidden_states = None, |
|
|
pixel_values = None, |
|
|
grid_hws = None, |
|
|
return_dict = None, |
|
|
**kwargs, |
|
|
): |
|
|
if inputs_embeds is None: |
|
|
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal( |
|
|
input_ids, position_ids, attention_mask, past_key_values, labels, pixel_values, grid_hws |
|
|
) |
|
|
|
|
|
output = self.language_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
**kwargs, |
|
|
) |
|
|
if labels is not None: |
|
|
return output[0], labels |
|
|
return output |
|
|
|
|
|
class LlavaUHDV3ForCausalLM(LlavaUHDV3PretrainedModel, GenerationMixin): |
|
|
config_class = LlavaUHDV3Config |
|
|
_checkpoint_conversion_mapping = { |
|
|
"^visual": "model.visual", |
|
|
r"^model(?!\.(language_model|visual|projector))": "model.language_model", |
|
|
} |
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = LlavaUHDV3Model(config) |
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
@property |
|
|
def language_model(self): |
|
|
return self.model.language_model |
|
|
|
|
|
@property |
|
|
def visual(self): |
|
|
return self.model.visual |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.language_model.embed_tokens |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def forward(self, input_ids, labels=None, attention_mask=None, pixel_values=None, grid_hws=None, **kwargs): |
|
|
if labels is not None: |
|
|
outputs, labels = self.model(input_ids, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs) |
|
|
else: |
|
|
outputs = self.model(input_ids, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
slice_indices = slice(0, None) |
|
|
logits = self.lm_head(hidden_states[:,slice_indices,:]) |
|
|
loss = None |
|
|
|
|
|
if labels is not None: |
|
|
loss = self.loss_function( |
|
|
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs |
|
|
) |
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
grid_hws: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
): |
|
|
position_ids = kwargs.pop("position_ids", None) |
|
|
attention_mask = kwargs.pop("attention_mask", None) |
|
|
if "inputs_embeds" in kwargs: |
|
|
raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
|
|
|
if pixel_values is not None: |
|
|
input_ids, position_ids, attention_mask, _, inputs_embeds, _ = self.model.prepare_inputs_labels_for_multimodal( |
|
|
input_ids, position_ids, attention_mask, None, None, pixel_values, grid_hws |
|
|
) |
|
|
else: |
|
|
inputs_embeds = self.model.language_model.embed_tokens(input_ids) |
|
|
|
|
|
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
|
|
pixel_values = kwargs.pop("pixel_values", None) |
|
|
grid_hws = kwargs.pop("grid_hws", None) |
|
|
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) |
|
|
if pixel_values is not None: |
|
|
inputs["pixel_values"] = pixel_values |
|
|
if grid_hws is not None: |
|
|
inputs["grid_hws"] = grid_hws |
|
|
return inputs |
|
|
|
|
|
__all__ = ["LlavaUHDV3ForCausalLM", "LlavaUHDV3Model", "LlavaUHDV3PretrainedModel", "LlavaUHDV3TextModel"] |
|
|
|
|
|
|