from typing import List, Optional, Tuple, Union from functools import partial import torch import torch.nn.functional as F import torch.nn as nn from torch.nn.attention import SDPBackend, sdpa_kernel from torchvision import transforms from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from torchvision.transforms.functional import InterpolationMode from transformers import ( Qwen2Config, Qwen2Model, Qwen2ForCausalLM, Qwen3ForCausalLM, Qwen3Model, Qwen3Config, ) try: from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss from liger_kernel.transformers import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNormFunction def liger_layer_norm(input, normalized_shape, weight, bias, eps): return LigerLayerNormFunction.apply(input, weight, bias, eps) use_liger = True except ImportError: use_liger = False from .configuration_gex import GexConfig, GexTConfig LayerNorm = ( partial(LigerLayerNorm, bias=True) if use_liger else partial(nn.LayerNorm, eps=1e-6) ) layer_norm = liger_layer_norm if use_liger else torch.nn.functional.layer_norm BOS_TOEKN_IDS: int = 151652 EOS_TOEKN_IDS: int = 151643 IMG_PAD_IDS: int = 151655 IMG_END_IDS: int = 25 @torch.no_grad def process_batch_labels(labels, pad_token_id=EOS_TOEKN_IDS): # 创建 mask:标记所有 pad_token_id 的位置 pad_mask = labels == pad_token_id # 找到每个样本第一个 pad_token_id 的位置 first_pad_pos = pad_mask.int().argmax(dim=1, keepdim=True) # shape: (bsz,) first_pad_pos[first_pad_pos == 0] = 256 # 生成要替换为 -100 的位置 mask replace_mask = torch.arange(labels.size(1), device=labels.device) > first_pad_pos # 执行替换(保留第一个 pad_token_id) labels[replace_mask] = -100 return labels class GexImageEvalProcessor: def __init__(self, image_size=1024, mean=None, std=None): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(mean, std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.num_channels = num_channels self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) return layer_norm( x, normalized_shape=(self.num_channels,), weight=self.weight, bias=self.bias, eps=self.eps, ).permute(0, 3, 1, 2) class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), in_chans: int = 3, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, input_size: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = 64 self.scale = 64**-0.5 self.seq_len = input_size[0] * input_size[1] self.input_size = input_size self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) # self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) # self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) self.rel_pos_h = nn.Parameter( torch.zeros(input_size[0], input_size[0], self.head_dim) ) self.rel_pos_w = nn.Parameter( torch.zeros(input_size[1], input_size[1], self.head_dim) ) def init_rel_pos(self): q_size, k_size = self.input_size q_coords = torch.arange(q_size)[:, None] k_coords = torch.arange(k_size)[None, :] relative_coords = (q_coords - k_coords) + (k_size - 1) self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()]) self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()]) def get_attn_bias(self, q: torch.Tensor): q = q.view(-1, *self.input_size, 64) rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h) rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w) return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape( -1, self.num_heads, self.seq_len, self.seq_len ) def forward(self, x: torch.Tensor) -> torch.Tensor: qkv = torch.split( self.qkv(x).view(-1, self.seq_len, 3 * 768), 768, dim=2, ) q, k, v = ( i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous() for i in qkv ) attn_bias = self.get_attn_bias(q) with sdpa_kernel( [ SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, ], set_priority=True, ): attn_output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, is_causal=False ) attn_output = attn_output.transpose(1, 2).flatten(-2) x = self.proj(attn_output) return x.view(-1, *self.input_size, 768) class MLP(nn.Module): def __init__( self, ): super().__init__() self.lin1 = nn.Linear(768, 4 * 768) self.lin2 = nn.Linear(4 * 768, 768) self.act = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) class Block(nn.Module): def __init__(self, idx: int, window_size: int = 14): super().__init__() self.idx = idx self.window_size = window_size self.norm1 = LayerNorm(768) self.attn = Attention( dim=768, num_heads=12, input_size=(64, 64) if window_size == 0 else (14, 14), ) self.norm2 = LayerNorm(768) self.mlp = MLP() @staticmethod def window_partition(x: torch.Tensor) -> torch.Tensor: x = F.pad(x, (0, 0, 0, 6, 0, 6)) x = ( x.view(-1, 5, 14, 5, 14, 768) .permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, 14, 14, 768) ) return x @staticmethod def window_unpartition(x: torch.Tensor) -> torch.Tensor: x = ( x.view(-1, 5, 5, 14, 14, 768) .permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, 70, 70, 768) ) return x[:, :64, :64, :].contiguous() def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) if self.window_size > 0: x = self.window_partition(x) x = self.attn(x) if self.window_size > 0: x = self.window_unpartition(x) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class GexVit(nn.Module): def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs): super().__init__() self.global_attn_indexes = global_attn_indexes self.patch_embed = PatchEmbed() self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768)) self.blocks = nn.ModuleList( [ Block(idx=i, window_size=14 if i not in global_attn_indexes else 0) for i in range(12) ] ) self.neck = nn.ModuleList( [ nn.Conv2d( 768, 256, kernel_size=1, bias=False, ), LayerNorm2d(256), nn.Conv2d( 256, 256, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(256), ] ) self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.net_3 = nn.Conv2d( 512, 1024, kernel_size=3, stride=2, padding=1, bias=False ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = x.permute(0, 3, 1, 2) for m in self.neck: x = m(x) x = self.net_2(x) x = self.net_3(x) return x class GexQwenModel(Qwen2Model): config_class = GexConfig _auto_class = "AutoModel" def __init__(self, config: Qwen2Config): super().__init__(config) self.vit = GexVit() self.vit_proj = nn.Linear(1024, 1024) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: if inputs_embeds is None and input_ids is not None: inputs_embeds = self.embed_tokens(input_ids) assert images is not None # img_pos = input_ids == IMG_PAD_IDS # if torch.any(img_pos): vit_feature = self.vit(images).flatten(2).permute(0, 2, 1) vit_feature = self.vit_proj(vit_feature) # img_ids = img_pos.nonzero().squeeze_() # inputs_embeds[img_ids[:, 0], img_ids[:, 1]] = vit_feature.view(-1,1024) inputs_embeds[:, 1:257, :] = vit_feature with sdpa_kernel(SDPBackend.FLASH_ATTENTION): return super().forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) class GexQwenForCausalLM(Qwen2ForCausalLM): config_class = GexConfig # supports_gradient_checkpointing = True _auto_class = "AutoModelForCausalLM" def __init__(self, config): super().__init__(config) self.model = GexQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() self.image_preprocess = GexImageEvalProcessor() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, images: Optional[torch.FloatTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if labels is not None and input_ids is None: input_ids: torch.Tensor = labels shifted_input_ids = input_ids.new_zeros( (input_ids.shape[0], input_ids.shape[1] + 256), device=input_ids.device ) shifted_input_ids[:, 257:].copy_(input_ids[:, :-1]) decoder_start_token_id = BOS_TOEKN_IDS shifted_input_ids[:, 0] = decoder_start_token_id shifted_input_ids[:, 1:257] = IMG_PAD_IDS input_ids = shifted_input_ids imgs_pad: torch.Tenosr = torch.full( (1, 256), IMG_PAD_IDS, device=self.device, dtype=torch.long ) labels = torch.cat( [ imgs_pad.expand(labels.shape[0], -1), process_batch_labels(labels), ], dim=-1, ) # labels = process_batch_labels(labels) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, images=images, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) # if (past_key_values is None or len(past_key_values) <= 0): # logits = self.lm_head(hidden_states[:, 256:, :]) # # if labels is not None: # # lb = labels[:,256:].contiguous() # # del labels # # labels = lb # else: # slice_indices = ( # slice(-logits_to_keep, None) # if isinstance(logits_to_keep, int) # else logits_to_keep # ) # logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=None, shift_labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def generate(self, *args, images, **kwargs): pad = torch.tensor( [[BOS_TOEKN_IDS] + [IMG_PAD_IDS] * 256], dtype=torch.long, device=self.device, ) if (input_ids := kwargs.pop("input_ids", None)) is not None: input_ids = torch.cat( [pad.expand(input_ids.shape[0], -1), input_ids], dim=-1 ) else: input_ids = pad.expand(images.shape[0], -1) res = super().generate( *args, input_ids=input_ids, images=images, max_length=kwargs.pop("max_length", 10) + 257, **kwargs, ) return res class GexTQwenModel(Qwen3Model): config_class = GexTConfig _auto_class = "AutoModel" def __init__(self, config: Qwen3Config): super().__init__(config) self.vit = GexVit() self.vit_proj = nn.Linear(1024, 1024) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: if inputs_embeds is None and input_ids is not None: inputs_embeds = self.embed_tokens(input_ids) assert images is not None # img_pos = input_ids == IMG_PAD_IDS # if torch.any(img_pos): vit_feature = self.vit(images).flatten(2).permute(0, 2, 1) vit_feature = self.vit_proj(vit_feature) # img_ids = img_pos.nonzero().squeeze_() # inputs_embeds[img_ids[:, 0], img_ids[:, 1]] = vit_feature.view(-1,1024) inputs_embeds[:, 1:257, :] = vit_feature with sdpa_kernel(SDPBackend.FLASH_ATTENTION): return super().forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **flash_attn_kwargs, ) class GexTQwenForCausalLM(Qwen3ForCausalLM): config_class = GexTConfig # supports_gradient_checkpointing = True _auto_class = "AutoModelForCausalLM" def __init__(self, config): super().__init__(config) self.model = GexTQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() self.image_preprocess = GexImageEvalProcessor() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, images: Optional[torch.FloatTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if labels is not None and input_ids is None: input_ids: torch.Tensor = labels shifted_input_ids = input_ids.new_zeros( (input_ids.shape[0], input_ids.shape[1] + 257), device=input_ids.device ) shifted_input_ids[:, 258:].copy_(input_ids[:, :-1]) decoder_start_token_id = BOS_TOEKN_IDS shifted_input_ids[:, 0] = decoder_start_token_id shifted_input_ids[:, 257] = IMG_END_IDS shifted_input_ids[:, 1:257] = IMG_PAD_IDS input_ids = shifted_input_ids imgs_pad: torch.Tenosr = torch.full( # type: ignore (1, 257), IMG_PAD_IDS, device=self.device, dtype=torch.long ) imgs_pad[:, -1] = IMG_END_IDS labels = torch.cat( [ imgs_pad.expand(labels.shape[0], -1), process_batch_labels(labels), ], dim=-1, ) # type: ignore # labels = process_batch_labels(labels) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, images=images, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss # if (past_key_values is None or len(past_key_values) <= 0): # logits = self.lm_head(hidden_states[:, 256:, :]) # # if labels is not None: # # lb = labels[:,256:].contiguous() # # del labels # # labels = lb # else: # slice_indices = ( # slice(-logits_to_keep, None) # if isinstance(logits_to_keep, int) # else logits_to_keep # ) # logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: if self.training and use_liger: loss = LigerForCausalLMLoss( hidden_states=hidden_states, lm_head_weight=self.lm_head.weight, labels=None, shift_labels=labels, hidden_size=self.config.hidden_size, **kwargs, ) logits = None else: slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = self.loss_function( logits=logits, labels=None, shift_labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) else: slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def generate(self, *args, images, **kwargs): pad = torch.tensor( [[BOS_TOEKN_IDS] + [IMG_PAD_IDS] * 256 + [IMG_END_IDS]], dtype=torch.long, device=self.device, ) if (input_ids := kwargs.pop("input_ids", None)) is not None: input_ids = torch.cat( [pad.expand(input_ids.shape[0], -1), input_ids], dim=-1 ) else: input_ids = pad.expand(images.shape[0], -1) res = super().generate( *args, input_ids=input_ids, images=images, max_length=kwargs.pop("max_length", 25) + 258, **kwargs, ) return res