|
|
"""PyTorch MarkupDM model.""" |
|
|
|
|
|
import contextlib |
|
|
import math |
|
|
import os |
|
|
from typing import Any |
|
|
|
|
|
import rff.layers |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import ( |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
GenerationMixin, |
|
|
PreTrainedModel, |
|
|
) |
|
|
from transformers.loss.loss_utils import LOSS_MAPPING |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .configuration_markupdm import MarkupDMConfig |
|
|
from .loss_utils import WeightedCausalLMLoss |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
LOSS_MAPPING["WeightedCausalLMLoss"] = WeightedCausalLMLoss |
|
|
|
|
|
|
|
|
class MarkupDMForCausalLM(PreTrainedModel, GenerationMixin): |
|
|
config: MarkupDMConfig |
|
|
config_class = MarkupDMConfig |
|
|
|
|
|
supports_gradient_checkpointing = True |
|
|
_supports_flash_attn_2 = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: MarkupDMConfig, |
|
|
text_model: PreTrainedModel, |
|
|
vision_model: PreTrainedModel, |
|
|
) -> None: |
|
|
if not isinstance(config, self.config_class): |
|
|
raise ValueError(f"Config: {config} has to be of type {self.config_class}") |
|
|
|
|
|
|
|
|
logger.info(f"MarkupDM config: {config}") |
|
|
super().__init__(config) |
|
|
|
|
|
self.text_model = text_model.train() |
|
|
self.vision_model = vision_model.eval().requires_grad_(False) |
|
|
|
|
|
if self.text_model.config.to_dict() != self.config.text_model.to_dict(): |
|
|
logger.warning( |
|
|
f"Config of the text model: {self.text_model.__class__} is" |
|
|
f"overwritten by shared text config: {self.config.text_model}" |
|
|
) |
|
|
if self.vision_model.config.to_dict() != self.config.vision_model.to_dict(): |
|
|
logger.warning( |
|
|
f"Config of the vision model: {self.vision_model.__class__} is" |
|
|
f"overwritten by shared vision config: {self.config.vision_model}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.text_model.config = self.config.text_model |
|
|
self.vision_model.config = self.config.vision_model |
|
|
|
|
|
|
|
|
base_size = self.text_model.config.vocab_size |
|
|
if base_size < self.config.vocab_size: |
|
|
self.text_model.resize_token_embeddings(self.config.vocab_size) |
|
|
new_size = self.text_model.get_input_embeddings().num_embeddings |
|
|
logger.info(f"Resize embedding layer from {base_size} to {new_size} tokens") |
|
|
|
|
|
d_text = self.text_model.config.hidden_size |
|
|
assert self.vision_model.config.model_type == "vqmodel" |
|
|
d_vision = self.vision_model.model.embed_dim |
|
|
image_pos_size = self.config.image_pos_size |
|
|
sigma = self.config.image_pos_sigma |
|
|
m = math.ceil(image_pos_size / 2) |
|
|
self.image_vocab_size = self.vision_model.model.n_embed |
|
|
|
|
|
|
|
|
self.proj_vpos = rff.layers.PositionalEncoding(sigma, m) |
|
|
self.proj_vt = nn.Linear(d_vision + image_pos_size, d_text) |
|
|
self.vis_head = nn.Linear(d_text, self.image_vocab_size) |
|
|
|
|
|
|
|
|
scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) |
|
|
latent_size = self.config.image_size // scale_factor |
|
|
self.num_image_tokens = latent_size**2 |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
if config.freeze_text_embeddings: |
|
|
self.text_model.get_input_embeddings().requires_grad_(False) |
|
|
|
|
|
def tie_weights(self) -> None: |
|
|
self.text_model.tie_weights() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, *args: Any, **kwargs: Any) -> "MarkupDMForCausalLM": |
|
|
assert "config" in kwargs, "Config must be provided" |
|
|
config = kwargs["config"] |
|
|
dtype = kwargs.get("dtype", kwargs.get("torch_dtype", None)) |
|
|
|
|
|
|
|
|
text_model = AutoModelForCausalLM.from_config( |
|
|
config.text_model, |
|
|
dtype=dtype, |
|
|
attn_implementation=config._attn_implementation, |
|
|
) |
|
|
|
|
|
|
|
|
with contextlib.redirect_stdout(open(os.devnull, "w")): |
|
|
vision_model = AutoModel.from_config( |
|
|
config.vision_model, |
|
|
trust_remote_code=True, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
return super().from_pretrained( |
|
|
*args, |
|
|
**kwargs, |
|
|
text_model=text_model, |
|
|
vision_model=vision_model, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
image_mask: torch.Tensor | None = None, |
|
|
image_pos_ids: torch.Tensor | None = None, |
|
|
labels: torch.Tensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.Tensor | None = None, |
|
|
past_key_values: tuple[tuple[torch.Tensor]] | None = None, |
|
|
use_cache: bool | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None, |
|
|
cache_position: torch.Tensor | None = None, |
|
|
num_items_in_batch: int | None = None, |
|
|
**kwargs: Any, |
|
|
) -> CausalLMOutputWithPast: |
|
|
for key in kwargs.keys(): |
|
|
if kwargs[key] is not None: |
|
|
raise ValueError(f"Unknown argument: {key}={kwargs[key]}") |
|
|
|
|
|
output_attentions = ( |
|
|
output_attentions |
|
|
if output_attentions is not None |
|
|
else self.config.output_attentions |
|
|
) |
|
|
output_hidden_states = ( |
|
|
output_hidden_states |
|
|
if output_hidden_states is not None |
|
|
else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
if image_mask is None: |
|
|
image_mask = input_ids >= self.config.vocab_size |
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens( |
|
|
input_ids, |
|
|
image_mask=image_mask, |
|
|
image_pos_ids=image_pos_ids, |
|
|
) |
|
|
|
|
|
|
|
|
fwd_kwargs = { |
|
|
"inputs_embeds": inputs_embeds, |
|
|
"attention_mask": attention_mask, |
|
|
"position_ids": position_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": use_cache, |
|
|
"output_hidden_states": True, |
|
|
"output_attentions": output_attentions, |
|
|
} |
|
|
if self.config.text_model.model_type == "starcoder2": |
|
|
fwd_kwargs["cache_position"] = cache_position |
|
|
outputs = self.text_model(**fwd_kwargs) |
|
|
|
|
|
|
|
|
text_logits = outputs.logits[:, :, : self.config.vocab_size] |
|
|
|
|
|
|
|
|
last_hidden_states = outputs.hidden_states[-1] |
|
|
vision_logits = self.vis_head(last_hidden_states) |
|
|
|
|
|
if labels is not None: |
|
|
|
|
|
shift_mask = F.pad(image_mask[:, 1:], (0, 1), value=False) |
|
|
text_logits[shift_mask] = -float("inf") |
|
|
vision_logits[~shift_mask] = -float("inf") |
|
|
|
|
|
|
|
|
logits = torch.cat([text_logits, vision_logits], dim=-1) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function( |
|
|
logits=logits, |
|
|
labels=labels, |
|
|
image_vocab_size=self.image_vocab_size, |
|
|
image_loss_weight=self.config.image_loss_weight, |
|
|
num_items_in_batch=num_items_in_batch, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states if output_hidden_states else None, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def embed_tokens( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
image_mask: torch.Tensor | None = None, |
|
|
image_pos_ids: torch.Tensor | None = None, |
|
|
) -> torch.Tensor: |
|
|
if image_mask is None: |
|
|
return self.text_embed(input_ids) |
|
|
|
|
|
|
|
|
size = input_ids.size() + (self.text_model.config.hidden_size,) |
|
|
inputs_embeds = torch.zeros(size, device=self.device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
text_embeds = self.text_embed(input_ids[~image_mask]) |
|
|
inputs_embeds[~image_mask] = text_embeds |
|
|
|
|
|
|
|
|
image_embeds = self.vis_embed(input_ids[image_mask] - self.config.vocab_size) |
|
|
|
|
|
|
|
|
assert image_pos_ids is not None |
|
|
image_pos = image_pos_ids / self.num_image_tokens |
|
|
image_pos = self.proj_vpos(image_pos.unsqueeze(-1)).to(image_embeds) |
|
|
image_pos = image_pos[image_mask][:, : self.config.image_pos_size] |
|
|
image_embeds = torch.cat([image_embeds, image_pos], dim=-1) |
|
|
|
|
|
|
|
|
image_embeds = self.proj_vt(image_embeds) |
|
|
inputs_embeds[image_mask] = image_embeds |
|
|
|
|
|
return inputs_embeds |
|
|
|
|
|
def text_embed(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
return self.text_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
def vis_embed(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
return self.vision_model.model.quantize.embedding(input_ids) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, input_ids: torch.Tensor, **model_kwargs: Any |
|
|
) -> dict: |
|
|
|
|
|
default_prepare_inputs = self.text_model.prepare_inputs_for_generation |
|
|
inputs = default_prepare_inputs(input_ids, **model_kwargs) |
|
|
|
|
|
|
|
|
base_ids = torch.arange(self.num_image_tokens, device=self.device) |
|
|
image_pos_ids = torch.zeros_like(input_ids) |
|
|
image_mask_all = input_ids >= self.config.vocab_size |
|
|
for i_batch, image_mask in enumerate(image_mask_all): |
|
|
N = sum(image_mask) |
|
|
pos_ids = base_ids.repeat(N // self.num_image_tokens + 1) |
|
|
image_pos_ids[i_batch, image_mask] = pos_ids[:N] |
|
|
length = inputs["input_ids"].size(1) |
|
|
inputs["image_pos_ids"] = image_pos_ids[:, -length:] |
|
|
|
|
|
inputs["image_mask"] = inputs["input_ids"] >= self.config.vocab_size |
|
|
|
|
|
return inputs |
|
|
|