|
|
from typing import List, Optional, Tuple, Union |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
|
|
from torchvision import transforms |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPast, |
|
|
CausalLMOutputWithPast, |
|
|
) |
|
|
|
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers import ( |
|
|
Qwen2Config, |
|
|
Qwen2Model, |
|
|
Qwen2ForCausalLM, |
|
|
Qwen3ForCausalLM, |
|
|
Qwen3Model, |
|
|
Qwen3Config, |
|
|
) |
|
|
|
|
|
try: |
|
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss |
|
|
from liger_kernel.transformers import LigerLayerNorm |
|
|
from liger_kernel.transformers.layer_norm import LigerLayerNormFunction |
|
|
|
|
|
def liger_layer_norm(input, normalized_shape, weight, bias, eps): |
|
|
return LigerLayerNormFunction.apply(input, weight, bias, eps) |
|
|
|
|
|
use_liger = True |
|
|
except ImportError: |
|
|
use_liger = False |
|
|
|
|
|
|
|
|
from .configuration_gex import GexConfig, GexTConfig |
|
|
|
|
|
|
|
|
LayerNorm = ( |
|
|
partial(LigerLayerNorm, bias=True) if use_liger else partial(nn.LayerNorm, eps=1e-6) |
|
|
) |
|
|
layer_norm = liger_layer_norm if use_liger else torch.nn.functional.layer_norm |
|
|
|
|
|
BOS_TOEKN_IDS: int = 151652 |
|
|
EOS_TOEKN_IDS: int = 151643 |
|
|
IMG_PAD_IDS: int = 151655 |
|
|
IMG_END_IDS: int = 25 |
|
|
|
|
|
|
|
|
@torch.no_grad |
|
|
def process_batch_labels(labels, pad_token_id=EOS_TOEKN_IDS): |
|
|
|
|
|
pad_mask = labels == pad_token_id |
|
|
|
|
|
|
|
|
first_pad_pos = pad_mask.int().argmax(dim=1, keepdim=True) |
|
|
first_pad_pos[first_pad_pos == 0] = 256 |
|
|
|
|
|
|
|
|
replace_mask = torch.arange(labels.size(1), device=labels.device) > first_pad_pos |
|
|
|
|
|
|
|
|
labels[replace_mask] = -100 |
|
|
|
|
|
return labels |
|
|
|
|
|
|
|
|
class GexImageEvalProcessor: |
|
|
def __init__(self, image_size=1024, mean=None, std=None): |
|
|
if mean is None: |
|
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
|
if std is None: |
|
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
self.normalize = transforms.Normalize(mean, std) |
|
|
|
|
|
self.transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize( |
|
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC |
|
|
), |
|
|
transforms.ToTensor(), |
|
|
self.normalize, |
|
|
] |
|
|
) |
|
|
|
|
|
def __call__(self, item): |
|
|
return self.transform(item) |
|
|
|
|
|
|
|
|
class LayerNorm2d(nn.Module): |
|
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(num_channels)) |
|
|
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|
|
self.num_channels = num_channels |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
return layer_norm( |
|
|
x, |
|
|
normalized_shape=(self.num_channels,), |
|
|
weight=self.weight, |
|
|
bias=self.bias, |
|
|
eps=self.eps, |
|
|
).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
""" |
|
|
Image to Patch Embedding. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
kernel_size: Tuple[int, int] = (16, 16), |
|
|
stride: Tuple[int, int] = (16, 16), |
|
|
in_chans: int = 3, |
|
|
embed_dim: int = 768, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
kernel_size (Tuple): kernel size of the projection layer. |
|
|
stride (Tuple): stride of the projection layer. |
|
|
padding (Tuple): padding size of the projection layer. |
|
|
in_chans (int): Number of input image channels. |
|
|
embed_dim (int): Patch embedding dimension. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.proj = nn.Conv2d( |
|
|
in_chans, embed_dim, kernel_size=kernel_size, stride=stride |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.proj(x) |
|
|
|
|
|
x = x.permute(0, 2, 3, 1) |
|
|
return x |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int = 8, |
|
|
input_size: Optional[Tuple[int, int]] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = 64 |
|
|
self.scale = 64**-0.5 |
|
|
self.seq_len = input_size[0] * input_size[1] |
|
|
self.input_size = input_size |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=True) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
|
|
|
|
|
|
self.rel_pos_h = nn.Parameter( |
|
|
torch.zeros(input_size[0], input_size[0], self.head_dim) |
|
|
) |
|
|
self.rel_pos_w = nn.Parameter( |
|
|
torch.zeros(input_size[1], input_size[1], self.head_dim) |
|
|
) |
|
|
|
|
|
def init_rel_pos(self): |
|
|
q_size, k_size = self.input_size |
|
|
q_coords = torch.arange(q_size)[:, None] |
|
|
|
|
|
k_coords = torch.arange(k_size)[None, :] |
|
|
relative_coords = (q_coords - k_coords) + (k_size - 1) |
|
|
|
|
|
self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()]) |
|
|
self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()]) |
|
|
|
|
|
def get_attn_bias(self, q: torch.Tensor): |
|
|
q = q.view(-1, *self.input_size, 64) |
|
|
|
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h) |
|
|
rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w) |
|
|
|
|
|
return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape( |
|
|
-1, self.num_heads, self.seq_len, self.seq_len |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
qkv = torch.split( |
|
|
self.qkv(x).view(-1, self.seq_len, 3 * 768), |
|
|
768, |
|
|
dim=2, |
|
|
) |
|
|
|
|
|
q, k, v = ( |
|
|
i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous() |
|
|
for i in qkv |
|
|
) |
|
|
|
|
|
attn_bias = self.get_attn_bias(q) |
|
|
with sdpa_kernel( |
|
|
[ |
|
|
SDPBackend.FLASH_ATTENTION, |
|
|
SDPBackend.CUDNN_ATTENTION, |
|
|
SDPBackend.EFFICIENT_ATTENTION, |
|
|
], |
|
|
set_priority=True, |
|
|
): |
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
q, k, v, attn_mask=attn_bias, is_causal=False |
|
|
) |
|
|
attn_output = attn_output.transpose(1, 2).flatten(-2) |
|
|
|
|
|
x = self.proj(attn_output) |
|
|
|
|
|
return x.view(-1, *self.input_size, 768) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
): |
|
|
super().__init__() |
|
|
self.lin1 = nn.Linear(768, 4 * 768) |
|
|
self.lin2 = nn.Linear(4 * 768, 768) |
|
|
self.act = nn.GELU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.lin2(self.act(self.lin1(x))) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, idx: int, window_size: int = 14): |
|
|
super().__init__() |
|
|
|
|
|
self.idx = idx |
|
|
self.window_size = window_size |
|
|
|
|
|
self.norm1 = LayerNorm(768) |
|
|
|
|
|
self.attn = Attention( |
|
|
dim=768, |
|
|
num_heads=12, |
|
|
input_size=(64, 64) if window_size == 0 else (14, 14), |
|
|
) |
|
|
|
|
|
self.norm2 = LayerNorm(768) |
|
|
self.mlp = MLP() |
|
|
|
|
|
@staticmethod |
|
|
def window_partition(x: torch.Tensor) -> torch.Tensor: |
|
|
x = F.pad(x, (0, 0, 0, 6, 0, 6)) |
|
|
x = ( |
|
|
x.view(-1, 5, 14, 5, 14, 768) |
|
|
.permute(0, 1, 3, 2, 4, 5) |
|
|
.contiguous() |
|
|
.view(-1, 14, 14, 768) |
|
|
) |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def window_unpartition(x: torch.Tensor) -> torch.Tensor: |
|
|
x = ( |
|
|
x.view(-1, 5, 5, 14, 14, 768) |
|
|
.permute(0, 1, 3, 2, 4, 5) |
|
|
.contiguous() |
|
|
.view(-1, 70, 70, 768) |
|
|
) |
|
|
return x[:, :64, :64, :].contiguous() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
shortcut = x |
|
|
x = self.norm1(x) |
|
|
if self.window_size > 0: |
|
|
x = self.window_partition(x) |
|
|
|
|
|
x = self.attn(x) |
|
|
|
|
|
if self.window_size > 0: |
|
|
x = self.window_unpartition(x) |
|
|
|
|
|
x = shortcut + x |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class GexVit(nn.Module): |
|
|
def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs): |
|
|
super().__init__() |
|
|
self.global_attn_indexes = global_attn_indexes |
|
|
self.patch_embed = PatchEmbed() |
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768)) |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
Block(idx=i, window_size=14 if i not in global_attn_indexes else 0) |
|
|
for i in range(12) |
|
|
] |
|
|
) |
|
|
|
|
|
self.neck = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv2d( |
|
|
768, |
|
|
256, |
|
|
kernel_size=1, |
|
|
bias=False, |
|
|
), |
|
|
LayerNorm2d(256), |
|
|
nn.Conv2d( |
|
|
256, |
|
|
256, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
bias=False, |
|
|
), |
|
|
LayerNorm2d(256), |
|
|
] |
|
|
) |
|
|
|
|
|
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) |
|
|
self.net_3 = nn.Conv2d( |
|
|
512, 1024, kernel_size=3, stride=2, padding=1, bias=False |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.patch_embed(x) |
|
|
x = x + self.pos_embed |
|
|
|
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
|
|
|
x = x.permute(0, 3, 1, 2) |
|
|
|
|
|
for m in self.neck: |
|
|
x = m(x) |
|
|
|
|
|
x = self.net_2(x) |
|
|
x = self.net_3(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class GexQwenModel(Qwen2Model): |
|
|
config_class = GexConfig |
|
|
_auto_class = "AutoModel" |
|
|
|
|
|
def __init__(self, config: Qwen2Config): |
|
|
super().__init__(config) |
|
|
self.vit = GexVit() |
|
|
self.vit_proj = nn.Linear(1024, 1024) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
if inputs_embeds is None and input_ids is not None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
assert images is not None |
|
|
|
|
|
|
|
|
vit_feature = self.vit(images).flatten(2).permute(0, 2, 1) |
|
|
vit_feature = self.vit_proj(vit_feature) |
|
|
|
|
|
|
|
|
inputs_embeds[:, 1:257, :] = vit_feature |
|
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
|
|
return super().forward( |
|
|
input_ids=None, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
position_ids=position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
class GexQwenForCausalLM(Qwen2ForCausalLM): |
|
|
config_class = GexConfig |
|
|
|
|
|
_auto_class = "AutoModelForCausalLM" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = GexQwenModel(config) |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.image_preprocess = GexImageEvalProcessor() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
output_attentions = ( |
|
|
output_attentions |
|
|
if output_attentions is not None |
|
|
else self.config.output_attentions |
|
|
) |
|
|
output_hidden_states = ( |
|
|
output_hidden_states |
|
|
if output_hidden_states is not None |
|
|
else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
if labels is not None and input_ids is None: |
|
|
input_ids: torch.Tensor = labels |
|
|
shifted_input_ids = input_ids.new_zeros( |
|
|
(input_ids.shape[0], input_ids.shape[1] + 256), device=input_ids.device |
|
|
) |
|
|
shifted_input_ids[:, 257:].copy_(input_ids[:, :-1]) |
|
|
decoder_start_token_id = BOS_TOEKN_IDS |
|
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
shifted_input_ids[:, 1:257] = IMG_PAD_IDS |
|
|
input_ids = shifted_input_ids |
|
|
imgs_pad: torch.Tenosr = torch.full( |
|
|
(1, 256), IMG_PAD_IDS, device=self.device, dtype=torch.long |
|
|
) |
|
|
labels = torch.cat( |
|
|
[ |
|
|
imgs_pad.expand(labels.shape[0], -1), |
|
|
process_batch_labels(labels), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
images=images, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
|
slice_indices = ( |
|
|
slice(-logits_to_keep, None) |
|
|
if isinstance(logits_to_keep, int) |
|
|
else logits_to_keep |
|
|
) |
|
|
|
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function( |
|
|
logits=logits, |
|
|
labels=None, |
|
|
shift_labels=labels, |
|
|
vocab_size=self.config.vocab_size, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def generate(self, *args, images, **kwargs): |
|
|
pad = torch.tensor( |
|
|
[[BOS_TOEKN_IDS] + [IMG_PAD_IDS] * 256], |
|
|
dtype=torch.long, |
|
|
device=self.device, |
|
|
) |
|
|
if (input_ids := kwargs.pop("input_ids", None)) is not None: |
|
|
input_ids = torch.cat( |
|
|
[pad.expand(input_ids.shape[0], -1), input_ids], dim=-1 |
|
|
) |
|
|
else: |
|
|
input_ids = pad.expand(images.shape[0], -1) |
|
|
|
|
|
res = super().generate( |
|
|
*args, |
|
|
input_ids=input_ids, |
|
|
images=images, |
|
|
max_length=kwargs.pop("max_length", 10) + 257, |
|
|
**kwargs, |
|
|
) |
|
|
return res |
|
|
|
|
|
|
|
|
class GexTQwenModel(Qwen3Model): |
|
|
config_class = GexTConfig |
|
|
_auto_class = "AutoModel" |
|
|
|
|
|
def __init__(self, config: Qwen3Config): |
|
|
super().__init__(config) |
|
|
self.vit = GexVit() |
|
|
self.vit_proj = nn.Linear(1024, 1024) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**flash_attn_kwargs, |
|
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
if inputs_embeds is None and input_ids is not None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
assert images is not None |
|
|
|
|
|
|
|
|
vit_feature = self.vit(images).flatten(2).permute(0, 2, 1) |
|
|
vit_feature = self.vit_proj(vit_feature) |
|
|
|
|
|
|
|
|
inputs_embeds[:, 1:257, :] = vit_feature |
|
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION): |
|
|
return super().forward( |
|
|
input_ids=None, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
position_ids=position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
cache_position=cache_position, |
|
|
**flash_attn_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
class GexTQwenForCausalLM(Qwen3ForCausalLM): |
|
|
config_class = GexTConfig |
|
|
|
|
|
_auto_class = "AutoModelForCausalLM" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = GexTQwenModel(config) |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.image_preprocess = GexImageEvalProcessor() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
output_attentions = ( |
|
|
output_attentions |
|
|
if output_attentions is not None |
|
|
else self.config.output_attentions |
|
|
) |
|
|
output_hidden_states = ( |
|
|
output_hidden_states |
|
|
if output_hidden_states is not None |
|
|
else self.config.output_hidden_states |
|
|
) |
|
|
|
|
|
if labels is not None and input_ids is None: |
|
|
input_ids: torch.Tensor = labels |
|
|
shifted_input_ids = input_ids.new_zeros( |
|
|
(input_ids.shape[0], input_ids.shape[1] + 257), device=input_ids.device |
|
|
) |
|
|
shifted_input_ids[:, 258:].copy_(input_ids[:, :-1]) |
|
|
decoder_start_token_id = BOS_TOEKN_IDS |
|
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
shifted_input_ids[:, 257] = IMG_END_IDS |
|
|
shifted_input_ids[:, 1:257] = IMG_PAD_IDS |
|
|
input_ids = shifted_input_ids |
|
|
imgs_pad: torch.Tenosr = torch.full( |
|
|
(1, 257), IMG_PAD_IDS, device=self.device, dtype=torch.long |
|
|
) |
|
|
imgs_pad[:, -1] = IMG_END_IDS |
|
|
labels = torch.cat( |
|
|
[ |
|
|
imgs_pad.expand(labels.shape[0], -1), |
|
|
process_batch_labels(labels), |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
cache_position=cache_position, |
|
|
images=images, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.training and use_liger: |
|
|
loss = LigerForCausalLMLoss( |
|
|
hidden_states=hidden_states, |
|
|
lm_head_weight=self.lm_head.weight, |
|
|
labels=None, |
|
|
shift_labels=labels, |
|
|
hidden_size=self.config.hidden_size, |
|
|
**kwargs, |
|
|
) |
|
|
logits = None |
|
|
|
|
|
else: |
|
|
slice_indices = ( |
|
|
slice(-logits_to_keep, None) |
|
|
if isinstance(logits_to_keep, int) |
|
|
else logits_to_keep |
|
|
) |
|
|
|
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
loss = self.loss_function( |
|
|
logits=logits, |
|
|
labels=None, |
|
|
shift_labels=labels, |
|
|
vocab_size=self.config.vocab_size, |
|
|
**kwargs, |
|
|
) |
|
|
else: |
|
|
slice_indices = ( |
|
|
slice(-logits_to_keep, None) |
|
|
if isinstance(logits_to_keep, int) |
|
|
else logits_to_keep |
|
|
) |
|
|
|
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def generate(self, *args, images, **kwargs): |
|
|
pad = torch.tensor( |
|
|
[[BOS_TOEKN_IDS] + [IMG_PAD_IDS] * 256 + [IMG_END_IDS]], |
|
|
dtype=torch.long, |
|
|
device=self.device, |
|
|
) |
|
|
if (input_ids := kwargs.pop("input_ids", None)) is not None: |
|
|
input_ids = torch.cat( |
|
|
[pad.expand(input_ids.shape[0], -1), input_ids], dim=-1 |
|
|
) |
|
|
else: |
|
|
input_ids = pad.expand(images.shape[0], -1) |
|
|
|
|
|
res = super().generate( |
|
|
*args, |
|
|
input_ids=input_ids, |
|
|
images=images, |
|
|
max_length=kwargs.pop("max_length", 25) + 258, |
|
|
**kwargs, |
|
|
) |
|
|
return res |
|
|
|