File size: 7,293 Bytes
5dbdc31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import torch
import torch.nn as nn
from typing import List, Tuple, Optional, Any, Dict
from dataclasses import dataclass
from transformers import Qwen2_5_VLForConditionalGeneration
from transformers.modeling_outputs import ModelOutput
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig
from transformers.models.idefics2.modeling_idefics2 import Idefics2PerceiverResampler
from transformers.models.idefics2.configuration_idefics2 import Idefics2PerceiverConfig
from transformers.utils import ModelOutput
from transformers.processing_utils import Unpack
@dataclass
class TRASEROutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
class TRASER(Qwen2_5_VLForConditionalGeneration):
def __init__(self, config: Qwen2_5_VLConfig, **kwargs):
super().__init__(config)
# Update config with kwargs if provided (fallback mechanism)
for k, v in kwargs.items():
if not hasattr(config, k):
setattr(config, k, v)
self.config = config
self._build_perceiver(dtype=config.torch_dtype, attn_impl=config._attn_implementation)
self.post_init()
def _build_perceiver(self, dtype: torch.dtype, attn_impl: str) -> None:
h = int(getattr(self.config, "hidden_size", 2048))
n_latents = int(getattr(self.config, "temporal_resampler_n_latents", 64))
depth = int(getattr(self.config, "resampler_depth", 3))
perceiver_cfg = Idefics2PerceiverConfig(
hidden_size=h,
resampler_n_latents=n_latents,
resampler_depth=depth,
_attn_implementation=attn_impl,
torch_dtype=dtype,
)
self.perceiver_resampler = Idefics2PerceiverResampler(perceiver_cfg)
if getattr(self.config, "object_resampler", True):
second_n_latents = int(getattr(self.config, "object_resampler_n_latents", 32))
second_perceiver_cfg = Idefics2PerceiverConfig(
hidden_size=h,
resampler_n_latents=second_n_latents,
resampler_depth=depth,
_attn_implementation=attn_impl,
torch_dtype=dtype,
)
self.second_perceiver_resampler = Idefics2PerceiverResampler(second_perceiver_cfg)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
use_cache=use_cache,
**kwargs,
)
model_inputs["position_ids"] = position_ids
if cache_position is not None and cache_position[0] != 0:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
model_inputs["position_ids"] = None
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
**kwargs: Unpack[Any],
) -> TRASEROutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if rope_deltas is not None:
self.model.rope_deltas = rope_deltas
is_prefill = (inputs_embeds is not None) and (
past_key_values is None or (hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0)
)
if is_prefill:
outputs = self.model.language_model(
input_ids=None,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=True,
)
else:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.model.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
pos = torch.arange(seq_length, device=inputs_embeds.device).view(1, -1).expand(batch_size, -1)
if cache_position is not None:
delta = delta.repeat_interleave(max(1, batch_size // delta.shape[0]), dim=0)
pos = pos.add(delta).unsqueeze(0).expand(3, -1, -1)
outputs = self.model.language_model(
input_ids=None,
position_ids=pos,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
return TRASEROutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.model.rope_deltas,
) |