"""PyTorch MarkupDM model.""" import contextlib import math import os from typing import Any import rff.layers import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( AutoModel, AutoModelForCausalLM, GenerationMixin, PreTrainedModel, ) from transformers.loss.loss_utils import LOSS_MAPPING from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging from .configuration_markupdm import MarkupDMConfig from .loss_utils import WeightedCausalLMLoss logger = logging.get_logger(__name__) LOSS_MAPPING["WeightedCausalLMLoss"] = WeightedCausalLMLoss class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): # type: ignore config: MarkupDMConfig config_class = MarkupDMConfig supports_gradient_checkpointing = True _supports_flash_attn_2 = True def __init__( self, config: MarkupDMConfig, text_model: PreTrainedModel, vision_model: PreTrainedModel, ) -> None: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") # Initialize with config logger.info(f"MarkupDM config: {config}") super().__init__(config) self.text_model = text_model.train() self.vision_model = vision_model.eval().requires_grad_(False) if self.text_model.config.to_dict() != self.config.text_model.to_dict(): logger.warning( f"Config of the text model: {self.text_model.__class__} is" f"overwritten by shared text config: {self.config.text_model}" ) if self.vision_model.config.to_dict() != self.config.vision_model.to_dict(): logger.warning( f"Config of the vision model: {self.vision_model.__class__} is" f"overwritten by shared vision config: {self.config.vision_model}" ) # Make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.text_model.config = self.config.text_model self.vision_model.config = self.config.vision_model # Resize embedding layer base_size = self.text_model.config.vocab_size if base_size < self.config.vocab_size: self.text_model.resize_token_embeddings(self.config.vocab_size) new_size = self.text_model.get_input_embeddings().num_embeddings logger.info(f"Resize embedding layer from {base_size} to {new_size} tokens") d_text = self.text_model.config.hidden_size assert self.vision_model.config.model_type == "vqmodel" d_vision = self.vision_model.model.embed_dim image_pos_size = self.config.image_pos_size sigma = self.config.image_pos_sigma m = math.ceil(image_pos_size / 2) # (sin, cos) self.image_vocab_size = self.vision_model.model.n_embed # Define additional layers self.proj_vpos = rff.layers.PositionalEncoding(sigma, m) self.proj_vt = nn.Linear(d_vision + image_pos_size, d_text) self.vis_head = nn.Linear(d_text, self.image_vocab_size) # Compute num_image_tokens scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) latent_size = self.config.image_size // scale_factor self.num_image_tokens = latent_size**2 # Initialize weights and apply final processing self.post_init() # Freeze text embeddings if needed if config.freeze_text_embeddings: self.text_model.get_input_embeddings().requires_grad_(False) def tie_weights(self) -> None: self.text_model.tie_weights() @classmethod def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM": assert "config" in kwargs, "Config must be provided" config = kwargs["config"] dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None)) # Initialize text model text_model = AutoModelForCausalLM.from_config( config.text_model, dtype=dtype, attn_implementation=config._attn_implementation, ) # Initialize vision model with contextlib.redirect_stdout(open(os.devnull, "w")): vision_model = AutoModel.from_config( config.vision_model, trust_remote_code=True, dtype=dtype, ) return super().from_pretrained( # type: ignore *args, **kwargs, text_model=text_model, vision_model=vision_model, ) def forward( self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor | None = None, image_mask: torch.Tensor | None = None, image_pos_ids: torch.Tensor | None = None, labels: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, past_key_values: tuple[tuple[torch.Tensor]] | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, cache_position: torch.Tensor | None = None, num_items_in_batch: int | None = None, **kwargs: Any, ) -> CausalLMOutputWithPast: for key in kwargs.keys(): if kwargs[key] is not None: raise ValueError(f"Unknown argument: {key}={kwargs[key]}") output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if image_mask is None: image_mask = input_ids >= self.config.vocab_size # Embed inputs if inputs_embeds is None: inputs_embeds = self.embed_tokens( input_ids, image_mask=image_mask, image_pos_ids=image_pos_ids, ) # Core forward pass fwd_kwargs = { "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "output_hidden_states": True, "output_attentions": output_attentions, } if self.config.text_model.model_type == "starcoder2": fwd_kwargs["cache_position"] = cache_position outputs = self.text_model(**fwd_kwargs) # text_logits: (B, L, V) text_logits = outputs.logits[:, :, : self.config.vocab_size] # vision_logits: (B, L, C) last_hidden_states = outputs.hidden_states[-1] vision_logits = self.vis_head(last_hidden_states) if labels is not None: # Mask logits with shifted image mask shift_mask = F.pad(image_mask[:, 1:], (0, 1), value=False) text_logits[shift_mask] = -float("inf") vision_logits[~shift_mask] = -float("inf") # Concatenate text and vision logits logits = torch.cat([text_logits, vision_logits], dim=-1) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, image_vocab_size=self.image_vocab_size, image_loss_weight=self.config.image_loss_weight, num_items_in_batch=num_items_in_batch, **kwargs, ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions, ) def embed_tokens( self, input_ids: torch.Tensor, image_mask: torch.Tensor | None = None, image_pos_ids: torch.Tensor | None = None, ) -> torch.Tensor: if image_mask is None: return self.text_embed(input_ids) # type: ignore # Prepare placeholders size = input_ids.size() + (self.text_model.config.hidden_size,) inputs_embeds = torch.zeros(size, device=self.device, dtype=self.dtype) # Embed text ids text_embeds = self.text_embed(input_ids[~image_mask]) inputs_embeds[~image_mask] = text_embeds # Embed image ids image_embeds = self.vis_embed(input_ids[image_mask] - self.config.vocab_size) # Concatenate positional embeddings assert image_pos_ids is not None image_pos = image_pos_ids / self.num_image_tokens image_pos = self.proj_vpos(image_pos.unsqueeze(-1)).to(image_embeds) image_pos = image_pos[image_mask][:, : self.config.image_pos_size] image_embeds = torch.cat([image_embeds, image_pos], dim=-1) # type: ignore # Project image features and update inputs_embeds image_embeds = self.proj_vt(image_embeds) inputs_embeds[image_mask] = image_embeds return inputs_embeds def text_embed(self, input_ids: torch.Tensor) -> torch.Tensor: return self.text_model.get_input_embeddings()(input_ids) # type: ignore def vis_embed(self, input_ids: torch.Tensor) -> torch.Tensor: return self.vision_model.model.quantize.embedding(input_ids) # type: ignore def prepare_inputs_for_generation( self, input_ids: torch.Tensor, **model_kwargs: Any ) -> dict: # Prepare inputs with the default function default_prepare_inputs = self.text_model.prepare_inputs_for_generation inputs = default_prepare_inputs(input_ids, **model_kwargs) # Compute image_pos_ids base_ids = torch.arange(self.num_image_tokens, device=self.device) image_pos_ids = torch.zeros_like(input_ids) image_mask_all = input_ids >= self.config.vocab_size for i_batch, image_mask in enumerate(image_mask_all): N = sum(image_mask) pos_ids = base_ids.repeat(N // self.num_image_tokens + 1) image_pos_ids[i_batch, image_mask] = pos_ids[:N] length = inputs["input_ids"].size(1) inputs["image_pos_ids"] = image_pos_ids[:, -length:] inputs["image_mask"] = inputs["input_ids"] >= self.config.vocab_size return inputs # type: ignore