# -*- coding: utf-8 -*- # Copyright (c) 2026 Meituan # This code is licensed under the MIT License, for details, see the ./LICENSE file. import os from dataclasses import dataclass from tqdm import tqdm from typing import Optional, Union import torch import torch.nn.functional as F from torch import nn from transformers.cache_utils import Cache from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerateNonBeamOutput from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.longcat_flash.modeling_longcat_flash import LongcatFlashForCausalLM from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from .configuration_longcat_next import LongcatNextConfig from .modeling_longcat_ngram import LongcatFlashNgramModel, NgramCache from .modular_longcat_next import CasualDepthTransformerHead from .modular_longcat_next_audio import LongcatNextAudioTokenizer from .modular_longcat_next_visual import LongcatNextVisualTokenizer from .cosy24k_vocoder import Cosy24kVocoder from .image_refiner import ImageRefinerContainer from .refiner_modules import FlowMatchEulerDiscreteScheduler logger = logging.get_logger(__name__) @dataclass class LongcatNextForCausalLMOutputWithPast(CausalLMOutputWithPast): visual_loss: Optional[torch.FloatTensor] = None visual_logits: Optional[torch.FloatTensor] = None visual_ids: Optional[torch.LongTensor] = None audio_loss: Optional[torch.FloatTensor] = None audio_logits: Optional[torch.FloatTensor] = None audio_ids: Optional[torch.LongTensor] = None @dataclass class LongcatNextForCausalLMGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput): visual_ids: Optional[torch.LongTensor] = None audio_ids: Optional[torch.LongTensor] = None audio_text_ids: Optional[torch.LongTensor] = None @dataclass class LongcatNextForCausalLMGenerateEncoderDecoderOutput(GenerateEncoderDecoderOutput): visual_ids: Optional[torch.LongTensor] = None audio_ids: Optional[torch.LongTensor] = None audio_text_ids: Optional[torch.LongTensor] = None @dataclass class LongcatNextForCausalLMGenerationStatus: mode: str = "text" current_image_token_num: int = -1 audio_parallel_decoding: bool = False is_audio_text_end: bool = False is_audio_start: bool = False last_step_mode: str = None def __init__(self, visual_generation_config, audio_generation_config): self.visual_generation_config = visual_generation_config self.h = self.visual_generation_config.custom_params["token_h"] self.w = self.visual_generation_config.custom_params["token_w"] self.anyres_prefix = self.visual_generation_config.custom_params["anyres_prefix"].format(h=self.h, w=self.w) self.audio_generation_config = audio_generation_config self.audio_parallel_decoding = audio_generation_config.audio_parallel_decoding def switch_to(self, modal): assert modal in ["text", "visual", "audio"] self.mode = modal self.current_image_token_num = 0 if modal == "visual" else -1 self.is_audio_text_end = False self.is_audio_start = False @property def is_img_newline(self): return ((self.current_image_token_num + 1) % (self.w + 1)) == 0 and not self.is_img_end @property def is_img_end(self): return (self.current_image_token_num + 1) / (self.w + 1) == self.h class LongcatNextModel(LongcatFlashNgramModel): _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"] config_class = LongcatNextConfig def __init__(self, config): super().__init__(config) self.visual_tokenizer = LongcatNextVisualTokenizer(config) self.audio_tokenizer = LongcatNextAudioTokenizer(config) self._init_multimodal_constants(config) self.post_init() def _init_multimodal_constants(self, config): name2id_dict = { "image_newline_token_id": self.config.visual_config.image_newline_token_id, "image_end_token_id": self.config.visual_config.image_end_token_id, "image_pad_token_id": self.config.visual_config.image_pad_token_id, "audiotext_start_token_id": config.audio_config.audiotext_start_token_id, "audiotext_pad_token_id": self.config.audio_config.audiotext_pad_token_id, "audiogen_end_token_id": config.audio_config.audiogen_end_token_id, "audio_pad_token_id": self.config.audio_config.audio_pad_token_id, } for k, v in name2id_dict.items(): self.register_buffer(k, torch.tensor([v], dtype=torch.long), persistent=False) visual_offset_list = [config.visual_offset] + config.visual_config.vq_config.codebook_sizes[:-1] visual_offset_vals = torch.cumsum(torch.tensor(visual_offset_list, dtype=torch.long), dim=0) self.register_buffer("visual_offset_vals", visual_offset_vals, persistent=False) audio_offset_list = [config.audio_offset] + config.audio_config.vq_config.codebook_sizes[:-1] audio_offset_vals = torch.cumsum(torch.tensor(audio_offset_list, dtype=torch.long), dim=0) self.register_buffer("audio_offset_vals", audio_offset_vals, persistent=False) print(f"{self.visual_offset_vals=}") print(f"{self.audio_offset_vals=}") def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, visual_inputs=None, visual_ids=None, audio_inputs=None, audio_ids=None, audio_text_ids=None, multimodal_generation_status=None, **kwargs ) -> BaseModelOutputWithPast: if input_ids is None: raise ValueError("You must specify input_ids") # Extract N-gram context if available ngram_context = None if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None: ngram_context = past_key_values.ngram_context # assert input_ids.size(0) == 1, "only support bs=1 for now" # but when bs=2, idx=1 is for uncond_image_generation special_visual_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask = self.get_placeholder_mask(input_ids[:1]) # seq-dim if inputs_embeds is None: input_ids[:, special_visual_mask | special_audio_mask | special_audio_text_pad_mask | special_audio_text_start_mask] = 0 filled_text_pad_mask = torch.ones_like(special_audio_mask) audio_text_position_mask = (special_audio_text_pad_mask | special_audio_text_start_mask | special_audio_mask) if audio_text_ids is not None and audio_text_ids.size(1) > 0 and audio_text_position_mask.sum() > 0: filled_text = audio_text_ids[:, -audio_text_position_mask.sum():] filled_text_pad_mask = (filled_text==self.config.audio_config.audiotext_pad_token_id)[0] input_ids[:, audio_text_position_mask] = filled_text input_ids[input_ids == self.config.audio_config.audiotext_pad_token_id] = 0 inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context) inputs_embeds[:, (special_visual_mask | (special_audio_mask & filled_text_pad_mask))] = 0 if special_audio_text_start_mask.sum() > 0: audio_text_start_embedding = self.embed_tokens(self.audiotext_start_token_id) if multimodal_generation_status.last_step_mode is None: # prefill inputs_embeds[:1, special_audio_text_start_mask] += audio_text_start_embedding else: inputs_embeds[:, special_audio_text_start_mask] += audio_text_start_embedding if visual_inputs is not None: visual_ids = self.get_visual_ids(**visual_inputs) # [*seq, lev] if visual_ids is not None and special_visual_mask.sum() > 0: visual_embeddings = self.get_visual_embeddings(visual_ids[-special_visual_mask.sum():]) # -> [seq, dim] if multimodal_generation_status.last_step_mode is None: # prefill inputs_embeds[:1, special_visual_mask] = visual_embeddings.to(inputs_embeds.device) else: inputs_embeds[:, special_visual_mask] = visual_embeddings.to(inputs_embeds.device) if audio_inputs is not None: audio_ids = self.get_audio_ids(**audio_inputs) # -> [*seq, lev] if audio_ids is not None and special_audio_mask.sum() > 0: audio_embeddings = self.get_audio_embeddings(audio_ids[-special_audio_mask.sum():]) # -> [seq, dim] if multimodal_generation_status.last_step_mode is None: # prefill inputs_embeds[:1, special_audio_mask] += audio_embeddings.to(inputs_embeds.device) else: inputs_embeds[:, special_audio_mask] += audio_embeddings.to(inputs_embeds.device) # Initialize NgramCache if needed if use_cache and past_key_values is None: past_key_values = NgramCache(config=self.config) # Update N-gram context if use_cache and isinstance(past_key_values, NgramCache): past_key_values.update_ngram_context(input_ids) return super().forward( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, use_cache=use_cache, **kwargs ) def get_visual_ids(self, pixel_values, visual_grid_thw, offset=True): visual_ids = self.visual_tokenizer.encode(pixel_values, visual_grid_thw) if offset: visual_ids += self.visual_offset_vals.to(visual_ids.device) return visual_ids def get_audio_ids(self, audio, encoder_length, bridge_length, offset=True): audio_ids = self.audio_tokenizer.encode(audio, encoder_length, bridge_length) if offset: audio_ids += self.audio_offset_vals.to(audio_ids.device) return audio_ids @torch.no_grad() def decode_visual_ids_and_save( self, visual_ids, save_prefix, token_h, token_w, **kwargs, ): visual_ids -= self.visual_offset_vals.to(visual_ids.device) if not (save_prefix.startswith("./") or save_prefix.startswith("/")): save_prefix = f"./{save_prefix}" os.makedirs(os.path.dirname(save_prefix), exist_ok=True) return self.visual_tokenizer.lazy_decode_and_save(visual_ids, token_h, token_w, f"{save_prefix}_{0}.png") @torch.no_grad() def decode_audio_ids_and_save( self, audio_ids, save_prefix, sampling_rate, wave_concat_overlap, **kwargs, ): audio_ids -= self.audio_offset_vals.to(audio_ids.device) if not (save_prefix.startswith("./") or save_prefix.startswith("/")): save_prefix = f"./{save_prefix}" os.makedirs(os.path.dirname(save_prefix), exist_ok=True) save_path = f"{save_prefix}_{0}.wav" self.audio_tokenizer.lazy_decode_and_save(audio_ids, sampling_rate, wave_concat_overlap, save_path) return [save_path] def get_visual_embeddings(self, visual_ids): visual_embeddings = self.embed_tokens(visual_ids).sum(dim=1) # [seq, lev] -> [seq, lev, dim] -> [seq, dim] visual_embeddings = self.visual_tokenizer.visual_embedding_layer(visual_embeddings) return visual_embeddings def get_audio_embeddings(self, audio_ids): audio_embeddings = self.embed_tokens(audio_ids).sum(dim=1) return audio_embeddings def get_placeholder_mask(self, input_ids: torch.LongTensor): special_image_mask = (input_ids == self.config.visual_config.image_pad_token_id).squeeze(0) special_audio_mask = (input_ids == self.config.audio_config.audio_pad_token_id).squeeze(0) special_audio_text_start_mask = (input_ids == self.config.audio_config.audiotext_start_token_id).squeeze(0) special_audio_text_pad_mask = (input_ids == self.config.audio_config.audiotext_pad_token_id).squeeze(0) return special_image_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask class LongcatNextForCausalLM(LongcatFlashForCausalLM): _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"] _no_split_modules = [ "LongcatFlashDecoderLayer", "CasualDepthTransformerHead", ] config_class = LongcatNextConfig def __init__(self, config): super().__init__(config) self.config = config self.model = LongcatNextModel(config) self.lm_head = nn.Linear(config.hidden_size, config.text_vocab_plus_multimodal_special_token_size, bias=False) self.visual_head = CasualDepthTransformerHead( hidden_size=config.hidden_size, codebook_sizes=config.visual_config.vq_config.codebook_sizes, transformer_layer_num=config.visual_config.image_head_transformer_layers, transformer_dim=config.visual_config.image_head_transformer_dims, transformer_ffn_scale=config.visual_config.image_head_transformer_ffn_scale, ) self.audio_head = CasualDepthTransformerHead( hidden_size=config.hidden_size, codebook_sizes=config.audio_config.vq_config.codebook_sizes, transformer_layer_num=config.audio_config.audio_head_transformer_layers, transformer_dim=config.audio_config.audio_head_transformer_dims, transformer_ffn_scale=config.audio_config.audio_head_transformer_ffn_scale, ) self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, visual_inputs=None, visual_ids=None, audio_inputs=None, audio_ids=None, audio_text_ids=None, multimodal_generation_status: LongcatNextForCausalLMGenerationStatus = None, visual_generation_config: GenerationConfig = None, audio_generation_config: GenerationConfig = None, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" visual_inputs (`BatchFeature`, *optional*): Visual inputs returned by the processor, containing pixel values and grid metadata for image encoding. visual_ids (`torch.LongTensor` of shape `(num_visual_tokens, num_codebooks)`, *optional*): Quantized visual token ids from the visual tokenizer, used to build visual embeddings during generation. audio_inputs (`BatchFeature`, *optional*): Audio inputs returned by the processor, containing mel-spectrogram features and length metadata. audio_ids (`torch.LongTensor` of shape `(num_audio_tokens, num_codebooks)`, *optional*): Quantized audio token ids from the audio tokenizer, used to build audio embeddings during generation. audio_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Token ids for the audio text transcript generated alongside audio tokens. multimodal_generation_status (`LongcatNextForCausalLMGenerationStatus`, *optional*): Stateful object tracking the current multimodal generation mode (text / visual / audio) and associated counters used to route logits to the correct head during auto-regressive decoding. visual_generation_config (`GenerationConfig`, *optional*): Generation configuration for the visual head, controlling sampling parameters such as `temperature`, `top_k`, `top_p`, and custom parameters like `cfg_scale` and `anyres_config`. audio_generation_config (`GenerationConfig`, *optional*): Generation configuration for the audio head, controlling sampling parameters such as `temperature`, `top_k`, `top_p`, `repetition_penalty`, and `audio_parallel_decoding`. """ if multimodal_generation_status.mode == "visual" and visual_generation_config.custom_params["cfg_scale"] != 1.0 and input_ids.size(0) == 1: input_ids = input_ids.repeat((2, 1)) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, visual_inputs=visual_inputs, visual_ids=visual_ids, audio_inputs=audio_inputs, audio_ids=audio_ids, audio_text_ids=audio_text_ids, multimodal_generation_status=multimodal_generation_status, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_hidden_states = hidden_states[:, slice_indices, :] loss, logits = None, None if multimodal_generation_status.mode == "visual" and \ (not multimodal_generation_status.is_img_newline) and (not multimodal_generation_status.is_img_end): visual_ids = self.get_multimodal_logits_and_ids( self.visual_head, visual_ids, slice_hidden_states, self.model.embed_tokens, self.config.visual_config.vq_config.codebook_sizes, self.model.visual_offset_vals, visual_generation_config, ) else: logits = self.lm_head(slice_hidden_states) if multimodal_generation_status.mode == "audio" and multimodal_generation_status.is_audio_start: audio_ids = self.get_multimodal_logits_and_ids( self.audio_head, audio_ids, slice_hidden_states, self.model.embed_tokens, self.config.audio_config.vq_config.codebook_sizes, self.model.audio_offset_vals, audio_generation_config, ) return LongcatNextForCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, visual_ids=visual_ids, audio_ids=audio_ids, ) def get_multimodal_logits_and_ids( self, head_model, multimodal_ids, hidden_states, multimodal_embedding_layer, codebook_sizes, offset_vals, multimodal_generation_config, ): next_token_ids = torch.zeros(hidden_states.size(0), len(codebook_sizes), dtype=torch.long, device=hidden_states.device) multimodal_embedding_layer = multimodal_embedding_layer.to(hidden_states.device) for level, _ in enumerate(codebook_sizes): logits = head_model(hidden_states, next_token_ids, multimodal_embedding_layer, level) # -> (bs, 1, dim) next_token_id = self.inner_sample(logits, multimodal_ids[None, :, level]-offset_vals[level], multimodal_generation_config) # (bs, 1) next_token_id += offset_vals[level] next_token_ids[:, level] = next_token_id return next_token_ids[:1] def inner_sample( self, next_token_logits: torch.Tensor, multimodal_ids: torch.LongTensor, generation_config: GenerationConfig, ) -> torch.Tensor: logits_processor = self._get_logits_processor(generation_config) if "cfg_scale" in generation_config.custom_params and generation_config.custom_params["cfg_scale"] != 1.0: cond_logits, uncond_logits = next_token_logits.chunk(2, dim=0) next_token_logits = generation_config.custom_params["cfg_scale"] * (cond_logits - uncond_logits) + uncond_logits next_token_scores = logits_processor(multimodal_ids, next_token_logits.to(multimodal_ids.device)) if generation_config.do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) return next_tokens @torch.no_grad() def generate(self, inputs=None, **kwargs): """Override to ensure NgramCache is used.""" if "past_key_values" not in kwargs or kwargs["past_key_values"] is None: kwargs["past_key_values"] = NgramCache(config=self.config) return super().generate( inputs=inputs, **kwargs, ) def prepare_inputs_for_generation( self, input_ids, visual_ids, audio_ids, audio_text_ids, multimodal_generation_status, generation_config, attention_mask, cache_position, **kwargs, ): extra_new_tokens = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device) if visual_ids is None: visual_ids = torch.empty(0, len(self.config.visual_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device) if audio_ids is None: audio_ids = torch.empty(0, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device) if audio_text_ids is None: audio_text_ids = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device) def insert_ids(new_ids, _input_ids, _attention_mask, _cache_position, position=0): if position < 0: parts = [_input_ids[:, :position], new_ids, _input_ids[:, position:]] else: parts = [_input_ids, new_ids] _input_ids = torch.cat(parts, dim=1) insert_len = new_ids.size(1) _attention_mask = F.pad(_attention_mask, (0, insert_len), value=1) insert_position = _cache_position[-1] + 1 + torch.arange(insert_len, device=_cache_position.device) _cache_position = torch.cat([_cache_position, insert_position]) return _input_ids, _attention_mask, _cache_position # multimodal generation status change if cache_position[0] != 0: multimodal_generation_status.last_step_mode = multimodal_generation_status.mode if multimodal_generation_status.mode == "visual": multimodal_generation_status.current_image_token_num += 1 if (input_ids[:, -1] == self.config.visual_config.image_start_token_id).all(): multimodal_generation_status.switch_to("visual") anyres_prefix_ids = self.text_tokenizer.encode(multimodal_generation_status.anyres_prefix, return_tensors="pt") anyres_prefix_ids = anyres_prefix_ids.to(input_ids.device) extra_new_tokens = torch.cat([extra_new_tokens, anyres_prefix_ids], dim=1) input_ids, attention_mask, cache_position = insert_ids(anyres_prefix_ids, input_ids, attention_mask, cache_position, position=-1) if input_ids.size(0) == 1: # cfg, change bs=1 -> 2 input_ids = input_ids.repeat((2, input_ids.size(1))) input_ids[1, :-(anyres_prefix_ids.size(-1)+1)] = 0 print(f"change to cfg, input_ids: {input_ids}") attention_mask = attention_mask.repeat((2, attention_mask.size(1))) elif (input_ids[:, -1] == self.config.audio_config.audiogen_start_token_id).all(): multimodal_generation_status.switch_to("audio") elif (input_ids[:, -1] == self.config.audio_config.audiotext_start_token_id).all(): multimodal_generation_status.is_audio_start = True elif ((input_ids[:, -1] == self.config.visual_config.image_end_token_id) | (input_ids[:, -1] == self.config.audio_config.audiogen_end_token_id)).all(): multimodal_generation_status.switch_to("text") model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, visual_ids=visual_ids, audio_ids=audio_ids, audio_text_ids=audio_text_ids, attention_mask=attention_mask, cache_position=cache_position, **kwargs, ) if model_inputs["cache_position"][0] != 0: model_inputs["visual_inputs"] = None model_inputs["audio_inputs"] = None return model_inputs, multimodal_generation_status, extra_new_tokens def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, visual_ids=None, audio_ids=None, audio_text_ids=None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed to avoid deadlocking with `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape[:2] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) model_forward = self.__call__ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) if compile_forward: os.environ["TOKENIZERS_PARALLELISM"] = "0" # If we use FA2 and a static cache, we cannot compile with fullgraph if self.config._attn_implementation == "flash_attention_2": # only raise warning if the user passed an explicit compile-config if generation_config.compile_config is not None and generation_config.compile_config.fullgraph: logger.warning_once( "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as " "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`." ) generation_config.compile_config.fullgraph = False model_forward = self.get_compiled_call(generation_config.compile_config) if generation_config.prefill_chunk_size is not None: model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) is_prefill = False else: is_prefill = True visual_generation_config = GenerationConfig(**generation_config.visual_generation_config) audio_generation_config = GenerationConfig(**generation_config.audio_generation_config) multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(visual_generation_config, audio_generation_config) pbar = tqdm(iter(int, 1), desc="Generating", unit="tok") while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs, multimodal_generation_status, extra_new_tokens = self.prepare_inputs_for_generation( input_ids, visual_ids, audio_ids, audio_text_ids, multimodal_generation_status, generation_config, **model_kwargs, ) if extra_new_tokens.size(1) > 0: input_ids = torch.cat([input_ids[:, :-1], extra_new_tokens, input_ids[:, -1:]], dim=1) model_kwargs["attention_mask"] = model_inputs["attention_mask"] model_kwargs["cache_position"] = model_inputs["cache_position"] if multimodal_generation_status.mode == "text" and multimodal_generation_status.last_step_mode == "visual": next_tokens = generation_config._eos_token_tensor input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) break visual_ids = model_inputs["visual_ids"] audio_ids = model_inputs["audio_ids"] audio_text_ids = model_inputs["audio_text_ids"] if is_prefill: outputs = self(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config) is_prefill = False else: outputs = model_forward(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, num_new_tokens=1, ) if synced_gpus and this_peer_finished: continue # multimodal generation if multimodal_generation_status.mode == "text" or \ (multimodal_generation_status.mode == "audio" and not multimodal_generation_status.is_audio_text_end): # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # audio_text_ids done if multimodal_generation_status.mode == "audio" and (next_tokens == self.config.audio_config.audiotext_pad_token_id).all(): multimodal_generation_status.is_audio_text_end = True elif multimodal_generation_status.mode == "visual": if multimodal_generation_status.is_img_end: next_tokens = self.model.image_end_token_id.to(input_ids.device) elif multimodal_generation_status.is_img_newline: next_tokens = self.model.image_newline_token_id.to(input_ids.device) else: visual_ids = torch.cat([visual_ids, outputs.visual_ids], dim=0) # [seq, lev] next_tokens = self.model.image_pad_token_id.to(input_ids.device) else: # mode == "audio" and multimodal_generation_status.is_audio_text_end next_tokens = self.model.audio_pad_token_id.to(input_ids.device) if multimodal_generation_status.mode == "audio": # audio_text_ids update audio_text_next_tokens = self.model.audiotext_pad_token_id.to(input_ids.device) if not multimodal_generation_status.is_audio_text_end: audio_text_next_tokens, next_tokens = next_tokens, audio_text_next_tokens audio_text_ids = torch.cat((audio_text_ids, audio_text_next_tokens[:, None]), dim=1) # audio_ids update if multimodal_generation_status.is_audio_start: if outputs.audio_ids[-1, 0] == (self.model.audio_offset_vals[1]): # offset + (level_1_len) next_tokens = self.model.audiogen_end_token_id.to(input_ids.device) else: next_tokens = self.model.audio_pad_token_id.to(input_ids.device) audio_ids = torch.cat([audio_ids, outputs.audio_ids], dim=0) elif (multimodal_generation_status.audio_parallel_decoding) or \ (not multimodal_generation_status.audio_parallel_decoding and multimodal_generation_status.is_audio_text_end): next_tokens = self.model.audiotext_start_token_id.to(input_ids.device) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # TODO: streaming mm ids if streamer is not None: streamer.put(next_tokens.cpu()) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 cur_len += 1 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs pbar.update(1) pbar.set_postfix({ "recent_5toks": f"{input_ids[:, -5:].tolist()}", }) pbar.close() if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return LongcatNextForCausalLMGenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), visual_ids=visual_ids, audio_ids=audio_ids, audio_text_ids=audio_text_ids, ) else: return LongcatNextForCausalLMGenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), visual_ids=visual_ids, audio_ids=audio_ids, audio_text_ids=audio_text_ids, ) else: return input_ids, visual_ids, audio_ids, audio_text_ids __all__ = ["LongcatNextModel", "LongcatNextForCausalLM"]