|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch Qwen2-VL model.""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging |
|
|
from transformers.configuration_utils import PretrainedConfig, layer_type_validation |
|
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
from transformers.modeling_outputs import ( |
|
|
ModelOutput, |
|
|
) |
|
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( |
|
|
Qwen2_5_VLVisionConfig, |
|
|
Qwen2_5_VLTextConfig, |
|
|
Qwen2_5_VLConfig, |
|
|
) |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VLAttention, |
|
|
Qwen2RMSNorm, |
|
|
Qwen2_5_VLRotaryEmbedding, |
|
|
) |
|
|
from DCMoE import UniMoEAudioSparseMoeBlock |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
FAST_INIT = True |
|
|
|
|
|
class Qwen2_5_VLMoETextConfig(Qwen2_5_VLTextConfig): |
|
|
model_type = "qwen2_5_vl_moe_text" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mlp_dynamic_expert_num=4, |
|
|
mlp_dynamic_null_expert_num=0, |
|
|
mlp_dynamic_top_p=0.7, |
|
|
mlp_dynamic_top_k=2, |
|
|
mlp_fixed_expert_num=2, |
|
|
dynamic_intermediate_size=8960, |
|
|
shared_intermediate_size=8960, |
|
|
ignore_differentiable_router=False, |
|
|
enable_expert_tensor_parallelism: bool = False, |
|
|
ep_size=1, |
|
|
fixed_ep_size=1, |
|
|
router_jitter_noise=0.01, |
|
|
input_jitter_noise=0.01, |
|
|
token_drop=False, |
|
|
drop_policy: str = "probs", |
|
|
min_capacity: int = 8, |
|
|
capacity_factor: float = 1.0, |
|
|
fp32_gate=True, |
|
|
avg_hidden_states_last=False, |
|
|
drop_token_num_print=True, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
self.mlp_dynamic_expert_num = mlp_dynamic_expert_num |
|
|
self.mlp_dynamic_top_p = mlp_dynamic_top_p |
|
|
self.mlp_dynamic_top_k = mlp_dynamic_top_k |
|
|
self.mlp_fixed_expert_num = mlp_fixed_expert_num |
|
|
self.mlp_dynamic_null_expert_num = mlp_dynamic_null_expert_num |
|
|
self.dynamic_intermediate_size = dynamic_intermediate_size |
|
|
self.shared_intermediate_size = shared_intermediate_size |
|
|
self.ignore_differentiable_router = ignore_differentiable_router |
|
|
self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism |
|
|
self.ep_size = ep_size |
|
|
self.fixed_ep_size = fixed_ep_size |
|
|
self.input_jitter_noise = input_jitter_noise |
|
|
self.router_jitter_noise = router_jitter_noise |
|
|
self.token_drop = token_drop |
|
|
self.drop_policy = drop_policy |
|
|
self.min_capacity = min_capacity |
|
|
self.capacity_factor = capacity_factor |
|
|
self.fp32_gate = fp32_gate |
|
|
self.avg_hidden_states_last = avg_hidden_states_last |
|
|
self.drop_token_num_print = drop_token_num_print |
|
|
|
|
|
class UniMoEAudioConfig(PretrainedConfig): |
|
|
model_type = "uni_audio_rvq_qwen2_5vl_moe" |
|
|
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLMoETextConfig} |
|
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
text_config=None, |
|
|
vision_config=None, |
|
|
image_token_id=151655, |
|
|
video_token_id=151656, |
|
|
codec_vocab_size=1028, |
|
|
codec_delay_pattern=[0, 8, 9, 10, 11, 12, 13, 14, 15], |
|
|
codec_channels=9, |
|
|
codec_eos_value=1024, |
|
|
codec_pad_value=1025, |
|
|
codec_bos_value=1026, |
|
|
codec_placeholder_value=None, |
|
|
**kwargs, |
|
|
): |
|
|
if isinstance(vision_config, dict): |
|
|
self.vision_config = self.sub_configs["vision_config"](**vision_config) |
|
|
elif vision_config is None: |
|
|
self.vision_config = self.sub_configs["vision_config"]() |
|
|
|
|
|
if isinstance(text_config, dict): |
|
|
self.text_config = self.sub_configs["text_config"](**text_config) |
|
|
elif text_config is None: |
|
|
self.text_config = self.sub_configs["text_config"](**kwargs) |
|
|
|
|
|
self.image_token_id = image_token_id |
|
|
self.video_token_id = video_token_id |
|
|
self.codec_vocab_size = codec_vocab_size |
|
|
self.codec_delay_pattern = codec_delay_pattern |
|
|
self.codec_channels = codec_channels |
|
|
self.codec_eos_value = codec_eos_value |
|
|
self.codec_pad_value = codec_pad_value |
|
|
self.codec_bos_value = codec_bos_value |
|
|
self.codec_placeholder_value = codec_placeholder_value |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
@dataclass |
|
|
class MoEQwen2_5VLCausalLMOutputWithPast(ModelOutput): |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: torch.FloatTensor = None |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
rope_deltas: Optional[torch.LongTensor] = None |
|
|
all_router_logits: Tuple = None |
|
|
all_router_top_k: Tuple = None |
|
|
all_router_expert_mask: Tuple = None |
|
|
all_router_weight: Tuple = None |
|
|
aux_balance_loss: torch.FloatTensor = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BaseModelOutputWithPast(ModelOutput): |
|
|
last_hidden_state: torch.FloatTensor = None |
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
all_router_logits: Tuple = None |
|
|
all_router_top_k: Tuple = None |
|
|
all_router_weight: Tuple = None |
|
|
all_router_expert_mask: Tuple = None |
|
|
all_aux_loss: Tuple = None |
|
|
|
|
|
|
|
|
class Qwen2_5_VLMoEDecoderLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config: Qwen2_5_VLMoETextConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
if config.use_sliding_window and config._attn_implementation != "flash_attention_2": |
|
|
logger.warning_once( |
|
|
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " |
|
|
"unexpected results may be encountered." |
|
|
) |
|
|
|
|
|
self.self_attn = Qwen2_5_VLAttention(config, layer_idx) |
|
|
self.mlp = UniMoEAudioSparseMoeBlock(config) |
|
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.attention_type = config.layer_types[layer_idx] |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
padding_token_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[tuple[torch.Tensor]] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
output_router_logits_and_topk: Optional[bool] = False, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states, router_logits, router_top_k, router_expert_mask, router_weight, aux_loss = self.mlp(hidden_states, padding_token_mask) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
|
|
|
if output_attentions: |
|
|
outputs += (self_attn_weights,) |
|
|
|
|
|
if output_router_logits_and_topk: |
|
|
outputs += (router_logits,) |
|
|
outputs += (router_top_k,) |
|
|
outputs += (router_expert_mask,) |
|
|
outputs += (router_weight,) |
|
|
outputs += (aux_loss,) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class Qwen2_5_VLMoEPreTrainedModel(PreTrainedModel): |
|
|
config_class = UniMoEAudioConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen2_5_VLMoEDecoderLayer", "Qwen2_5_VLVisionBlock"] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_flash_attn_3 = True |
|
|
_supports_sdpa = True |
|
|
_supports_cache_class = True |
|
|
_supports_static_cache = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
std = self.config.initializer_range |
|
|
if FAST_INIT: |
|
|
if isinstance(module, UniMoEAudioSparseMoeBlock): |
|
|
module.gate.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.gate.bias is not None: |
|
|
module.gate.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
else: |
|
|
if isinstance(module, (nn.Linear, nn.Conv3d)): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, Qwen2RMSNorm): |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
class Qwen2_5_VLMoETextModel(Qwen2_5_VLMoEPreTrainedModel): |
|
|
config_class = Qwen2_5_VLMoETextConfig |
|
|
def __init__(self, config: Qwen2_5_VLMoETextConfig): |
|
|
super().__init__(config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
self.layers = nn.ModuleList( |
|
|
[Qwen2_5_VLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
|
) |
|
|
self._attn_implementation = config._attn_implementation |
|
|
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) |
|
|
self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed_tokens = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
padding_token_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
output_router_logits_and_topk: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
if use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
|
) |
|
|
use_cache = False |
|
|
|
|
|
if use_cache and past_key_values is None and not torch.jit.is_tracing(): |
|
|
past_key_values = DynamicCache() |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if cache_position is None: |
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) |
|
|
elif position_ids.dim() == 2: |
|
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) |
|
|
|
|
|
if not isinstance(causal_mask_mapping := attention_mask, dict): |
|
|
mask_kwargs = { |
|
|
"config": self.config, |
|
|
"input_embeds": inputs_embeds, |
|
|
"attention_mask": attention_mask, |
|
|
"cache_position": cache_position, |
|
|
"past_key_values": past_key_values, |
|
|
"position_ids": position_ids, |
|
|
} |
|
|
causal_mask_mapping = { |
|
|
"full_attention": create_causal_mask(**mask_kwargs), |
|
|
} |
|
|
if self.has_sliding_layers: |
|
|
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
all_router_logits = () if output_router_logits_and_topk else None |
|
|
all_router_top_k = () if output_router_logits_and_topk else None |
|
|
all_router_expert_mask = () |
|
|
all_router_weight = () |
|
|
all_aux_loss = () |
|
|
next_decoder_cache = None |
|
|
|
|
|
for decoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type], |
|
|
padding_token_mask=padding_token_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
output_router_logits_and_topk=output_router_logits_and_topk, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
if output_router_logits_and_topk: |
|
|
all_router_logits += (layer_outputs[-5],) |
|
|
all_router_top_k += (layer_outputs[-4],) |
|
|
all_router_expert_mask += (layer_outputs[-3],) |
|
|
all_router_weight += (layer_outputs[-2],) |
|
|
all_aux_loss += (layer_outputs[-1],) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple( |
|
|
v for v in [ |
|
|
hidden_states, |
|
|
past_key_values, |
|
|
all_hidden_states, |
|
|
all_self_attns, |
|
|
all_router_logits, |
|
|
all_router_top_k, |
|
|
all_router_expert_mask, |
|
|
all_router_weight, |
|
|
all_aux_loss] |
|
|
if v is not None |
|
|
) |
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
all_router_logits=all_router_logits, |
|
|
all_router_top_k=all_router_top_k, |
|
|
all_router_expert_mask=all_router_expert_mask, |
|
|
all_router_weight=all_router_weight, |
|
|
all_aux_loss=all_aux_loss, |
|
|
) |
|
|
|
|
|
|
|
|
class UniMoEAudio(Qwen2_5_VLMoEPreTrainedModel): |
|
|
base_model_prefix = "" |
|
|
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] |
|
|
config_class = UniMoEAudioConfig |
|
|
_checkpoint_conversion_mapping = { |
|
|
"^visual": "visual", |
|
|
r"^model(?!\.(language_model|visual))": "language_model", |
|
|
} |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config, attn_implementation=config._attn_implementation) |
|
|
self.language_model = Qwen2_5_VLMoETextModel._from_config(config.text_config) |
|
|
self.rope_deltas = None |
|
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) |
|
|
self.num_channels = config.codec_channels |
|
|
self.codec_vocab_size = config.codec_vocab_size |
|
|
self.codec_embed_tokens = nn.ModuleList( |
|
|
[nn.Embedding(self.codec_vocab_size, config.text_config.hidden_size) for embed_idx in range(self.num_channels)]) |
|
|
self.codec_placeholder_value = config.codec_placeholder_value |
|
|
self.codec_head = nn.Linear(config.text_config.hidden_size, self.num_channels * self.codec_vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
@property |
|
|
def cur_aux_weight(self): |
|
|
if self.training_steps >= self.l_aux_weight_decay_steps: |
|
|
return self.min_l_aux_weight |
|
|
return self.l_aux_weight - (self.l_aux_weight - self.min_l_aux_weight) / self.l_aux_weight_decay_steps * self.training_steps |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.language_model.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.language_model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.language_model |
|
|
|
|
|
def get_rope_index( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size |
|
|
image_token_id = self.config.image_token_id |
|
|
video_token_id = self.config.video_token_id |
|
|
vision_start_token_id = self.config.vision_start_token_id |
|
|
mrope_position_deltas = [] |
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): |
|
|
total_input_ids = input_ids |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(total_input_ids) |
|
|
position_ids = torch.ones( |
|
|
3, |
|
|
input_ids.shape[0], |
|
|
input_ids.shape[1], |
|
|
dtype=input_ids.dtype, |
|
|
device=input_ids.device, |
|
|
) |
|
|
image_index, video_index = 0, 0 |
|
|
attention_mask = attention_mask.to(total_input_ids.device) |
|
|
for i, input_ids in enumerate(total_input_ids): |
|
|
input_ids = input_ids[attention_mask[i] == 1] |
|
|
image_nums, video_nums = 0, 0 |
|
|
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) |
|
|
vision_tokens = input_ids[vision_start_indices + 1] |
|
|
image_nums = (vision_tokens == image_token_id).sum() |
|
|
video_nums = (vision_tokens == video_token_id).sum() |
|
|
input_tokens = input_ids.tolist() |
|
|
llm_pos_ids_list: list = [] |
|
|
st = 0 |
|
|
remain_images, remain_videos = image_nums, video_nums |
|
|
for _ in range(image_nums + video_nums): |
|
|
if image_token_id in input_tokens and remain_images > 0: |
|
|
ed_image = input_tokens.index(image_token_id, st) |
|
|
else: |
|
|
ed_image = len(input_tokens) + 1 |
|
|
if video_token_id in input_tokens and remain_videos > 0: |
|
|
ed_video = input_tokens.index(video_token_id, st) |
|
|
else: |
|
|
ed_video = len(input_tokens) + 1 |
|
|
if ed_image < ed_video: |
|
|
t, h, w = ( |
|
|
image_grid_thw[image_index][0], |
|
|
image_grid_thw[image_index][1], |
|
|
image_grid_thw[image_index][2], |
|
|
) |
|
|
second_per_grid_t = 0 |
|
|
image_index += 1 |
|
|
remain_images -= 1 |
|
|
ed = ed_image |
|
|
|
|
|
else: |
|
|
t, h, w = ( |
|
|
video_grid_thw[video_index][0], |
|
|
video_grid_thw[video_index][1], |
|
|
video_grid_thw[video_index][2], |
|
|
) |
|
|
if second_per_grid_ts is not None: |
|
|
second_per_grid_t = second_per_grid_ts[video_index] |
|
|
else: |
|
|
second_per_grid_t = 1.0 |
|
|
video_index += 1 |
|
|
remain_videos -= 1 |
|
|
ed = ed_video |
|
|
llm_grid_t, llm_grid_h, llm_grid_w = ( |
|
|
t.item(), |
|
|
h.item() // spatial_merge_size, |
|
|
w.item() // spatial_merge_size, |
|
|
) |
|
|
text_len = ed - st |
|
|
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
|
|
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1) |
|
|
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) |
|
|
second_per_grid_t = torch.as_tensor( |
|
|
second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device |
|
|
) |
|
|
|
|
|
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second |
|
|
|
|
|
time_tensor_long = time_tensor.long() |
|
|
t_index = time_tensor_long.flatten() |
|
|
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) |
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
|
|
|
|
|
if st < len(input_tokens): |
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
|
text_len = len(input_tokens) - st |
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) |
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) |
|
|
return position_ids, mrope_position_deltas |
|
|
else: |
|
|
if attention_mask is not None: |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) |
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
|
|
else: |
|
|
position_ids = ( |
|
|
torch.arange(input_ids.shape[1], device=input_ids.device) |
|
|
.view(1, 1, -1) |
|
|
.expand(3, input_ids.shape[0], -1) |
|
|
) |
|
|
mrope_position_deltas = torch.zeros( |
|
|
[input_ids.shape[0], 1], |
|
|
device=input_ids.device, |
|
|
dtype=input_ids.dtype, |
|
|
) |
|
|
|
|
|
return position_ids, mrope_position_deltas |
|
|
|
|
|
def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None): |
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() |
|
|
video_embeds = torch.split(video_embeds, split_sizes) |
|
|
return video_embeds |
|
|
|
|
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
|
|
pixel_values = pixel_values.type(self.visual.dtype) |
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() |
|
|
image_embeds = torch.split(image_embeds, split_sizes) |
|
|
return image_embeds |
|
|
|
|
|
|
|
|
def codec_embedding(self, codec_input_ids): |
|
|
x = None |
|
|
for i in range(self.num_channels): |
|
|
channel_tokens = codec_input_ids[..., i] |
|
|
channel_embed = self.codec_embed_tokens[i](channel_tokens) |
|
|
x = channel_embed if x is None else x + channel_embed |
|
|
return x |
|
|
|
|
|
def calculate_input_embedding(self, input_ids, codec_input_ids): |
|
|
inputs_embeds = self.language_model.embed_tokens(input_ids) |
|
|
if codec_input_ids is not None: |
|
|
codec_input_embeds = self.codec_embedding(codec_input_ids) |
|
|
|
|
|
codec_mask = (input_ids == self.codec_placeholder_value).unsqueeze(-1).expand_as(inputs_embeds) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(codec_mask, codec_input_embeds) |
|
|
return inputs_embeds |
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
codec_input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
codec_labels: Optional[torch.LongTensor] = None, |
|
|
padding_token_mask: Optional[torch.Tensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
output_router_logits_and_topk: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
|
|
|
) -> Union[Tuple, MoEQwen2_5VLCausalLMOutputWithPast]: |
|
|
return_dict = True |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) |
|
|
|
|
|
if pixel_values is not None: |
|
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw) |
|
|
image_embeds = torch.cat(image_embeds, dim=0) |
|
|
|
|
|
if input_ids is None: |
|
|
image_mask = inputs_embeds == self.get_input_embeddings()( |
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) |
|
|
) |
|
|
image_mask = image_mask.all(-1) |
|
|
else: |
|
|
image_mask = input_ids == self.config.image_token_id |
|
|
|
|
|
n_image_tokens = (image_mask).sum() |
|
|
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
n_image_features = image_embeds.shape[0] |
|
|
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
|
) |
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) |
|
|
video_embeds = torch.cat(video_embeds, dim=0) |
|
|
|
|
|
if input_ids is None: |
|
|
video_mask = inputs_embeds == self.get_input_embeddings()( |
|
|
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) |
|
|
) |
|
|
video_mask = video_mask.all(-1) |
|
|
else: |
|
|
video_mask = input_ids == self.config.video_token_id |
|
|
|
|
|
n_video_tokens = (video_mask).sum() |
|
|
n_video_features = video_embeds.shape[0] |
|
|
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: |
|
|
raise ValueError( |
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
|
) |
|
|
|
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if position_ids is None: |
|
|
attention_mask_tensor = ( |
|
|
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] |
|
|
) |
|
|
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: |
|
|
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) |
|
|
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min |
|
|
attention_mask_tensor = (1.0 - attention_mask_tensor).int() |
|
|
prefill_compiled_stage = is_torchdynamo_compiling() and ( |
|
|
(input_ids is not None and input_ids.shape[1] != 1) |
|
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) |
|
|
) |
|
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( |
|
|
(cache_position is not None and cache_position[0] == 0) |
|
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) |
|
|
) |
|
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: |
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
|
input_ids, |
|
|
image_grid_thw, |
|
|
video_grid_thw, |
|
|
second_per_grid_ts=second_per_grid_ts, |
|
|
attention_mask=attention_mask_tensor, |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
|
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
delta = ( |
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
|
|
if cache_position is not None |
|
|
else 0 |
|
|
) |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(delta) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
if padding_token_mask is None: |
|
|
padding_token_mask = attention_mask.bool() |
|
|
|
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
padding_token_mask=padding_token_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
output_router_logits_and_topk=output_router_logits_and_topk, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states).float() |
|
|
codec_logits = self.codec_head(hidden_states).float() |
|
|
codec_logits = codec_logits.view((logits.shape[0], logits.shape[1], self.num_channels, self.codec_vocab_size)) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
all_aux_loss = outputs.all_aux_loss if return_dict else outputs[-1] |
|
|
all_aux_loss = torch.mean(torch.cat([l.unsqueeze(0) for l in all_aux_loss], dim=0)) |
|
|
aux_loss = self.cur_aux_weight * all_aux_loss |
|
|
self.training_steps += 1 |
|
|
codec_loss = None |
|
|
|
|
|
if codec_labels is not None: |
|
|
for i in range(self.num_channels): |
|
|
channel_logits = codec_logits[:, :, i].float() |
|
|
channel_labels = codec_labels[:, :, i] |
|
|
shift_channel_logits = channel_logits[..., :-1, :].contiguous() |
|
|
shift_channel_labels = channel_labels[..., 1:].contiguous() |
|
|
|
|
|
if i!= 0 and (shift_channel_labels != -100).sum() == 0: |
|
|
continue |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_channel_logits = shift_channel_logits.view(-1, self.codec_vocab_size) |
|
|
shift_channel_labels = shift_channel_labels.view(-1) |
|
|
shift_channel_labels = shift_channel_labels.to(shift_channel_logits.device) |
|
|
channel_loss = loss_fct(shift_channel_logits, shift_channel_labels) |
|
|
codec_loss = channel_loss if codec_loss is None else codec_loss + channel_loss |
|
|
|
|
|
loss = codec_loss + aux_loss |
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return MoEQwen2_5VLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
all_router_logits=outputs.all_router_logits, |
|
|
all_router_top_k=outputs.all_router_top_k, |
|
|
all_router_expert_mask=outputs.all_router_expert_mask, |
|
|
all_router_weight=outputs.all_router_weight, |
|
|
aux_balance_loss=all_aux_loss, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _sample_next_token( |
|
|
logits_BCxV: torch.Tensor, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
top_k: int, |
|
|
audio_eos_value: int, |
|
|
) -> torch.Tensor: |
|
|
if temperature == 0.0: |
|
|
return torch.argmax(logits_BCxV, dim=-1) |
|
|
|
|
|
logits_BCxV = logits_BCxV / temperature |
|
|
|
|
|
if audio_eos_value is not None and audio_eos_value >= 0: |
|
|
top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1) |
|
|
eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value |
|
|
mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool) |
|
|
mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True |
|
|
logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf) |
|
|
|
|
|
if top_k is not None: |
|
|
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1) |
|
|
mask = torch.ones_like(logits_BCxV, dtype=torch.bool) |
|
|
mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False) |
|
|
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) |
|
|
|
|
|
if top_p < 1.0: |
|
|
probs_BCxV = torch.softmax(logits_BCxV, dim=-1) |
|
|
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True) |
|
|
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) |
|
|
|
|
|
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p |
|
|
sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1) |
|
|
sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0]) |
|
|
|
|
|
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) |
|
|
indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV) |
|
|
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) |
|
|
|
|
|
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) |
|
|
|
|
|
sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) |
|
|
sampled_indices_C = sampled_indices_BC.squeeze(-1) |
|
|
return sampled_indices_C |
|
|
|
|
|
def _decoder_step( |
|
|
self, |
|
|
tokens_Bx1xC: torch.Tensor, |
|
|
model_kwargs, |
|
|
cfg_scale: float, |
|
|
neg_input_size: int, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
top_k: int, |
|
|
do_sample=True, |
|
|
eos_prob_mul_factor=1.0, |
|
|
labels_Bx1xC=None, |
|
|
use_cache=True, |
|
|
enable_eos=True, |
|
|
) -> torch.Tensor: |
|
|
B = tokens_Bx1xC.shape[0] |
|
|
audio_eos_value = self.config.codec_eos_value |
|
|
attention_mask = model_kwargs["attention_mask"] |
|
|
cache_position = model_kwargs["cache_position"] |
|
|
past_key_values = model_kwargs["past_key_values"] |
|
|
input_ids = model_kwargs["input_ids"] |
|
|
codec_input_ids = model_kwargs["codec_input_ids"] |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
if past_key_values: |
|
|
position_ids = position_ids[:, -tokens_Bx1xC.shape[1] :] |
|
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
|
|
|
tokens_Bx1xC = tokens_Bx1xC.repeat_interleave(neg_input_size, dim=0) |
|
|
codec_input_ids = torch.cat((codec_input_ids, tokens_Bx1xC), dim=1) if codec_input_ids is not None else tokens_Bx1xC.clone() |
|
|
input_ids = torch.cat((input_ids, torch.ones(input_ids.shape[0], 1).to(input_ids) * self.codec_placeholder_value), dim=-1) |
|
|
|
|
|
if use_cache: |
|
|
codec_input_embeds = self.codec_embedding(tokens_Bx1xC) |
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=codec_input_embeds, |
|
|
use_cache=True, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
else: |
|
|
batch_codec_input_ids = codec_input_ids.contiguous().view(-1, self.num_channels) |
|
|
|
|
|
inputs_embeds = self.calculate_input_embedding(input_ids, batch_codec_input_ids) |
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=attention_mask.long().cumsum(-1) - 1, |
|
|
past_key_values=None, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=True, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict=True, |
|
|
cache_position=None, |
|
|
) |
|
|
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
codec_logits = self.codec_head(last_hidden_state).float() |
|
|
codec_logits = codec_logits.view((codec_logits.shape[0], codec_logits.shape[1], self.num_channels, self.codec_vocab_size)) |
|
|
model_kwargs["past_key_values"] = outputs.past_key_values |
|
|
attention_mask = model_kwargs["attention_mask"] |
|
|
model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) |
|
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
|
|
model_kwargs["input_ids"] = input_ids |
|
|
model_kwargs["codec_input_ids"] = codec_input_ids |
|
|
|
|
|
logits_Bx1xCxV = codec_logits[: , -1:].clone() |
|
|
logits_last_2BxCxV = logits_Bx1xCxV[:, -1] |
|
|
logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, neg_input_size, *logits_last_2BxCxV.shape[1:]) |
|
|
if cfg_scale is not None: |
|
|
cond_logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] |
|
|
logits_BxCxV = cond_logits_BxCxV |
|
|
for ni in range(neg_input_size - 1): |
|
|
uncond_logits_BxCxV = logits_last_Bx2xCxV[:, ni, :, :] |
|
|
cfg_weight = cfg_scale[ni] if isinstance(cfg_scale, List) else cfg_scale |
|
|
logits_BxCxV = logits_BxCxV + cfg_weight * (cond_logits_BxCxV - uncond_logits_BxCxV) |
|
|
else: |
|
|
logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] |
|
|
|
|
|
if enable_eos: |
|
|
logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like( |
|
|
logits_BxCxV[:, :, audio_eos_value + 1 :], |
|
|
fill_value=-torch.inf, |
|
|
) |
|
|
logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like( |
|
|
logits_BxCxV[:, 1:, audio_eos_value:], |
|
|
fill_value=-torch.inf, |
|
|
) |
|
|
logits_BxCxV[:, 0, audio_eos_value] *= torch.tensor(eos_prob_mul_factor, device=self.device) |
|
|
|
|
|
else: |
|
|
logits_BxCxV[:, :, audio_eos_value:] = torch.full_like( |
|
|
logits_BxCxV[:, :, audio_eos_value:], |
|
|
fill_value=-torch.inf, |
|
|
) |
|
|
|
|
|
|
|
|
flat_logits_BCxV = logits_BxCxV.reshape(B * self.num_channels, -1) |
|
|
if do_sample: |
|
|
pred_BC = self._sample_next_token( |
|
|
flat_logits_BCxV.float(), |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
audio_eos_value=audio_eos_value, |
|
|
) |
|
|
else: |
|
|
pred_BC = torch.argmax(flat_logits_BCxV, dim=1) |
|
|
|
|
|
pred_BxC = pred_BC.view(B, self.num_channels) |
|
|
|
|
|
return pred_BxC, model_kwargs |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids, |
|
|
attention_mask, |
|
|
dec_output, |
|
|
max_tokens, |
|
|
min_tokens=None, |
|
|
codec_input_ids: Optional[torch.Tensor] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
neg_input_size = 2, |
|
|
cfg_scale = 3.0, |
|
|
temperature: float = 1.2, |
|
|
top_p: float = 0.95, |
|
|
cfg_filter_top_k: int = 45, |
|
|
eos_prob_mul_factor: float = 0.8, |
|
|
do_sample: bool = True, |
|
|
debug_guidance_step: int = 0, |
|
|
use_cache=True, |
|
|
): |
|
|
if codec_input_ids is not None: |
|
|
assert use_cache |
|
|
batch_size = input_ids.shape[0] // neg_input_size |
|
|
audio_eos_value = self.config.codec_eos_value |
|
|
audio_pad_value = self.config.codec_pad_value |
|
|
delay_pattern = self.config.codec_delay_pattern |
|
|
max_delay_pattern = max(delay_pattern) |
|
|
delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long) |
|
|
|
|
|
dec_step = min(dec_output.prefill_steps) - 1 |
|
|
|
|
|
eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device) |
|
|
eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) |
|
|
finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device) |
|
|
|
|
|
bos_over = False |
|
|
model_kwargs = dict(attention_mask=attention_mask, use_cache=True) |
|
|
model_kwargs["past_key_values"] = DynamicCache() |
|
|
model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 |
|
|
attention_mask = model_kwargs["attention_mask"] |
|
|
past_key_values = model_kwargs["past_key_values"] |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
cache_position = torch.arange(0, input_ids.shape[-1], device=input_ids.device) |
|
|
inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids) |
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
pixel_values=pixel_values, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
video_grid_thw=video_grid_thw, |
|
|
second_per_grid_ts=second_per_grid_ts, |
|
|
use_cache=True, |
|
|
output_attentions=False, |
|
|
output_hidden_states=False, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
model_kwargs["input_ids"] = input_ids |
|
|
model_kwargs["codec_input_ids"] = None |
|
|
model_kwargs["labels"] = torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100 |
|
|
labels_Bx1xC = dec_output.get_labels_at(0) |
|
|
if labels_Bx1xC is not None: |
|
|
model_kwargs["codec_labels"] = (torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100).unsqueeze(-1).expand(-1, -1, self.num_channels) |
|
|
assert (labels_Bx1xC != self.config.codec_bos_value).sum() == 0 |
|
|
labels_Bx1xC = torch.full_like(labels_Bx1xC, -100) |
|
|
model_kwargs["codec_labels"] = torch.cat((model_kwargs["codec_labels"], labels_Bx1xC), dim=1) |
|
|
model_kwargs["past_key_values"] = outputs.past_key_values |
|
|
attention_mask = model_kwargs["attention_mask"] |
|
|
model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) |
|
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
|
|
|
|
|
while dec_step < max_tokens: |
|
|
if (eos_countdown_Bx == 0).all(): |
|
|
break |
|
|
|
|
|
current_step_idx = dec_step + 1 |
|
|
tokens_Bx1xC = dec_output.get_tokens_at(dec_step) |
|
|
labels_Bx1xC = dec_output.get_labels_at(dec_step + 1) |
|
|
|
|
|
pred_BxC, model_kwargs = self._decoder_step( |
|
|
tokens_Bx1xC=tokens_Bx1xC, |
|
|
model_kwargs=model_kwargs, |
|
|
cfg_scale=cfg_scale, |
|
|
neg_input_size=neg_input_size, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=cfg_filter_top_k, |
|
|
do_sample=do_sample, |
|
|
eos_prob_mul_factor=eos_prob_mul_factor, |
|
|
labels_Bx1xC=labels_Bx1xC, |
|
|
use_cache=use_cache, |
|
|
enable_eos=(min_tokens is None or dec_step >= min_tokens), |
|
|
) |
|
|
if labels_Bx1xC is not None and (dec_step < debug_guidance_step or debug_guidance_step==-1): |
|
|
pred_BxC = labels_Bx1xC[:, 0] |
|
|
|
|
|
active_mask_Bx = eos_countdown_Bx != 0 |
|
|
eos_trigger_Bx = torch.zeros_like(active_mask_Bx) |
|
|
if active_mask_Bx.any(): |
|
|
is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value) |
|
|
is_max_len = current_step_idx >= max_tokens - max_delay_pattern |
|
|
eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len |
|
|
eos_detected_Bx |= eos_trigger_Bx |
|
|
start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0) |
|
|
if start_countdown_mask_Bx.any(): |
|
|
eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern |
|
|
finished_step_Bx[start_countdown_mask_Bx] = current_step_idx |
|
|
|
|
|
padding_mask_Bx = eos_countdown_Bx > 0 |
|
|
if padding_mask_Bx.any(): |
|
|
pred_active_BxC = pred_BxC[padding_mask_Bx].clone() |
|
|
countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx] |
|
|
step_after_eos_Bx = max_delay_pattern - countdown_active_Bx |
|
|
step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1) |
|
|
delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0) |
|
|
eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_ |
|
|
pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_ |
|
|
pred_active_BxC[eos_mask_NxC] = audio_eos_value |
|
|
pred_active_BxC[pad_mask_NxC] = audio_pad_value |
|
|
pred_BxC[padding_mask_Bx] = pred_active_BxC |
|
|
eos_countdown_Bx[padding_mask_Bx] -= 1 |
|
|
|
|
|
if not bos_over: |
|
|
bos_over = all(current_step_idx - prefill_step >= max_delay_pattern for prefill_step in dec_output.prefill_steps) |
|
|
|
|
|
dec_output.update_one(pred_BxC, current_step_idx, not bos_over) |
|
|
dec_step += 1 |
|
|
|
|
|
final_step = dec_step + 1 |
|
|
finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern |
|
|
prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device) |
|
|
lengths_Bx = finished_step_Bx - prefill_steps_tensor |
|
|
lengths_Bx = torch.clamp(lengths_Bx, min=0) |
|
|
max_len = lengths_Bx.max().item() + max_delay_pattern |
|
|
|
|
|
if max_len > 0: |
|
|
num_channels = self.num_channels |
|
|
generated_codes = torch.full( |
|
|
(batch_size, max_len, num_channels), |
|
|
fill_value=audio_pad_value, |
|
|
dtype=torch.long, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
for i in range(batch_size): |
|
|
start_step = dec_output.prefill_steps[i] |
|
|
actual_len = lengths_Bx[i].item() + max_delay_pattern |
|
|
if actual_len > 0: |
|
|
tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :] |
|
|
generated_codes[i, :actual_len, :] = tokens_to_copy |
|
|
|
|
|
return generated_codes, lengths_Bx |
|
|
else: |
|
|
print("Warning: Nothing generated for any sequence in the batch.") |
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|