Spaces:
Build error
Build error
| import os | |
| import sys | |
| import math | |
| import pickle | |
| import random | |
| import torch | |
| import numpy as np | |
| import requests | |
| from .utils import get_suppression_coefficient | |
| from io import BytesIO | |
| from typing import Union, List, Optional, Any, Dict, Tuple, Callable | |
| from dataclasses import dataclass | |
| from PIL import Image | |
| from transformers import ( | |
| AutoModel, | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoConfig, | |
| PreTrainedModel | |
| ) | |
| from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor | |
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| Qwen2_5_VLCausalLMOutputWithPast | |
| ) | |
| from transformers.models.qwen2.modeling_qwen2 import ( | |
| Qwen2ForCausalLM, | |
| Qwen2Config | |
| ) | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache | |
| from transformers.generation import GenerationMixin | |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPast, | |
| CausalLMOutputWithPast, | |
| QuestionAnsweringModelOutput, | |
| SequenceClassifierOutputWithPast, | |
| TokenClassifierOutput, | |
| ModelOutput, | |
| ) | |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update | |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | |
| from transformers.processing_utils import Unpack | |
| from transformers.utils import ( | |
| # LossKwargs, | |
| add_code_sample_docstrings, | |
| add_start_docstrings, | |
| add_start_docstrings_to_model_forward, | |
| can_return_tuple, | |
| logging, | |
| replace_return_docstrings, | |
| ) | |
| from transformers.utils.deprecation import deprecate_kwarg | |
| # from qwen_vl_utils import process_vision_info | |
| from .vlm_unitok import UniTok | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss | |
| from transformers.models.qwen2.modeling_qwen2 import * | |
| class StyleGenerator(Qwen2ForCausalLM): | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| logits_to_keep: Union[int, torch.Tensor] = 0, | |
| code_freq: Any = None, | |
| code_freq_threshold: Any = None, | |
| k: Any=None, | |
| **kwargs: Unpack[KwargsForCausalLM], | |
| ) -> CausalLMOutputWithPast: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs: BaseModelOutputWithPast = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| cache_position=cache_position, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) | |
| loss = None | |
| if labels is not None: | |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) | |
| coefficient = get_suppression_coefficient(code_freq, code_freq_threshold, k).to(logits.device) | |
| logits[0][0] = logits[0][0] * coefficient | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| _CONFIG_FOR_DOC = "Qwen2_5_VLConfig" | |
| class Qwen2_5_VLCausalLMOutputWithPastQuant(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: Optional[torch.FloatTensor] = None | |
| past_key_values: Optional[List[torch.FloatTensor]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| rope_deltas: Optional[torch.LongTensor] = None | |
| quant_info: Optional[Dict[str, Any]] = None | |
| class Qwen2_5_VLForConditionalGeneration_Quant(Qwen2_5_VLForConditionalGeneration): | |
| def forward( | |
| self, | |
| unitok: Optional[Any] = None, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| pixel_values_videos: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| rope_deltas: Optional[torch.LongTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| second_per_grid_ts: Optional[torch.Tensor] = None, | |
| codebook_id: Any = None, | |
| ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: | |
| unitok_info = {} | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| if pixel_values is not None: | |
| pixel_values = pixel_values.type(self.visual.dtype) | |
| if codebook_id == None: | |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) | |
| b_n, dim = image_embeds.shape | |
| image_embeds = image_embeds.reshape(b_n//196,196,dim) | |
| with torch.amp.autocast(device_type='cuda', enabled=False): | |
| output = unitok(image_embeds) | |
| image_embeds_recon, unitok_info = output['img_rec'].squeeze(), output | |
| image_embeds = image_embeds_recon.reshape(b_n, dim) | |
| else: | |
| image_embeds = unitok.quantizer.idx_to_f(codebook_id.unsqueeze(0).to(self.visual.device)) | |
| image_embeds = unitok.post_quant_proj(image_embeds).squeeze() | |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() | |
| n_image_features = image_embeds.shape[0] | |
| if n_image_tokens != n_image_features: | |
| raise ValueError( | |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" | |
| ) | |
| mask = input_ids == self.config.image_token_id | |
| mask_unsqueezed = mask.unsqueeze(-1) | |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) | |
| image_mask = mask_expanded.to(inputs_embeds.device) | |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | |
| image_embeds_for_hook = image_embeds.clone() | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # | |
| if pixel_values_videos is not None: | |
| pixel_values_videos = pixel_values_videos.type(self.visual.dtype) | |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) | |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() | |
| n_video_features = video_embeds.shape[0] | |
| if n_video_tokens != n_video_features: | |
| raise ValueError( | |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" | |
| ) | |
| mask = input_ids == self.config.video_token_id | |
| mask_unsqueezed = mask.unsqueeze(-1) | |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) | |
| video_mask = mask_expanded.to(inputs_embeds.device) | |
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) | |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(inputs_embeds.device) | |
| # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme | |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): | |
| # calculate RoPE index once per generation in the pre-fill stage only | |
| if ( | |
| (cache_position is not None and cache_position[0] == 0) | |
| or self.rope_deltas is None | |
| or (past_key_values is None or past_key_values.get_seq_length() == 0) | |
| ): | |
| position_ids, rope_deltas = self.get_rope_index( | |
| input_ids, | |
| image_grid_thw, | |
| video_grid_thw, | |
| second_per_grid_ts, | |
| attention_mask, | |
| ) | |
| self.rope_deltas = rope_deltas | |
| # then use the prev pre-calculated rope-deltas to get the correct position ids | |
| else: | |
| batch_size, seq_length, _ = inputs_embeds.shape | |
| delta = ( | |
| (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) | |
| if cache_position is not None | |
| else 0 | |
| ) | |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) | |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) | |
| if cache_position is not None: # otherwise `deltas` is an int `0` | |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) | |
| position_ids = position_ids.add(delta) | |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) | |
| outputs = self.model( | |
| input_ids=None, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = outputs[0] | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| # Upcast to float if we need to compute the loss to avoid potential precision issues | |
| logits = logits.float() | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return Qwen2_5_VLCausalLMOutputWithPastQuant( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| rope_deltas=self.rope_deltas, | |
| quant_info = unitok_info, | |
| ) | |
| class Qwen2_5_VL_Quant(nn.Module): | |
| def __init__(self, unitok, qwen2_5_vl): | |
| super().__init__() | |
| self.unitok = unitok | |
| self.qwen = qwen2_5_vl | |
| self.dtype = self.qwen.dtype | |
| def forward(self, | |
| input_ids, | |
| attention_mask, | |
| pixel_values=None, | |
| image_grid_thw=None, | |
| output_hidden_states=None, | |
| codebook_id=None, | |
| ): | |
| output = self.qwen( | |
| unitok = self.unitok, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| output_hidden_states=output_hidden_states, | |
| codebook_id=codebook_id, | |
| ) | |
| return output | |