CoTyle / models /model.py
liuhuijie
update
ce02858
import os
import sys
import math
import pickle
import random
import torch
import numpy as np
import requests
from .utils import get_suppression_coefficient
from io import BytesIO
from typing import Union, List, Optional, Any, Dict, Tuple, Callable
from dataclasses import dataclass
from PIL import Image
from transformers import (
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedModel
)
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLCausalLMOutputWithPast
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2ForCausalLM,
Qwen2Config
)
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
ModelOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.processing_utils import Unpack
from transformers.utils import (
# LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
# from qwen_vl_utils import process_vision_info
from .vlm_unitok import UniTok
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2.modeling_qwen2 import *
class StyleGenerator(Qwen2ForCausalLM):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
code_freq: Any = None,
code_freq_threshold: Any = None,
k: Any=None,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
coefficient = get_suppression_coefficient(code_freq, code_freq_threshold, k).to(logits.device)
logits[0][0] = logits[0][0] * coefficient
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
_CONFIG_FOR_DOC = "Qwen2_5_VLConfig"
@dataclass
class Qwen2_5_VLCausalLMOutputWithPastQuant(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
quant_info: Optional[Dict[str, Any]] = None
class Qwen2_5_VLForConditionalGeneration_Quant(Qwen2_5_VLForConditionalGeneration):
def forward(
self,
unitok: Optional[Any] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
codebook_id: Any = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
unitok_info = {}
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
if codebook_id == None:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
b_n, dim = image_embeds.shape
image_embeds = image_embeds.reshape(b_n//196,196,dim)
with torch.amp.autocast(device_type='cuda', enabled=False):
output = unitok(image_embeds)
image_embeds_recon, unitok_info = output['img_rec'].squeeze(), output
image_embeds = image_embeds_recon.reshape(b_n, dim)
else:
image_embeds = unitok.quantizer.idx_to_f(codebook_id.unsqueeze(0).to(self.visual.device))
image_embeds = unitok.post_quant_proj(image_embeds).squeeze()
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
image_embeds_for_hook = image_embeds.clone()
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) #
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2_5_VLCausalLMOutputWithPastQuant(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
quant_info = unitok_info,
)
class Qwen2_5_VL_Quant(nn.Module):
def __init__(self, unitok, qwen2_5_vl):
super().__init__()
self.unitok = unitok
self.qwen = qwen2_5_vl
self.dtype = self.qwen.dtype
def forward(self,
input_ids,
attention_mask,
pixel_values=None,
image_grid_thw=None,
output_hidden_states=None,
codebook_id=None,
):
output = self.qwen(
unitok = self.unitok,
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
output_hidden_states=output_hidden_states,
codebook_id=codebook_id,
)
return output