TRASER / modeling_traser.py
UWGZQ's picture
Upload folder using huggingface_hub
5dbdc31 verified
import torch
import torch.nn as nn
from typing import List, Tuple, Optional, Any, Dict
from dataclasses import dataclass
from transformers import Qwen2_5_VLForConditionalGeneration
from transformers.modeling_outputs import ModelOutput
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.idefics2.modeling_idefics2 import Idefics2PerceiverResampler
from transformers.models.idefics2.configuration_idefics2 import Idefics2PerceiverConfig
from transformers.utils import ModelOutput
from transformers.processing_utils import Unpack
@dataclass
class TRASEROutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
class TRASER(Qwen2_5_VLForConditionalGeneration):
def __init__(self, config: Qwen2_5_VLConfig, **kwargs):
super().__init__(config)
# Update config with kwargs if provided (fallback mechanism)
for k, v in kwargs.items():
if not hasattr(config, k):
setattr(config, k, v)
self.config = config
self._build_perceiver(dtype=config.torch_dtype, attn_impl=config._attn_implementation)
self.post_init()
def _build_perceiver(self, dtype: torch.dtype, attn_impl: str) -> None:
h = int(getattr(self.config, "hidden_size", 2048))
n_latents = int(getattr(self.config, "temporal_resampler_n_latents", 64))
depth = int(getattr(self.config, "resampler_depth", 3))
perceiver_cfg = Idefics2PerceiverConfig(
hidden_size=h,
resampler_n_latents=n_latents,
resampler_depth=depth,
_attn_implementation=attn_impl,
torch_dtype=dtype,
)
self.perceiver_resampler = Idefics2PerceiverResampler(perceiver_cfg)
if getattr(self.config, "object_resampler", True):
second_n_latents = int(getattr(self.config, "object_resampler_n_latents", 32))
second_perceiver_cfg = Idefics2PerceiverConfig(
hidden_size=h,
resampler_n_latents=second_n_latents,
resampler_depth=depth,
_attn_implementation=attn_impl,
torch_dtype=dtype,
)
self.second_perceiver_resampler = Idefics2PerceiverResampler(second_perceiver_cfg)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
use_cache=use_cache,
**kwargs,
)
model_inputs["position_ids"] = position_ids
if cache_position is not None and cache_position[0] != 0:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
model_inputs["position_ids"] = None
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
**kwargs: Unpack[Any],
) -> TRASEROutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if rope_deltas is not None:
self.model.rope_deltas = rope_deltas
is_prefill = (inputs_embeds is not None) and (
past_key_values is None or (hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0)
)
if is_prefill:
outputs = self.model.language_model(
input_ids=None,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=True,
)
else:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.model.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
pos = torch.arange(seq_length, device=inputs_embeds.device).view(1, -1).expand(batch_size, -1)
if cache_position is not None:
delta = delta.repeat_interleave(max(1, batch_size // delta.shape[0]), dim=0)
pos = pos.add(delta).unsqueeze(0).expand(3, -1, -1)
outputs = self.model.language_model(
input_ids=None,
position_ids=pos,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return TRASEROutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.model.rope_deltas,
)