|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Callable, Optional, Tuple, Union |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import Qwen3Model |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import TransformersKwargs, can_return_tuple, logging |
|
|
|
|
|
from typing import Any, Literal, Optional, TypedDict, Union |
|
|
|
|
|
from configuration_step_vl import StepRoboticsConfig |
|
|
from vision_encoder import StepRoboticsVisionEncoder |
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class StepVLImagePixelInputs(TypedDict): |
|
|
type: Literal["pixel_values"] |
|
|
pixel_values: torch.Tensor |
|
|
patch_pixel_values: Optional[torch.Tensor] |
|
|
num_patches: list[int] |
|
|
|
|
|
|
|
|
class StepVLImageEmbeddingInputs(TypedDict): |
|
|
type: Literal["image_embeds"] |
|
|
image_embeds: torch.Tensor |
|
|
|
|
|
|
|
|
StepVLImageInputs = Union[StepVLImagePixelInputs, |
|
|
StepVLImageEmbeddingInputs] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class StepVLCausalLMOutputWithPast(ModelOutput): |
|
|
r""" |
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
|
Language modeling loss (for next-token prediction). |
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
|
`past_key_values` input) to speed up sequential decoding. |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
last_hidden_state: Optional[torch.FloatTensor] = None |
|
|
logits: torch.FloatTensor = None |
|
|
past_key_values: Optional[list[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None |
|
|
image_hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
|
|
def _flatten_embeddings(embeddings) -> torch.Tensor: |
|
|
""" |
|
|
Recursively flattens and concatenates NestedTensors on all but the last |
|
|
dimension. |
|
|
""" |
|
|
|
|
|
if isinstance(embeddings, torch.Tensor): |
|
|
|
|
|
return embeddings.flatten(0, -2) |
|
|
|
|
|
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) |
|
|
|
|
|
def _embedding_count_expression(embeddings) -> str: |
|
|
""" |
|
|
Constructs a debugging representation of the number of embeddings in the |
|
|
NestedTensors. |
|
|
""" |
|
|
|
|
|
if isinstance(embeddings, torch.Tensor): |
|
|
return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) |
|
|
|
|
|
return " + ".join( |
|
|
_embedding_count_expression(inner) for inner in embeddings) |
|
|
|
|
|
def _merge_multimodal_embeddings( |
|
|
inputs_embeds: torch.Tensor, |
|
|
is_multimodal: torch.Tensor, |
|
|
multimodal_embeddings, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
|
|
positions in ``inputs_embeds`` corresponding to placeholder tokens in |
|
|
``input_ids``. |
|
|
Note: |
|
|
This updates ``inputs_embeds`` in place. |
|
|
""" |
|
|
num_expected_tokens = is_multimodal.sum().item() |
|
|
assert isinstance(num_expected_tokens, int) |
|
|
|
|
|
flattened = _flatten_embeddings(multimodal_embeddings) |
|
|
if flattened.shape[0] != num_expected_tokens: |
|
|
expr = _embedding_count_expression(multimodal_embeddings) |
|
|
raise ValueError( |
|
|
f"Attempted to assign {expr} = {flattened.shape[0]} " |
|
|
f"multimodal tokens to {num_expected_tokens} placeholders") |
|
|
|
|
|
is_multimodal = is_multimodal.to(inputs_embeds.device) |
|
|
flattened = flattened.to(inputs_embeds.device) |
|
|
inputs_embeds[is_multimodal] = flattened |
|
|
return inputs_embeds |
|
|
|
|
|
def merge_multimodal_embeddings( |
|
|
input_ids: torch.Tensor, |
|
|
inputs_embeds: torch.Tensor, |
|
|
multimodal_embeddings, |
|
|
placeholder_token_id: Union[int, list[int]], |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the |
|
|
positions in ``inputs_embeds`` corresponding to placeholder tokens in |
|
|
``input_ids``. |
|
|
|
|
|
``placeholder_token_id`` can be a list of token ids (e.g, token ids |
|
|
of img_start, img_break, and img_end tokens) when needed: This means |
|
|
the order of these tokens in the ``input_ids`` MUST MATCH the order of |
|
|
their embeddings in ``multimodal_embeddings`` since we need to |
|
|
slice-merge instead of individually scattering. |
|
|
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where |
|
|
- T is text token |
|
|
- S is image start token |
|
|
- I is image embedding token |
|
|
- B is image break token |
|
|
- E is image end token. |
|
|
|
|
|
Then the image embeddings (that correspond to I's) from vision encoder |
|
|
must be padded with embeddings of S, B, and E in the same order of |
|
|
input_ids for a correct embedding merge. |
|
|
Note: |
|
|
This updates ``inputs_embeds`` in place. |
|
|
""" |
|
|
if isinstance(placeholder_token_id, list): |
|
|
placeholder_token_id = torch.tensor(placeholder_token_id, |
|
|
device=input_ids.device) |
|
|
return _merge_multimodal_embeddings( |
|
|
inputs_embeds, |
|
|
torch.isin(input_ids, placeholder_token_id), |
|
|
multimodal_embeddings, |
|
|
) |
|
|
|
|
|
return _merge_multimodal_embeddings( |
|
|
inputs_embeds, |
|
|
(input_ids == placeholder_token_id), |
|
|
multimodal_embeddings, |
|
|
) |
|
|
|
|
|
class StepRoboticsPreTrainedModel(PreTrainedModel): |
|
|
|
|
|
|
|
|
config_class = StepRoboticsConfig |
|
|
supports_gradient_checkpointing = True |
|
|
_skip_keys_device_placement = ["past_key_values"] |
|
|
_supports_flash_attn = False |
|
|
_supports_sdpa = True |
|
|
_supports_flex_attn = True |
|
|
_supports_static_cache = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
|
|
|
class StepRoboticsModel(StepRoboticsPreTrainedModel, GenerationMixin): |
|
|
config: StepRoboticsConfig |
|
|
base_model_prefix = "" |
|
|
def __init__(self, config: StepRoboticsConfig): |
|
|
super().__init__(config) |
|
|
self.vision_model = StepRoboticsVisionEncoder(config.vision_config) |
|
|
self.language_model = Qwen3Model(config.text_config) |
|
|
self.vocab_size = config.text_config.vocab_size |
|
|
self.vit_large_projector = nn.Linear( |
|
|
config.vision_config.width * 4, |
|
|
config.text_config.hidden_size, |
|
|
bias=config.projector_bias) |
|
|
self.image_placeholder_token_id = config.image_token_id |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
multimodal_embeddings = None, |
|
|
) -> torch.Tensor: |
|
|
input_ids = input_ids.squeeze(0) |
|
|
if multimodal_embeddings is None: |
|
|
inputs_embeds = self.language_model.embed_tokens(input_ids) |
|
|
else: |
|
|
is_text = input_ids != self.config.image_token_id |
|
|
text_ids = input_ids[is_text] |
|
|
text_embeds = self.language_model.embed_tokens(text_ids) |
|
|
|
|
|
inputs_embeds = torch.empty(input_ids.shape[0], |
|
|
text_embeds.shape[-1], |
|
|
dtype=text_embeds.dtype, |
|
|
device=text_embeds.device) |
|
|
inputs_embeds[is_text] = text_embeds |
|
|
inputs_embeds = merge_multimodal_embeddings( |
|
|
input_ids, inputs_embeds, multimodal_embeddings, |
|
|
self.config.image_token_id) |
|
|
inputs_embeds = inputs_embeds.unsqueeze(0) |
|
|
return inputs_embeds |
|
|
|
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
return self.language_model.set_input_embeddings(value) |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.language_model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.language_model |
|
|
|
|
|
def _parse_and_validate_image_input( |
|
|
self, **kwargs: object) -> Optional[StepVLImageInputs]: |
|
|
pixel_values = kwargs.pop("pixel_values", None) |
|
|
patch_pixel_values = kwargs.pop("patch_pixel_values", None) |
|
|
num_patches = kwargs.pop("num_patches", None) |
|
|
image_embeds = kwargs.pop("image_embeds", None) |
|
|
|
|
|
if pixel_values is None and image_embeds is None: |
|
|
return None |
|
|
|
|
|
if pixel_values is not None: |
|
|
|
|
|
if pixel_values.dim() >= 3: |
|
|
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) |
|
|
if patch_pixel_values is not None: |
|
|
|
|
|
|
|
|
patch_pixel_values = patch_pixel_values.view( |
|
|
-1, *patch_pixel_values.shape[-3:]) |
|
|
|
|
|
if patch_pixel_values.shape[0] == 0: |
|
|
patch_pixel_values = None |
|
|
|
|
|
return StepVLImagePixelInputs( |
|
|
type="pixel_values", |
|
|
pixel_values=pixel_values.to(self.dtype).to(self.device), |
|
|
patch_pixel_values=patch_pixel_values.to(self.dtype).to( |
|
|
self.device) if patch_pixel_values is not None else None, |
|
|
num_patches=num_patches, |
|
|
) |
|
|
|
|
|
if image_embeds is not None: |
|
|
if image_embeds.dim() == 2 or image_embeds.dim() >= 3: |
|
|
image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unexpected shape for image_embeds: {image_embeds.shape}") |
|
|
|
|
|
return StepVLImageEmbeddingInputs( |
|
|
type="image_embeds", |
|
|
image_embeds=image_embeds.to(self.dtype).to(self.device), |
|
|
) |
|
|
return None |
|
|
|
|
|
def _process_image_features(self, |
|
|
image_features: torch.Tensor) -> torch.Tensor: |
|
|
B, P = image_features.shape[:2] |
|
|
HW = int(P ** 0.5) |
|
|
image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) |
|
|
image_features = self.vision_model.vit_downsampler1(image_features) |
|
|
image_features = self.vision_model.vit_downsampler2(image_features) |
|
|
|
|
|
B, C, HW, HW = image_features.shape |
|
|
image_features = image_features.view(B, -1, HW * HW).permute(0, 2, 1) |
|
|
image_features = self.vit_large_projector(image_features) |
|
|
return image_features |
|
|
|
|
|
def _get_vision_model_output(self, |
|
|
input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
return self.vision_model(input_tensor) |
|
|
|
|
|
def _process_image_input( |
|
|
self, image_input: StepVLImageInputs) -> tuple[torch.Tensor, ...]: |
|
|
|
|
|
if image_input["type"] == "image_embeds": |
|
|
image_features = image_input["image_embeds"] |
|
|
else: |
|
|
image_features = self._get_vision_model_output( |
|
|
image_input["pixel_values"]) |
|
|
patch_image_features = self._get_vision_model_output( |
|
|
image_input["patch_pixel_values"] |
|
|
) if image_input["patch_pixel_values"] is not None else None |
|
|
num_patches = image_input["num_patches"] |
|
|
|
|
|
image_features = self._process_image_features(image_features) |
|
|
patch_image_features = self._process_image_features( |
|
|
patch_image_features) if patch_image_features is not None else None |
|
|
|
|
|
merged_image_features = [] |
|
|
cur_patch_idx = 0 |
|
|
for i, num_patch in enumerate(num_patches): |
|
|
cur_feature = [] |
|
|
if num_patch > 0: |
|
|
patch_slice = patch_image_features[ |
|
|
cur_patch_idx:cur_patch_idx + num_patch] |
|
|
cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) |
|
|
cur_feature.append(image_features[i].view( |
|
|
-1, image_features.shape[-1])) |
|
|
cur_patch_idx += num_patch |
|
|
merged_image_features.append( |
|
|
torch.cat(cur_feature) if len(cur_feature) > |
|
|
1 else cur_feature[0]) |
|
|
|
|
|
return merged_image_features |
|
|
|
|
|
def get_multimodal_embeddings(self, **kwargs): |
|
|
image_input = self._parse_and_validate_image_input(**kwargs) |
|
|
if image_input is None: |
|
|
return None |
|
|
vision_embeddings = self._process_image_input(image_input) |
|
|
return vision_embeddings |
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
images: Optional[list[Image.Image]] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[tuple, StepVLCausalLMOutputWithPast]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
Example: |
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, Llama4ForCausalLM |
|
|
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") |
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") |
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
>>> # Generate |
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
|
```""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
input_ids = input_ids |
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs) |
|
|
inputs_embeds = self.get_input_embeddings(input_ids, |
|
|
vision_embeddings) |
|
|
input_ids = None |
|
|
|
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
output = StepVLCausalLMOutputWithPast( |
|
|
last_hidden_state=outputs.last_hidden_state, |
|
|
past_key_values=outputs.past_key_values, |
|
|
attentions=outputs.attentions, |
|
|
|
|
|
) |
|
|
return output if return_dict else output.to_tuple() |
|
|
|
|
|
|
|
|
|
|
|
class Step3VL10BForCausalLM(StepRoboticsPreTrainedModel, GenerationMixin): |
|
|
_checkpoint_conversion_mapping = { |
|
|
"^vision_model": "model.vision_model", |
|
|
r"^model(?!\.(language_model|vision_model))": "model.language_model", |
|
|
"^vit_large_projector": "model.vit_large_projector" |
|
|
} |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
config: StepRoboticsConfig |
|
|
|
|
|
def __init__(self, config: StepRoboticsConfig): |
|
|
super().__init__(config) |
|
|
self.model = StepRoboticsModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.model.get_output_embeddings() |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.model.set_output_embeddings(new_embeddings) |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.model.set_decoder(decoder) |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.model.get_decoder() |
|
|
|
|
|
@property |
|
|
def language_model(self): |
|
|
return self.model.language_model |
|
|
|
|
|
@property |
|
|
def visual(self): |
|
|
return self.model.visual |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
num_patches = None, |
|
|
patch_pixel_values = None, |
|
|
patch_newline_mask = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[tuple, StepVLCausalLMOutputWithPast]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
Example: |
|
|
```python |
|
|
>>> from PIL import Image |
|
|
>>> import requests |
|
|
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
|
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
|
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
|
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:" |
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
|
|
>>> # Generate |
|
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15) |
|
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" |
|
|
```""" |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
num_patches = num_patches, |
|
|
patch_pixel_values = patch_pixel_values, |
|
|
patch_newline_mask=patch_newline_mask, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
los = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
|
|
|
return StepVLCausalLMOutputWithPast( |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
inputs_embeds=None, |
|
|
pixel_values=None, |
|
|
attention_mask=None, |
|
|
cache_position=None, |
|
|
logits_to_keep=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
logits_to_keep=logits_to_keep, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if cache_position[0] == 0: |
|
|
|
|
|
|
|
|
model_inputs["pixel_values"] = pixel_values |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
def _fix_state_dict_key_on_load(self, key: str) -> tuple[str, bool]: |
|
|
if key.startswith("language_model."): |
|
|
return key[len("language_model."):], True |
|
|
|
|
|
return key, False |
|
|
|
|
|
|