markupdm / modeling_markupdm.py
ktrk115's picture
Update modeling_markupdm.py
248b260 verified
"""PyTorch MarkupDM model."""
import contextlib
import math
import os
from typing import Any
import rff.layers
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoModel,
AutoModelForCausalLM,
GenerationMixin,
PreTrainedModel,
)
from transformers.loss.loss_utils import LOSS_MAPPING
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging
from .configuration_markupdm import MarkupDMConfig
from .loss_utils import WeightedCausalLMLoss
logger = logging.get_logger(__name__)
LOSS_MAPPING["WeightedCausalLMLoss"] = WeightedCausalLMLoss
class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): # type: ignore
config: MarkupDMConfig
config_class = MarkupDMConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
def __init__(
self,
config: MarkupDMConfig,
text_model: PreTrainedModel,
vision_model: PreTrainedModel,
) -> None:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
# Initialize with config
logger.info(f"MarkupDM config: {config}")
super().__init__(config)
self.text_model = text_model.train()
self.vision_model = vision_model.eval().requires_grad_(False)
if self.text_model.config.to_dict() != self.config.text_model.to_dict():
logger.warning(
f"Config of the text model: {self.text_model.__class__} is"
f"overwritten by shared text config: {self.config.text_model}"
)
if self.vision_model.config.to_dict() != self.config.vision_model.to_dict():
logger.warning(
f"Config of the vision model: {self.vision_model.__class__} is"
f"overwritten by shared vision config: {self.config.vision_model}"
)
# Make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.text_model.config = self.config.text_model
self.vision_model.config = self.config.vision_model
# Resize embedding layer
base_size = self.text_model.config.vocab_size
if base_size < self.config.vocab_size:
self.text_model.resize_token_embeddings(self.config.vocab_size)
new_size = self.text_model.get_input_embeddings().num_embeddings
logger.info(f"Resize embedding layer from {base_size} to {new_size} tokens")
d_text = self.text_model.config.hidden_size
assert self.vision_model.config.model_type == "vqmodel"
d_vision = self.vision_model.model.embed_dim
image_pos_size = self.config.image_pos_size
sigma = self.config.image_pos_sigma
m = math.ceil(image_pos_size / 2) # (sin, cos)
self.image_vocab_size = self.vision_model.model.n_embed
# Define additional layers
self.proj_vpos = rff.layers.PositionalEncoding(sigma, m)
self.proj_vt = nn.Linear(d_vision + image_pos_size, d_text)
self.vis_head = nn.Linear(d_text, self.image_vocab_size)
# Compute num_image_tokens
scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1)
latent_size = self.config.image_size // scale_factor
self.num_image_tokens = latent_size**2
# Initialize weights and apply final processing
self.post_init()
# Freeze text embeddings if needed
if config.freeze_text_embeddings:
self.text_model.get_input_embeddings().requires_grad_(False)
def tie_weights(self) -> None:
self.text_model.tie_weights()
@classmethod
def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM":
assert "config" in kwargs, "Config must be provided"
config = kwargs["config"]
dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None))
# Initialize text model
text_model = AutoModelForCausalLM.from_config(
config.text_model,
dtype=dtype,
attn_implementation=config._attn_implementation,
)
# Initialize vision model
with contextlib.redirect_stdout(open(os.devnull, "w")):
vision_model = AutoModel.from_config(
config.vision_model,
trust_remote_code=True,
dtype=dtype,
)
return super().from_pretrained( # type: ignore
*args,
**kwargs,
text_model=text_model,
vision_model=vision_model,
)
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
image_mask: torch.Tensor | None = None,
image_pos_ids: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_values: tuple[tuple[torch.Tensor]] | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
cache_position: torch.Tensor | None = None,
num_items_in_batch: int | None = None,
**kwargs: Any,
) -> CausalLMOutputWithPast:
for key in kwargs.keys():
if kwargs[key] is not None:
raise ValueError(f"Unknown argument: {key}={kwargs[key]}")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if image_mask is None:
image_mask = input_ids >= self.config.vocab_size
# Embed inputs
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(
input_ids,
image_mask=image_mask,
image_pos_ids=image_pos_ids,
)
# Core forward pass
fwd_kwargs = {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"output_hidden_states": True,
"output_attentions": output_attentions,
}
if self.config.text_model.model_type == "starcoder2":
fwd_kwargs["cache_position"] = cache_position
outputs = self.text_model(**fwd_kwargs)
# text_logits: (B, L, V)
text_logits = outputs.logits[:, :, : self.config.vocab_size]
# vision_logits: (B, L, C)
last_hidden_states = outputs.hidden_states[-1]
vision_logits = self.vis_head(last_hidden_states)
if labels is not None:
# Mask logits with shifted image mask
shift_mask = F.pad(image_mask[:, 1:], (0, 1), value=False)
text_logits[shift_mask] = -float("inf")
vision_logits[~shift_mask] = -float("inf")
# Concatenate text and vision logits
logits = torch.cat([text_logits, vision_logits], dim=-1)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
image_vocab_size=self.image_vocab_size,
image_loss_weight=self.config.image_loss_weight,
num_items_in_batch=num_items_in_batch,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
def embed_tokens(
self,
input_ids: torch.Tensor,
image_mask: torch.Tensor | None = None,
image_pos_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if image_mask is None:
return self.text_embed(input_ids) # type: ignore
# Prepare placeholders
size = input_ids.size() + (self.text_model.config.hidden_size,)
inputs_embeds = torch.zeros(size, device=self.device, dtype=self.dtype)
# Embed text ids
text_embeds = self.text_embed(input_ids[~image_mask])
inputs_embeds[~image_mask] = text_embeds
# Embed image ids
image_embeds = self.vis_embed(input_ids[image_mask] - self.config.vocab_size)
# Concatenate positional embeddings
assert image_pos_ids is not None
image_pos = image_pos_ids / self.num_image_tokens
image_pos = self.proj_vpos(image_pos.unsqueeze(-1)).to(image_embeds)
image_pos = image_pos[image_mask][:, : self.config.image_pos_size]
image_embeds = torch.cat([image_embeds, image_pos], dim=-1) # type: ignore
# Project image features and update inputs_embeds
image_embeds = self.proj_vt(image_embeds)
inputs_embeds[image_mask] = image_embeds
return inputs_embeds
def text_embed(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.text_model.get_input_embeddings()(input_ids) # type: ignore
def vis_embed(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.vision_model.model.quantize.embedding(input_ids) # type: ignore
def prepare_inputs_for_generation(
self, input_ids: torch.Tensor, **model_kwargs: Any
) -> dict:
# Prepare inputs with the default function
default_prepare_inputs = self.text_model.prepare_inputs_for_generation
inputs = default_prepare_inputs(input_ids, **model_kwargs)
# Compute image_pos_ids
base_ids = torch.arange(self.num_image_tokens, device=self.device)
image_pos_ids = torch.zeros_like(input_ids)
image_mask_all = input_ids >= self.config.vocab_size
for i_batch, image_mask in enumerate(image_mask_all):
N = sum(image_mask)
pos_ids = base_ids.repeat(N // self.num_image_tokens + 1)
image_pos_ids[i_batch, image_mask] = pos_ids[:N]
length = inputs["input_ids"].size(1)
inputs["image_pos_ids"] = image_pos_ids[:, -length:]
inputs["image_mask"] = inputs["input_ids"] >= self.config.vocab_size
return inputs # type: ignore