# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from examples/modular-transformers/modular_new_task_model.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_new_task_model.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from collections.abc import Callable from dataclasses import dataclass from typing import ClassVar import torch from torch import nn from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig from ...generation import GenerationMixin from ...masking_utils import create_masks_for_generate from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_new_task_model import NewTaskModelConfig logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Base class for NewTaskModel outputs, with hidden states and attentions. """ ) class NewTaskModelModelOutputWithPast(BaseModelOutputWithPast): r""" image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: torch.FloatTensor | None = None @dataclass @auto_docstring( custom_intro=""" Base class for NewTaskModel causal language model (or autoregressive) outputs. """ ) class NewTaskModelCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder after projecting last hidden state. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor] | None = None attentions: tuple[torch.FloatTensor] | None = None image_hidden_states: torch.FloatTensor | None = None class NewTaskModelMultiModalProjector(nn.Module): def __init__(self, config: NewTaskModelConfig): super().__init__() self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) def forward(self, image_features): hidden_states = self.linear(image_features) return hidden_states @auto_docstring class NewTaskModelPreTrainedModel(PreTrainedModel): config: NewTaskModelConfig base_model_prefix = "model" input_modalities = ("image", "text") supports_gradient_checkpointing = True _no_split_modules = ["NewTaskModelMultiModalProjector"] _skip_keys_device_placement = "past_key_values" _can_compile_fullgraph = False _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True def token_type_ids_mask_function( token_type_ids: torch.Tensor | None, image_group_ids: torch.Tensor | None, ) -> Callable | None: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. """ # Do not return an additional mask in this case if token_type_ids is None: return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: # If it's 1 for both query and key/value, we are in an image block # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length # Since vmap doesn't support `if statement` we workaround it with `torch.where` safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0) safe_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) token_type_ids_at_q_idx = token_type_ids[batch_idx, safe_q_idx] token_type_ids_at_q_idx = torch.where(q_idx < token_type_ids.shape[1], token_type_ids_at_q_idx, 0) token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_kv_idx] token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) image_group_ids_at_q_idx = image_group_ids[batch_idx, safe_q_idx] image_group_ids_at_q_idx = torch.where(q_idx < image_group_ids.shape[1], image_group_ids_at_q_idx, -1) image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_kv_idx] image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1) is_image_block = (token_type_ids_at_q_idx == 1) & (token_type_ids_at_kv_idx == 1) same_image_block = image_group_ids_at_q_idx == image_group_ids_at_kv_idx # This is bidirectional attention whenever we are dealing with image tokens return is_image_block & same_image_block return inner_mask def create_causal_mask_mapping( config: PreTrainedConfig, input_embeds: torch.Tensor, attention_mask: torch.Tensor | None, cache_position: torch.Tensor, past_key_values: Cache | None, position_ids: torch.Tensor | None, token_type_ids: torch.Tensor | None = None, pixel_values: torch.FloatTensor | None = None, is_training: bool | None = False, is_first_iteration: bool | None = None, **kwargs, ) -> dict: """ Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping for all kinds of forward passes. NewTaskModel uses a bidirectional mask on the prompt tokens. Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and token_type_ids is None: raise ValueError("`token_type_ids` is required as a model input when training") mask_kwargs = { "config": config.get_text_config(), "input_embeds": input_embeds, "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, "position_ids": position_ids, } # Infer if prefill or decoding stage, if the flag isn't passed. This happens only when the mask is constructed # from `forward` call. If users run a `forward` call, we have no option to infer `is_first_iteration` because users may be # running generation with custom loop. Thus we need to infer it in a `non-perfect` way # NOTE: Determining prefill in that case requires checking data values, which is not compile-compatible. is_first_iteration = ( is_first_iteration if is_first_iteration else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None) ) if is_first_iteration or not kwargs.get("use_cache", True): if token_type_ids is not None: # The logic bellow was originally written for Gemma3, where `token_type_ids` is reversed. Let's reverse # it to then use exactly the same logic. token_type_ids = 1 - token_type_ids else: logger.warning_once( "It is a prefill stage but The `token_type_ids` is not provided. We recommend " "passing `token_type_ids` to the model to prevent bad attention masking." ) # NOTE: this branch can't be reached when training because `token_type_ids` is required as a model input. token_type_ids = torch.ones_like(input_embeds)[:, :, 0] # Logic originally copied from Gemma3. It holds up for NewTaskModel as well because NewTaskModel assumes up to one image # per prompt AND we reverse `token_type_ids` above. Gemma3 uses a bidirectional mask for images, tagged through # `token_type_ids` 1s. if token_type_ids is not None and is_first_iteration: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to # undo the causal masking) # First find where a new image block starts: 1 if image and previous not image # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally is_image = (token_type_ids == 1).to(cache_position.device) is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] new_image_start = is_image & ~is_previous_image image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( token_type_ids.to(cache_position.device), image_group_ids ) return create_masks_for_generate(**mask_kwargs) @auto_docstring( custom_intro=""" The Base NewTaskModel model which consists of a vision backbone and a language model without language modeling head., """ ) class NewTaskModelModel(NewTaskModelPreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch accepts_loss_kwargs = False def __init__(self, config: NewTaskModelConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = NewTaskModelMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size language_model = AutoModel.from_config(config=config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.text_config_dtype = self.config.get_text_config().dtype or self.dtype self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_image_features(self, pixel_values: torch.FloatTensor): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) The tensors corresponding to the input images. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features def get_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor ): """ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is equal to the length of multimodal features. If the lengths are different, an error is raised. """ if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = special_image_mask.all(-1) else: special_image_mask = input_ids == self.config.image_token_id n_image_tokens = special_image_mask.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_features.shape[0] * image_features.shape[1] if inputs_embeds[special_image_mask].numel() != image_features.numel(): raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) return special_image_mask @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, pixel_values: torch.FloatTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, token_type_ids: torch.LongTensor | None = None, cache_position: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple | NewTaskModelModelOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration >>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/new_task_model2-3b-mix-224") >>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224") >>> prompt = "Where is the cat standing?" >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs,) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Where is the cat standing?\nsnow" ```""" if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Replace image id with PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) + 1 # NewTaskModel positions are 1-indexed # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) special_image_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): causal_mask_mapping = create_causal_mask_mapping( self.config, inputs_embeds, attention_mask, cache_position, past_key_values, position_ids, token_type_ids, pixel_values, is_training=self.training, ) outputs = self.language_model( attention_mask=causal_mask_mapping, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) return NewTaskModelModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) @auto_docstring( custom_intro=""" The Base NewTaskModel model which consists of a vision backbone and a language model without language modeling head., """ ) class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^vision_tower": "model.vision_tower", "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related def __init__(self, config): super().__init__(config) self.model = NewTaskModelModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.embedding_dim = self.config.embedding_dim self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim) if self.language_model._tied_weights_keys is not None: prefix = "model.language_model." prefixed_mapping = { f"{prefix}{target}": f"{prefix}{source}" for target, source in self.language_model._tied_weights_keys.items() } if isinstance(self._tied_weights_keys, dict): self._tied_weights_keys.update(prefixed_mapping) else: self._tied_weights_keys = prefixed_mapping self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, token_type_ids: torch.LongTensor | None = None, cache_position: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, num_logits_to_keep: int = 0, ) -> tuple | NewTaskModelCausalLMOutputWithPast: r""" Returns: """ vlm_outputs = super().forward( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, token_type_ids=token_type_ids, cache_position=cache_position, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, return_dict=True, num_logits_to_keep=num_logits_to_keep, ) last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) # L2 normalization embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim) if attention_mask is not None: embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim) return (embeddings,) + vlm_outputs def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, cache_position=None, position_ids=None, pixel_values=None, attention_mask=None, token_type_ids=None, use_cache=True, logits_to_keep=None, labels=None, is_first_iteration=False, **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, logits_to_keep=logits_to_keep, token_type_ids=token_type_ids, is_first_iteration=is_first_iteration, **kwargs, ) # position_ids in NewTaskModel are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 # Pixel values are used only in the first iteration if available # In subsquent iterations, they are already merged with text and cached # NOTE: first iteration doesn't have to be prefill, it can be the first # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always if is_first_iteration or not use_cache: model_inputs["pixel_values"] = pixel_values return model_inputs @staticmethod def create_masks_for_generate( config: PreTrainedConfig, input_embeds: torch.Tensor, attention_mask: torch.Tensor | None, cache_position: torch.Tensor, past_key_values: Cache | None, position_ids: torch.Tensor | None, token_type_ids: torch.Tensor | None = None, is_first_iteration: bool | None = False, **kwargs, ) -> dict: # Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking return create_causal_mask_mapping( config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, token_type_ids, is_first_iteration=is_first_iteration, **{k: v for k, v in kwargs.items() if k != "pixel_values"}, ) def resize_token_embeddings( self, new_num_tokens: int | None = None, pad_to_multiple_of=None, mean_resizing=True ) -> nn.Embedding: model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) # Update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings self.config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings return model_embeds