GexT_V1 / modeling_gex.py
MosRat's picture
Upload folder using huggingface_hub
322823c verified
from typing import List, Optional, Tuple, Union
from functools import partial
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torchvision import transforms
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from torchvision.transforms.functional import InterpolationMode
from transformers import (
Qwen2Config,
Qwen2Model,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Qwen3Model,
Qwen3Config,
)
try:
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers import LigerLayerNorm
from liger_kernel.transformers.layer_norm import LigerLayerNormFunction
def liger_layer_norm(input, normalized_shape, weight, bias, eps):
return LigerLayerNormFunction.apply(input, weight, bias, eps)
use_liger = True
except ImportError:
use_liger = False
from .configuration_gex import GexConfig, GexTConfig
LayerNorm = (
partial(LigerLayerNorm, bias=True) if use_liger else partial(nn.LayerNorm, eps=1e-6)
)
layer_norm = liger_layer_norm if use_liger else torch.nn.functional.layer_norm
BOS_TOEKN_IDS: int = 151652
EOS_TOEKN_IDS: int = 151643
IMG_PAD_IDS: int = 151655
IMG_END_IDS: int = 25
@torch.no_grad
def process_batch_labels(labels, pad_token_id=EOS_TOEKN_IDS):
# 创建 mask:标记所有 pad_token_id 的位置
pad_mask = labels == pad_token_id
# 找到每个样本第一个 pad_token_id 的位置
first_pad_pos = pad_mask.int().argmax(dim=1, keepdim=True) # shape: (bsz,)
first_pad_pos[first_pad_pos == 0] = 256
# 生成要替换为 -100 的位置 mask
replace_mask = torch.arange(labels.size(1), device=labels.device) > first_pad_pos
# 执行替换(保留第一个 pad_token_id)
labels[replace_mask] = -100
return labels
class GexImageEvalProcessor:
def __init__(self, image_size=1024, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.num_channels = num_channels
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
return layer_norm(
x,
normalized_shape=(self.num_channels,),
weight=self.weight,
bias=self.bias,
eps=self.eps,
).permute(0, 3, 1, 2)
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = 64
self.scale = 64**-0.5
self.seq_len = input_size[0] * input_size[1]
self.input_size = input_size
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
# self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
# self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
self.rel_pos_h = nn.Parameter(
torch.zeros(input_size[0], input_size[0], self.head_dim)
)
self.rel_pos_w = nn.Parameter(
torch.zeros(input_size[1], input_size[1], self.head_dim)
)
def init_rel_pos(self):
q_size, k_size = self.input_size
q_coords = torch.arange(q_size)[:, None]
k_coords = torch.arange(k_size)[None, :]
relative_coords = (q_coords - k_coords) + (k_size - 1)
self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()])
self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()])
def get_attn_bias(self, q: torch.Tensor):
q = q.view(-1, *self.input_size, 64)
rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h)
rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w)
return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape(
-1, self.num_heads, self.seq_len, self.seq_len
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = torch.split(
self.qkv(x).view(-1, self.seq_len, 3 * 768),
768,
dim=2,
)
q, k, v = (
i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous()
for i in qkv
)
attn_bias = self.get_attn_bias(q)
with sdpa_kernel(
[
SDPBackend.FLASH_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
],
set_priority=True,
):
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias, is_causal=False
)
attn_output = attn_output.transpose(1, 2).flatten(-2)
x = self.proj(attn_output)
return x.view(-1, *self.input_size, 768)
class MLP(nn.Module):
def __init__(
self,
):
super().__init__()
self.lin1 = nn.Linear(768, 4 * 768)
self.lin2 = nn.Linear(4 * 768, 768)
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
class Block(nn.Module):
def __init__(self, idx: int, window_size: int = 14):
super().__init__()
self.idx = idx
self.window_size = window_size
self.norm1 = LayerNorm(768)
self.attn = Attention(
dim=768,
num_heads=12,
input_size=(64, 64) if window_size == 0 else (14, 14),
)
self.norm2 = LayerNorm(768)
self.mlp = MLP()
@staticmethod
def window_partition(x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, (0, 0, 0, 6, 0, 6))
x = (
x.view(-1, 5, 14, 5, 14, 768)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, 14, 14, 768)
)
return x
@staticmethod
def window_unpartition(x: torch.Tensor) -> torch.Tensor:
x = (
x.view(-1, 5, 5, 14, 14, 768)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, 70, 70, 768)
)
return x[:, :64, :64, :].contiguous()
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
x = self.window_partition(x)
x = self.attn(x)
if self.window_size > 0:
x = self.window_unpartition(x)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class GexVit(nn.Module):
def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs):
super().__init__()
self.global_attn_indexes = global_attn_indexes
self.patch_embed = PatchEmbed()
self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768))
self.blocks = nn.ModuleList(
[
Block(idx=i, window_size=14 if i not in global_attn_indexes else 0)
for i in range(12)
]
)
self.neck = nn.ModuleList(
[
nn.Conv2d(
768,
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
]
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = x.permute(0, 3, 1, 2)
for m in self.neck:
x = m(x)
x = self.net_2(x)
x = self.net_3(x)
return x
class GexQwenModel(Qwen2Model):
config_class = GexConfig
_auto_class = "AutoModel"
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.vit = GexVit()
self.vit_proj = nn.Linear(1024, 1024)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids)
assert images is not None
# img_pos = input_ids == IMG_PAD_IDS
# if torch.any(img_pos):
vit_feature = self.vit(images).flatten(2).permute(0, 2, 1)
vit_feature = self.vit_proj(vit_feature)
# img_ids = img_pos.nonzero().squeeze_()
# inputs_embeds[img_ids[:, 0], img_ids[:, 1]] = vit_feature.view(-1,1024)
inputs_embeds[:, 1:257, :] = vit_feature
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return super().forward(
input_ids=None,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
class GexQwenForCausalLM(Qwen2ForCausalLM):
config_class = GexConfig
# supports_gradient_checkpointing = True
_auto_class = "AutoModelForCausalLM"
def __init__(self, config):
super().__init__(config)
self.model = GexQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.image_preprocess = GexImageEvalProcessor()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
images: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if labels is not None and input_ids is None:
input_ids: torch.Tensor = labels
shifted_input_ids = input_ids.new_zeros(
(input_ids.shape[0], input_ids.shape[1] + 256), device=input_ids.device
)
shifted_input_ids[:, 257:].copy_(input_ids[:, :-1])
decoder_start_token_id = BOS_TOEKN_IDS
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[:, 1:257] = IMG_PAD_IDS
input_ids = shifted_input_ids
imgs_pad: torch.Tenosr = torch.full(
(1, 256), IMG_PAD_IDS, device=self.device, dtype=torch.long
)
labels = torch.cat(
[
imgs_pad.expand(labels.shape[0], -1),
process_batch_labels(labels),
],
dim=-1,
)
# labels = process_batch_labels(labels)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
images=images,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
# if (past_key_values is None or len(past_key_values) <= 0):
# logits = self.lm_head(hidden_states[:, 256:, :])
# # if labels is not None:
# # lb = labels[:,256:].contiguous()
# # del labels
# # labels = lb
# else:
# slice_indices = (
# slice(-logits_to_keep, None)
# if isinstance(logits_to_keep, int)
# else logits_to_keep
# )
# logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=None,
shift_labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def generate(self, *args, images, **kwargs):
pad = torch.tensor(
[[BOS_TOEKN_IDS] + [IMG_PAD_IDS] * 256],
dtype=torch.long,
device=self.device,
)
if (input_ids := kwargs.pop("input_ids", None)) is not None:
input_ids = torch.cat(
[pad.expand(input_ids.shape[0], -1), input_ids], dim=-1
)
else:
input_ids = pad.expand(images.shape[0], -1)
res = super().generate(
*args,
input_ids=input_ids,
images=images,
max_length=kwargs.pop("max_length", 10) + 257,
**kwargs,
)
return res
class GexTQwenModel(Qwen3Model):
config_class = GexTConfig
_auto_class = "AutoModel"
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.vit = GexVit()
self.vit_proj = nn.Linear(1024, 1024)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids)
assert images is not None
# img_pos = input_ids == IMG_PAD_IDS
# if torch.any(img_pos):
vit_feature = self.vit(images).flatten(2).permute(0, 2, 1)
vit_feature = self.vit_proj(vit_feature)
# img_ids = img_pos.nonzero().squeeze_()
# inputs_embeds[img_ids[:, 0], img_ids[:, 1]] = vit_feature.view(-1,1024)
inputs_embeds[:, 1:257, :] = vit_feature
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
return super().forward(
input_ids=None,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**flash_attn_kwargs,
)
class GexTQwenForCausalLM(Qwen3ForCausalLM):
config_class = GexTConfig
# supports_gradient_checkpointing = True
_auto_class = "AutoModelForCausalLM"
def __init__(self, config):
super().__init__(config)
self.model = GexTQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.image_preprocess = GexImageEvalProcessor()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
images: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
if labels is not None and input_ids is None:
input_ids: torch.Tensor = labels
shifted_input_ids = input_ids.new_zeros(
(input_ids.shape[0], input_ids.shape[1] + 257), device=input_ids.device
)
shifted_input_ids[:, 258:].copy_(input_ids[:, :-1])
decoder_start_token_id = BOS_TOEKN_IDS
shifted_input_ids[:, 0] = decoder_start_token_id
shifted_input_ids[:, 257] = IMG_END_IDS
shifted_input_ids[:, 1:257] = IMG_PAD_IDS
input_ids = shifted_input_ids
imgs_pad: torch.Tenosr = torch.full( # type: ignore
(1, 257), IMG_PAD_IDS, device=self.device, dtype=torch.long
)
imgs_pad[:, -1] = IMG_END_IDS
labels = torch.cat(
[
imgs_pad.expand(labels.shape[0], -1),
process_batch_labels(labels),
],
dim=-1,
) # type: ignore
# labels = process_batch_labels(labels)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
images=images,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# if (past_key_values is None or len(past_key_values) <= 0):
# logits = self.lm_head(hidden_states[:, 256:, :])
# # if labels is not None:
# # lb = labels[:,256:].contiguous()
# # del labels
# # labels = lb
# else:
# slice_indices = (
# slice(-logits_to_keep, None)
# if isinstance(logits_to_keep, int)
# else logits_to_keep
# )
# logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
if self.training and use_liger:
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=None,
shift_labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
logits = None
else:
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = self.loss_function(
logits=logits,
labels=None,
shift_labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
else:
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def generate(self, *args, images, **kwargs):
pad = torch.tensor(
[[BOS_TOEKN_IDS] + [IMG_PAD_IDS] * 256 + [IMG_END_IDS]],
dtype=torch.long,
device=self.device,
)
if (input_ids := kwargs.pop("input_ids", None)) is not None:
input_ids = torch.cat(
[pad.expand(input_ids.shape[0], -1), input_ids], dim=-1
)
else:
input_ids = pad.expand(images.shape[0], -1)
res = super().generate(
*args,
input_ids=input_ids,
images=images,
max_length=kwargs.pop("max_length", 25) + 258,
**kwargs,
)
return res