import torch import torch.nn as nn from typing import List, Tuple, Optional, Any, Dict from dataclasses import dataclass from transformers import Qwen2_5_VLForConditionalGeneration from transformers.modeling_outputs import ModelOutput from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig from transformers.models.idefics2.modeling_idefics2 import Idefics2PerceiverResampler from transformers.models.idefics2.configuration_idefics2 import Idefics2PerceiverConfig from transformers.utils import ModelOutput from transformers.processing_utils import Unpack @dataclass class TRASEROutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None class TRASER(Qwen2_5_VLForConditionalGeneration): def __init__(self, config: Qwen2_5_VLConfig, **kwargs): super().__init__(config) # Update config with kwargs if provided (fallback mechanism) for k, v in kwargs.items(): if not hasattr(config, k): setattr(config, k, v) self.config = config self._build_perceiver(dtype=config.torch_dtype, attn_impl=config._attn_implementation) self.post_init() def _build_perceiver(self, dtype: torch.dtype, attn_impl: str) -> None: h = int(getattr(self.config, "hidden_size", 2048)) n_latents = int(getattr(self.config, "temporal_resampler_n_latents", 64)) depth = int(getattr(self.config, "resampler_depth", 3)) perceiver_cfg = Idefics2PerceiverConfig( hidden_size=h, resampler_n_latents=n_latents, resampler_depth=depth, _attn_implementation=attn_impl, torch_dtype=dtype, ) self.perceiver_resampler = Idefics2PerceiverResampler(perceiver_cfg) if getattr(self.config, "object_resampler", True): second_n_latents = int(getattr(self.config, "object_resampler_n_latents", 32)) second_perceiver_cfg = Idefics2PerceiverConfig( hidden_size=h, resampler_n_latents=second_n_latents, resampler_depth=depth, _attn_implementation=attn_impl, torch_dtype=dtype, ) self.second_perceiver_resampler = Idefics2PerceiverResampler(second_perceiver_cfg) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, second_per_grid_ts=None, **kwargs, ): model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, use_cache=use_cache, **kwargs, ) model_inputs["position_ids"] = position_ids if cache_position is not None and cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None model_inputs["position_ids"] = None return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, **kwargs: Unpack[Any], ) -> TRASEROutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if rope_deltas is not None: self.model.rope_deltas = rope_deltas is_prefill = (inputs_embeds is not None) and ( past_key_values is None or (hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0) ) if is_prefill: outputs = self.model.language_model( input_ids=None, inputs_embeds=inputs_embeds, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=True, ) else: inputs_embeds = self.model.get_input_embeddings()(input_ids) batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.model.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) pos = torch.arange(seq_length, device=inputs_embeds.device).view(1, -1).expand(batch_size, -1) if cache_position is not None: delta = delta.repeat_interleave(max(1, batch_size // delta.shape[0]), dim=0) pos = pos.add(delta).unsqueeze(0).expand(3, -1, -1) outputs = self.model.language_model( input_ids=None, position_ids=pos, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) return TRASEROutput( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.model.rope_deltas, )