Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import logging | |
| import math | |
| import sys | |
| import warnings | |
| from abc import abstractmethod | |
| from collections import defaultdict | |
| from functools import partial | |
| from typing import ( | |
| Callable, | |
| Dict, | |
| Iterable, | |
| List, | |
| NamedTuple, | |
| Optional, | |
| Sequence, | |
| Set, | |
| Tuple, | |
| cast, | |
| ) | |
| from dataclasses import fields | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.backends.cuda | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import einsum | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.models.auto import AutoModel, AutoConfig, AutoModelForCausalLM | |
| from transformers.cache_utils import Cache | |
| from PIL import Image | |
| from .configuration_llada import ( | |
| LLaDAConfig, | |
| StrEnum, | |
| InitFnType, | |
| ActivationType, | |
| BlockType, | |
| LayerNormType, | |
| ModelConfig, | |
| ActivationCheckpointingStrategy, | |
| ) | |
| from .modeling_llada import LLaDAModelLM | |
| from .modeling_video_encoder import VideoEncoder | |
| from .sampling import cosine_schedule, mask_by_random_topk | |
| from transformers import PretrainedConfig | |
| def calculate_mmu_style_loss(logits_batch, labels_batch, masked_indices_batch, p_mask, answer_lengths, output_size, device): | |
| if logits_batch.shape[0] == 0: | |
| return logits_batch.new_zeros(()) | |
| p_mask_flat = p_mask.to(device)[masked_indices_batch] | |
| p_mask_flat = torch.clamp(p_mask_flat, min=1e-4) | |
| answer_lengths_flat = answer_lengths.to(device)[masked_indices_batch] | |
| answer_lengths_flat = torch.clamp(answer_lengths_flat, min=1) | |
| loss = F.cross_entropy( | |
| logits_batch[masked_indices_batch].contiguous().view(-1, output_size), | |
| labels_batch[masked_indices_batch].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| ) / p_mask_flat | |
| loss = torch.sum(loss / answer_lengths_flat) / logits_batch.shape[0] | |
| return loss | |
| def calculate_t2s_loss( | |
| logits_batch, | |
| labels_batch, | |
| masked_indices_batch, | |
| p_mask, | |
| answer_lengths, | |
| vocab_start, | |
| codebook_size, | |
| eoa_token_id, | |
| eos_token_id, | |
| device, | |
| ignore_index=-100, | |
| ): | |
| if logits_batch.shape[0] == 0: | |
| return logits_batch.new_zeros(()) | |
| selected_logits = logits_batch[masked_indices_batch] | |
| selected_labels = labels_batch[masked_indices_batch].to(torch.long) | |
| if selected_logits.shape[0] == 0: | |
| return logits_batch.new_zeros(()) | |
| work_dtype = torch.float32 | |
| selected_logits_fp32 = selected_logits.to(dtype=work_dtype) | |
| speech_logits = selected_logits_fp32[:, vocab_start : vocab_start + codebook_size] | |
| eoa_logits = selected_logits_fp32[:, eoa_token_id : eoa_token_id + 1] | |
| eos_logits = selected_logits_fp32[:, eos_token_id : eos_token_id + 1] | |
| combined_logits = torch.cat([speech_logits, eoa_logits, eos_logits], dim=-1) | |
| p_mask_flat = p_mask.to(device=device, dtype=work_dtype)[masked_indices_batch] | |
| p_mask_flat = torch.clamp(p_mask_flat, min=1e-4) | |
| answer_lengths_flat = answer_lengths.to(device=device, dtype=work_dtype)[masked_indices_batch] | |
| answer_lengths_flat = torch.clamp(answer_lengths_flat, min=1.0) | |
| relative_labels = torch.full_like(selected_labels, ignore_index) | |
| audio_mask = (selected_labels >= vocab_start) & (selected_labels < vocab_start + codebook_size) | |
| relative_labels[audio_mask] = selected_labels[audio_mask] - vocab_start | |
| relative_labels[selected_labels == eoa_token_id] = codebook_size | |
| relative_labels[selected_labels == eos_token_id] = codebook_size + 1 | |
| loss_vec = F.cross_entropy( | |
| combined_logits, | |
| relative_labels, | |
| ignore_index=ignore_index, | |
| reduction='none' | |
| ) | |
| loss_vec = loss_vec / p_mask_flat | |
| loss_vec = loss_vec / answer_lengths_flat | |
| loss = torch.sum(loss_vec) / logits_batch.shape[0] | |
| return loss.to(dtype=logits_batch.dtype) | |
| def add_gumbel_noise(logits, temperature): | |
| ''' | |
| The Gumbel max is a method for sampling categorical distributions. | |
| According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. | |
| Thus, we use float64. | |
| ''' | |
| if temperature == 0: | |
| return logits | |
| logits = logits.to(torch.float64) | |
| noise = torch.rand_like(logits, dtype=torch.float64) | |
| gumbel_noise = (- torch.log(noise)) ** temperature | |
| return logits.exp() / gumbel_noise | |
| def get_num_transfer_tokens(mask_index, steps): | |
| ''' | |
| In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. | |
| Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), | |
| the expected number of tokens transitioned at each step should be consistent. | |
| This function is designed to precompute the number of tokens that need to be transitioned at each step. | |
| ''' | |
| mask_num = mask_index.sum(dim=1, keepdim=True) | |
| base = mask_num // steps | |
| remainder = mask_num % steps | |
| num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base | |
| for i in range(mask_num.size(0)): | |
| num_transfer_tokens[i, :remainder[i]] += 1 | |
| return num_transfer_tokens | |
| class OMadaConfig(PretrainedConfig): | |
| model_type = "omada" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| allowed_keys = [ | |
| "vocab_size", | |
| "llm_vocab_size", | |
| "llm_model_path", | |
| "codebook_size", | |
| "num_vq_tokens", | |
| "num_new_special_tokens", | |
| "gradient_checkpointing", | |
| "new_vocab_size", | |
| ] | |
| for key in allowed_keys: | |
| if key in kwargs: | |
| setattr(self, key, kwargs[key]) | |
| class OMadaModelLM(LLaDAModelLM): | |
| config_class = OMadaConfig | |
| base_model_prefix = "model" | |
| def __init__(self, config: OMadaConfig, *args, **kwargs): | |
| print(f"Initializing OMadaModelLM with config: {config}") | |
| super().__init__(config, *args, **kwargs) | |
| # # resize token embeddings | |
| # print(f"Resizing token embeddings to {config.new_vocab_size}") | |
| # self.resize_token_embeddings(config.new_vocab_size) | |
| def t2i_generate( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, # ideal number of steps is 18 in maskgit paper | |
| guidance_scale=0, | |
| noise_schedule=cosine_schedule, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=1024, | |
| mask_token_id = 126336, | |
| resolution = 512, | |
| codebook_size = 8192, | |
| **kwargs, | |
| ): | |
| """ | |
| Generate 1:1 similar to the original MaskGit repo | |
| https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 | |
| """ | |
| # begin with all image token ids masked | |
| # 计算有多少个mask token | |
| mask_count = (input_ids == mask_token_id).sum().item() | |
| num_vq_tokens = seq_len | |
| num_new_special_tokens = 0 | |
| uni_prompting = kwargs.get("uni_prompting", None) | |
| # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") | |
| input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() | |
| input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) | |
| # for classifier-free guidance | |
| if uncond_input_ids is not None: | |
| uncond_prefix = uncond_input_ids[:, :resolution + 1] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat( | |
| [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) | |
| attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(model_input, attention_bias=attention_bias).logits | |
| # print(f"logits.shape: {logits.shape}") | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| # it seems that muse has a different cfg setting | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| else: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| # logits: 1, 1024, 8192 | |
| # print(f"logits.shape: {logits.shape}") | |
| probs = logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, logits.size(-1)) | |
| # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 | |
| unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
| # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") | |
| sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
| # Defines the mask ratio for the next round. The number to mask out is | |
| # determined by mask_ratio * unknown_number_in_the_beginning. | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio)) | |
| # Computes the probabilities of each selected tokens. | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) | |
| selected_probs = selected_probs.squeeze(-1) | |
| # Ignores the tokens given in the input by overwriting their confidence. | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| # Gets mask lens for each sample in the batch according to the mask ratio. | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| # Keeps at least one of prediction in this round and also masks out at least | |
| # one and for the next iteration | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
| ) | |
| # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") | |
| # Adds noise for randomness | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| # Masks tokens with lower confidence. | |
| input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, | |
| sampled_ids + len(uni_prompting.text_tokenizer) | |
| + num_new_special_tokens) | |
| input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
| return sampled_ids | |
| def t2s_generate( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, | |
| guidance_scale=0, | |
| noise_schedule=None, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=256, | |
| mask_token_id=126336, | |
| **kwargs, | |
| ): | |
| uni_prompting = kwargs.get("uni_prompting", None) | |
| if uni_prompting is None: | |
| raise ValueError("uni_prompting object must be provided in kwargs.") | |
| eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() | |
| eos_token_id = uni_prompting.text_tokenizer.eos_token_id | |
| num_vq_tokens = (input_ids == mask_token_id).sum(dim=-1).max().item() | |
| if num_vq_tokens == 0: | |
| raise ValueError("No mask tokens found in input_ids.") | |
| speech_vocab_start_idx = len(uni_prompting.text_tokenizer) + 8192 | |
| speech_vocab_end_idx = speech_vocab_start_idx + 4096 | |
| # VQ Codes: 0 ~ 4095 | |
| # EOA: 4096 | |
| # EOS: 4097 | |
| vq_code_relative_eoa_id = 4096 | |
| vq_code_relative_eos_id = 4097 | |
| input_ids_relative = input_ids[:, -(num_vq_tokens):].clone() | |
| input_ids_relative = torch.where( | |
| input_ids_relative == mask_token_id, | |
| mask_token_id, | |
| input_ids_relative - speech_vocab_start_idx | |
| ) | |
| if uncond_input_ids is not None: | |
| start_gen_idx = (uncond_input_ids[0] == uni_prompting.sptids_dict['<|soa|>'][0].item()).nonzero(as_tuple=True)[0][0].item() + 1 | |
| uncond_prefix = uncond_input_ids[:, :start_gen_idx] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat([uncond_prefix, input_ids[:, start_gen_idx:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) | |
| attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(model_input, attention_bias=attention_bias).logits | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| else: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| logits_vq = logits[:, -(num_vq_tokens):, speech_vocab_start_idx:speech_vocab_end_idx] | |
| logits_eoa = logits[:, -(num_vq_tokens):, eoa_token_id:eoa_token_id+1] | |
| logits_eos = logits[:, -(num_vq_tokens):, eos_token_id:eos_token_id+1] | |
| combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) | |
| probs = combined_logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, combined_logits.size(-1)) | |
| sampled_ids_relative = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*combined_logits.shape[:-1]) | |
| unknown_map = input_ids_relative == mask_token_id | |
| sampled_ids_relative = torch.where(unknown_map, sampled_ids_relative, input_ids_relative) | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio, device=logits.device)) | |
| selected_probs = torch.gather(probs, -1, sampled_ids_relative.long()[..., None]).squeeze(-1) | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), | |
| torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
| ) | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| input_ids[:, -(num_vq_tokens):] = torch.where( | |
| masking, | |
| mask_token_id, | |
| torch.where( | |
| sampled_ids_relative == vq_code_relative_eos_id, | |
| eos_token_id, | |
| torch.where( | |
| sampled_ids_relative == vq_code_relative_eoa_id, | |
| eoa_token_id, | |
| sampled_ids_relative + speech_vocab_start_idx | |
| ) | |
| ) | |
| ) | |
| input_ids_relative = torch.where(masking, mask_token_id, sampled_ids_relative) | |
| # print("--- Generation Loop Finished ---") | |
| # print("Final sequence BEFORE post-processing (relative IDs):") | |
| # print(input_ids_relative[0]) | |
| # print(f"Shape: {input_ids_relative.shape}") | |
| # print("---------------------------------") | |
| final_output_ids = [] | |
| for i in range(input_ids_relative.shape[0]): | |
| seq = input_ids_relative[i] | |
| eoa_indices = (seq >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] | |
| if eoa_indices.numel() > 0: | |
| first_eoa_idx = eoa_indices[0] | |
| seq = seq[:first_eoa_idx] | |
| valid_tokens = seq[seq != mask_token_id] | |
| final_output_ids.append(valid_tokens) | |
| return final_output_ids | |
| def t2s_generate_mmu_like( | |
| self, | |
| input_ids: torch.LongTensor, | |
| max_new_tokens: Optional[int] = None, | |
| steps: int = 256, | |
| block_length: int = 128, | |
| temperature: float = 0.0, | |
| cfg_scale: float = 0.0, | |
| mask_token_id: int = 126336, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| uni_prompting=None, | |
| codebook_size: Optional[int] = None, | |
| audio_codebook_size: int = 4096, | |
| ): | |
| """ | |
| Generate speech tokens with MMU-style block-wise refinement. | |
| Assumes the speech region within ``input_ids`` is contiguous and filled with ``mask_token_id`` | |
| prior to generation. | |
| """ | |
| if uni_prompting is None: | |
| raise ValueError("uni_prompting must be provided") | |
| if block_length <= 0: | |
| raise ValueError("block_length must be positive") | |
| batch_size, seq_len = input_ids.shape | |
| device = input_ids.device | |
| mask_positions_full = (input_ids == mask_token_id) | |
| if not mask_positions_full.any(): | |
| raise ValueError("No mask tokens detected for T2S generation") | |
| mask_cols = torch.where(mask_positions_full[0])[0] | |
| speech_region_start = mask_cols[0].item() | |
| speech_region_len = mask_cols.numel() | |
| mask_counts = mask_positions_full.sum(dim=1) | |
| if not torch.all(mask_counts == mask_counts[0]): | |
| raise ValueError("All batch items must contain the same number of masked speech tokens for MMU-like generation") | |
| if max_new_tokens is None: | |
| max_new_tokens = speech_region_len | |
| else: | |
| max_new_tokens = min(max_new_tokens, speech_region_len) | |
| block_length = max(1, min(block_length, max_new_tokens)) | |
| num_blocks = math.ceil(max_new_tokens / block_length) | |
| inner_steps = max(1, steps // num_blocks) | |
| codebook_base = codebook_size if codebook_size is not None else getattr(self.config, "codebook_size", 8192) | |
| speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_base | |
| speech_vocab_end = speech_vocab_start + audio_codebook_size | |
| eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() | |
| eos_token_id = uni_prompting.text_tokenizer.eos_token_id | |
| vq_code_relative_eoa_id = audio_codebook_size | |
| vq_code_relative_eos_id = audio_codebook_size + 1 | |
| work = input_ids.clone() | |
| attention_bias = None | |
| if attention_mask is not None: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| speech_indices = mask_cols[:max_new_tokens] | |
| for block_idx in range(num_blocks): | |
| block_start = block_idx * block_length | |
| block_end = min(block_start + block_length, max_new_tokens) | |
| curr_indices = speech_indices[block_start:block_end] | |
| if curr_indices.numel() == 0: | |
| continue | |
| block_mask = mask_positions_full[:, curr_indices] | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask, inner_steps) | |
| for inner_step in range(inner_steps): | |
| if cfg_scale > 0.0: | |
| un_cond = work.clone() | |
| un_cond[:, speech_indices] = mask_token_id | |
| stacked = torch.cat([work, un_cond], dim=0) | |
| if attention_bias is not None: | |
| att_bias = torch.cat([attention_bias, attention_bias], dim=0) | |
| else: | |
| att_bias = None | |
| logits = self(stacked, attention_bias=att_bias).logits | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) | |
| else: | |
| logits = self(work, attention_bias=attention_bias).logits | |
| logits_block = logits.index_select(1, curr_indices.to(device)) | |
| logits_vq = logits_block[:, :, speech_vocab_start:speech_vocab_end] | |
| logits_eoa = logits_block[:, :, eoa_token_id:eoa_token_id + 1] | |
| logits_eos = logits_block[:, :, eos_token_id:eos_token_id + 1] | |
| combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) | |
| if temperature > 0.0: | |
| combined_logits = combined_logits / max(temperature, 1e-5) | |
| probs = F.softmax(combined_logits, dim=-1) | |
| sampled = torch.multinomial( | |
| probs.view(-1, probs.size(-1)), 1 | |
| ).view(batch_size, curr_indices.numel()) | |
| selected_probs = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) | |
| eos_tensor = sampled.new_full(sampled.shape, eos_token_id) | |
| eoa_tensor = sampled.new_full(sampled.shape, eoa_token_id) | |
| sampled_absolute = torch.where( | |
| sampled == vq_code_relative_eos_id, | |
| eos_tensor, | |
| torch.where( | |
| sampled == vq_code_relative_eoa_id, | |
| eoa_tensor, | |
| sampled + speech_vocab_start | |
| ) | |
| ) | |
| current_block_vals = work.index_select(1, curr_indices) | |
| mask_current = current_block_vals == mask_token_id | |
| confidence = torch.where( | |
| mask_current, | |
| selected_probs, | |
| torch.full_like(selected_probs, float('-inf')) | |
| ) | |
| finalize = torch.zeros_like(mask_current, dtype=torch.bool) | |
| for b in range(batch_size): | |
| available = mask_current[b].sum().item() | |
| if available == 0: | |
| continue | |
| transfer = min(int(num_transfer_tokens[b, inner_step].item()), available) | |
| if transfer <= 0: | |
| continue | |
| _, idxs = torch.topk(confidence[b], k=transfer, largest=True) | |
| finalize[b, idxs] = True | |
| mask_fill = sampled_absolute.new_full(sampled_absolute.shape, mask_token_id) | |
| updates = torch.where(finalize, sampled_absolute, mask_fill) | |
| new_block = torch.where(mask_current, updates, current_block_vals) | |
| work[:, curr_indices] = new_block | |
| mask_positions_full[:, curr_indices] = new_block == mask_token_id | |
| if not mask_positions_full[:, curr_indices].any(): | |
| break | |
| final_outputs = [] | |
| audio_slice = slice(speech_region_start, speech_region_start + speech_region_len) | |
| audio_region = work[:, audio_slice] | |
| for seq in audio_region: | |
| mask_tensor = seq.new_full(seq.shape, mask_token_id) | |
| rel_eoa = seq.new_full(seq.shape, vq_code_relative_eoa_id) | |
| rel_eos = seq.new_full(seq.shape, vq_code_relative_eos_id) | |
| relative = torch.where( | |
| seq == mask_token_id, | |
| mask_tensor, | |
| torch.where( | |
| seq == eoa_token_id, | |
| rel_eoa, | |
| torch.where( | |
| seq == eos_token_id, | |
| rel_eos, | |
| seq - speech_vocab_start | |
| ) | |
| ) | |
| ) | |
| eoa_positions = (relative >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] | |
| if eoa_positions.numel() > 0: | |
| relative = relative[:eoa_positions[0]] | |
| final_outputs.append(relative[relative != mask_token_id]) | |
| return final_outputs | |
| def t2s_generate_mmu_like_stream( | |
| self, | |
| input_ids: torch.LongTensor, | |
| max_new_tokens: Optional[int] = None, | |
| steps: int = 256, | |
| block_length: int = 128, | |
| temperature: float = 0.0, | |
| cfg_scale: float = 0.0, | |
| mask_token_id: int = 126336, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| uni_prompting=None, | |
| codebook_size: Optional[int] = None, | |
| audio_codebook_size: int = 4096, | |
| update_every: Optional[int] = None, | |
| ): | |
| """ | |
| Stream speech token generation. Yields intermediate token lists. | |
| """ | |
| if uni_prompting is None: | |
| raise ValueError("uni_prompting must be provided") | |
| if block_length <= 0: | |
| raise ValueError("block_length must be positive") | |
| batch_size, seq_len = input_ids.shape | |
| device = input_ids.device | |
| mask_positions_full = (input_ids == mask_token_id) | |
| if not mask_positions_full.any(): | |
| raise ValueError("No mask tokens detected for T2S generation") | |
| mask_cols = torch.where(mask_positions_full[0])[0] | |
| speech_region_start = mask_cols[0].item() | |
| speech_region_len = mask_cols.numel() | |
| mask_counts = mask_positions_full.sum(dim=1) | |
| if not torch.all(mask_counts == mask_counts[0]): | |
| raise ValueError("All batch items must contain the same number of masked speech tokens for MMU-like generation") | |
| if max_new_tokens is None: | |
| max_new_tokens = speech_region_len | |
| else: | |
| max_new_tokens = min(max_new_tokens, speech_region_len) | |
| block_length = max(1, min(block_length, max_new_tokens)) | |
| num_blocks = math.ceil(max_new_tokens / block_length) | |
| inner_steps = max(1, steps // num_blocks) | |
| codebook_base = codebook_size if codebook_size is not None else getattr(self.config, "codebook_size", 8192) | |
| speech_vocab_start = len(uni_prompting.text_tokenizer) + codebook_base | |
| speech_vocab_end = speech_vocab_start + audio_codebook_size | |
| eoa_token_id = uni_prompting.sptids_dict['<|eoa|>'][0].item() | |
| eos_token_id = uni_prompting.text_tokenizer.eos_token_id | |
| vq_code_relative_eoa_id = audio_codebook_size | |
| vq_code_relative_eos_id = audio_codebook_size + 1 | |
| work = input_ids.clone() | |
| attention_bias = None | |
| if attention_mask is not None: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| speech_indices = mask_cols[:max_new_tokens] | |
| total_steps = num_blocks * inner_steps | |
| global_step = 0 | |
| def _extract_relative_tokens(work_tensor: torch.Tensor): | |
| audio_slice = slice(speech_region_start, speech_region_start + speech_region_len) | |
| audio_region = work_tensor[:, audio_slice] | |
| final_outputs = [] | |
| for seq in audio_region: | |
| mask_tensor = seq.new_full(seq.shape, mask_token_id) | |
| rel_eoa = seq.new_full(seq.shape, vq_code_relative_eoa_id) | |
| rel_eos = seq.new_full(seq.shape, vq_code_relative_eos_id) | |
| relative = torch.where( | |
| seq == mask_token_id, | |
| mask_tensor, | |
| torch.where( | |
| seq == eoa_token_id, | |
| rel_eoa, | |
| torch.where( | |
| seq == eos_token_id, | |
| rel_eos, | |
| seq - speech_vocab_start | |
| ) | |
| ) | |
| ) | |
| eoa_positions = (relative >= vq_code_relative_eoa_id).nonzero(as_tuple=True)[0] | |
| if eoa_positions.numel() > 0: | |
| relative = relative[:eoa_positions[0]] | |
| final_outputs.append(relative[relative != mask_token_id]) | |
| return final_outputs | |
| for block_idx in range(num_blocks): | |
| block_start = block_idx * block_length | |
| block_end = min(block_start + block_length, max_new_tokens) | |
| curr_indices = speech_indices[block_start:block_end] | |
| if curr_indices.numel() == 0: | |
| continue | |
| block_mask = mask_positions_full[:, curr_indices] | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask, inner_steps) | |
| for inner_step in range(inner_steps): | |
| if cfg_scale > 0.0: | |
| un_cond = work.clone() | |
| un_cond[:, speech_indices] = mask_token_id | |
| stacked = torch.cat([work, un_cond], dim=0) | |
| if attention_bias is not None: | |
| att_bias = torch.cat([attention_bias, attention_bias], dim=0) | |
| else: | |
| att_bias = None | |
| logits = self(stacked, attention_bias=att_bias).logits | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) | |
| else: | |
| logits = self(work, attention_bias=attention_bias).logits | |
| logits_block = logits.index_select(1, curr_indices.to(device)) | |
| logits_vq = logits_block[:, :, speech_vocab_start:speech_vocab_end] | |
| logits_eoa = logits_block[:, :, eoa_token_id:eoa_token_id + 1] | |
| logits_eos = logits_block[:, :, eos_token_id:eos_token_id + 1] | |
| combined_logits = torch.cat([logits_vq, logits_eoa, logits_eos], dim=-1) | |
| if temperature > 0.0: | |
| combined_logits = combined_logits / max(temperature, 1e-5) | |
| probs = F.softmax(combined_logits, dim=-1) | |
| sampled = torch.multinomial( | |
| probs.view(-1, probs.size(-1)), 1 | |
| ).view(batch_size, curr_indices.numel()) | |
| selected_probs = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) | |
| eos_tensor = sampled.new_full(sampled.shape, eos_token_id) | |
| eoa_tensor = sampled.new_full(sampled.shape, eoa_token_id) | |
| sampled_absolute = torch.where( | |
| sampled == vq_code_relative_eos_id, | |
| eos_tensor, | |
| torch.where( | |
| sampled == vq_code_relative_eoa_id, | |
| eoa_tensor, | |
| sampled + speech_vocab_start | |
| ) | |
| ) | |
| current_block_vals = work.index_select(1, curr_indices) | |
| mask_current = current_block_vals == mask_token_id | |
| confidence = torch.where( | |
| mask_current, | |
| selected_probs, | |
| torch.full_like(selected_probs, float('-inf')) | |
| ) | |
| finalize = torch.zeros_like(mask_current, dtype=torch.bool) | |
| for b in range(batch_size): | |
| available = mask_current[b].sum().item() | |
| if available == 0: | |
| continue | |
| transfer = min(int(num_transfer_tokens[b, inner_step].item()), available) | |
| if transfer <= 0: | |
| continue | |
| _, idxs = torch.topk(confidence[b], k=transfer, largest=True) | |
| finalize[b, idxs] = True | |
| mask_fill = sampled_absolute.new_full(sampled_absolute.shape, mask_token_id) | |
| updates = torch.where(finalize, sampled_absolute, mask_fill) | |
| new_block = torch.where(mask_current, updates, current_block_vals) | |
| work[:, curr_indices] = new_block | |
| mask_positions_full[:, curr_indices] = new_block == mask_token_id | |
| global_step += 1 | |
| should_yield = False | |
| if update_every is not None and update_every > 0: | |
| if global_step % update_every == 0 or global_step == total_steps: | |
| should_yield = True | |
| else: | |
| if inner_step == inner_steps - 1 or global_step == total_steps: | |
| should_yield = True | |
| if should_yield: | |
| yield _extract_relative_tokens(work), f"Step {global_step}/{total_steps}" | |
| if not mask_positions_full[:, curr_indices].any(): | |
| break | |
| return | |
| def t2s_fixed_generate( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, | |
| guidance_scale=0, | |
| noise_schedule=None, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=256, | |
| mask_token_id=126336, | |
| **kwargs, | |
| ): | |
| """ | |
| Generate 1:1 similar to the original MaskGit repo | |
| https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 | |
| """ | |
| # begin with all image token ids masked | |
| # 计算有多少个mask token | |
| mask_count = (input_ids == mask_token_id).sum().item() | |
| num_vq_tokens = seq_len | |
| num_new_special_tokens = 0 | |
| uni_prompting = kwargs.get("uni_prompting", None) | |
| # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") | |
| input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() | |
| input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens - 8192) | |
| # for classifier-free guidance | |
| if uncond_input_ids is not None: | |
| start_gen_idx = (uncond_input_ids[0] == uni_prompting.sptids_dict['<|soa|>'][0].item()).nonzero(as_tuple=True)[0][0].item() + 1 | |
| uncond_prefix = uncond_input_ids[:, :start_gen_idx] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat( | |
| [uncond_prefix, input_ids[:, start_gen_idx:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) | |
| attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(model_input, attention_bias=attention_bias).logits | |
| # print(f"logits.shape: {logits.shape}") | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| # it seems that muse has a different cfg setting | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 : len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 + 4096] | |
| else: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 : len(uni_prompting.text_tokenizer) + num_new_special_tokens + 8192 + 4096] | |
| # logits: 1, 1024, 8192 | |
| # print(f"logits.shape: {logits.shape}") | |
| probs = logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, logits.size(-1)) | |
| # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 | |
| unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
| # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") | |
| sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
| # Defines the mask ratio for the next round. The number to mask out is | |
| # determined by mask_ratio * unknown_number_in_the_beginning. | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio)) | |
| # Computes the probabilities of each selected tokens. | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) | |
| selected_probs = selected_probs.squeeze(-1) | |
| # Ignores the tokens given in the input by overwriting their confidence. | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| # Gets mask lens for each sample in the batch according to the mask ratio. | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| # Keeps at least one of prediction in this round and also masks out at least | |
| # one and for the next iteration | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
| ) | |
| # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") | |
| # Adds noise for randomness | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| # Masks tokens with lower confidence. | |
| input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, | |
| sampled_ids + len(uni_prompting.text_tokenizer) | |
| + num_new_special_tokens + 8192) | |
| input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
| return sampled_ids | |
| def i2i_generate( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, # ideal number of steps is 18 in maskgit paper | |
| guidance_scale=0, | |
| noise_schedule=cosine_schedule, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=1024, | |
| mask_token_id = 126336, | |
| resolution = 512, | |
| codebook_size = 8192, | |
| **kwargs, | |
| ): | |
| """ | |
| Generate 1:1 similar to the original MaskGit repo | |
| https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 | |
| """ | |
| # begin with all image token ids masked | |
| # 计算有多少个mask token | |
| mask_count = (input_ids == mask_token_id).sum().item() | |
| num_vq_tokens = seq_len | |
| num_new_special_tokens = 0 | |
| uni_prompting = kwargs.get("uni_prompting", None) | |
| # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") | |
| input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() | |
| input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) | |
| # for classifier-free guidance | |
| if uncond_input_ids is not None: | |
| uncond_prefix = uncond_input_ids[:, :resolution + 1] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat( | |
| [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) | |
| attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(model_input, attention_bias=attention_bias).logits | |
| # print(f"logits.shape: {logits.shape}") | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| # it seems that muse has a different cfg setting | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| else: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| # logits: 1, 1024, 8192 | |
| # print(f"logits.shape: {logits.shape}") | |
| probs = logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, logits.size(-1)) | |
| # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 | |
| unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
| # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") | |
| sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
| # Defines the mask ratio for the next round. The number to mask out is | |
| # determined by mask_ratio * unknown_number_in_the_beginning. | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio)) | |
| # Computes the probabilities of each selected tokens. | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) | |
| selected_probs = selected_probs.squeeze(-1) | |
| # Ignores the tokens given in the input by overwriting their confidence. | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| # Gets mask lens for each sample in the batch according to the mask ratio. | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| # Keeps at least one of prediction in this round and also masks out at least | |
| # one and for the next iteration | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
| ) | |
| # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") | |
| # Adds noise for randomness | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| # Masks tokens with lower confidence. | |
| input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, | |
| sampled_ids + len(uni_prompting.text_tokenizer) | |
| + num_new_special_tokens) | |
| input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
| return sampled_ids | |
| # def forward_process( | |
| # self, | |
| # input_ids, | |
| # labels, | |
| # batch_size_t2i=0, | |
| # batch_size_lm=0, | |
| # batch_size_mmu=0, | |
| # batch_size_v2t=0, | |
| # batch_size_s2t=0, | |
| # batch_size_t2s=0, | |
| # max_seq_length=128, | |
| # p_mask_lm=None, | |
| # p_mask_mmu=None, | |
| # p_mask_vid=None, | |
| # p_mask_s2t=None, | |
| # p_mask_t2s=None, | |
| # answer_lengths=None, | |
| # t2i_masks=None, | |
| # answer_lengths_lm=None | |
| # ): | |
| # # attention bias, True for batch_size, 1, seq_len, seq_len | |
| # attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) | |
| # attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) | |
| # attention_bias[:batch_size_t2i] = attention_bias_t2i | |
| # logits = self(input_ids, attention_bias=attention_bias).logits | |
| # self.output_size = logits.shape[-1] | |
| # if batch_size_t2i == 0: | |
| # loss_t2i = torch.tensor(0.0, device=input_ids.device) | |
| # else: | |
| # loss_t2i = F.cross_entropy( | |
| # logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), | |
| # labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, | |
| # ) | |
| # masked_indices = input_ids == self.config.mask_token_id | |
| # masked_indices_lm = masked_indices[batch_size_t2i:batch_size_t2i + batch_size_lm] | |
| # masked_indices_mmu = masked_indices[-batch_size_mmu:] | |
| # p_mask_lm = p_mask_lm.to(masked_indices_lm.device) | |
| # p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) | |
| # answer_lengths = answer_lengths.to(masked_indices_mmu.device) | |
| # loss_lm = F.cross_entropy( | |
| # logits[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1, self.output_size), | |
| # labels[batch_size_t2i:batch_size_t2i + batch_size_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| # )/p_mask_lm[masked_indices_lm] | |
| # if answer_lengths_lm is not None: | |
| # loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0]) | |
| # else: | |
| # loss_lm = loss_lm.sum() / (logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[0] * logits[batch_size_t2i:batch_size_t2i + batch_size_lm].shape[1]) | |
| # loss_mmu = F.cross_entropy( | |
| # logits[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1, self.output_size), | |
| # labels[-batch_size_mmu:][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| # )/p_mask_mmu[masked_indices_mmu] | |
| # loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[-batch_size_mmu:].shape[0]) | |
| # return logits, loss_t2i, loss_lm, loss_mmu | |
| # def forward_process( | |
| # self, | |
| # input_ids, | |
| # labels, | |
| # batch_size_t2i=0, | |
| # batch_size_lm=0, | |
| # batch_size_mmu=0, | |
| # batch_size_v2t=0, | |
| # batch_size_s2t=0, | |
| # batch_size_t2s=0, | |
| # max_seq_length=128, | |
| # p_mask_lm=None, | |
| # p_mask_mmu=None, | |
| # p_mask_vid=None, | |
| # p_mask_s2t=None, | |
| # p_mask_t2s=None, | |
| # answer_lengths_lm=None, | |
| # answer_lengths_mmu=None, | |
| # answer_lengths_vid=None, | |
| # answer_lengths_s2t=None, | |
| # answer_lengths_t2s=None, | |
| # t2i_masks=None, | |
| # t2s_vocab_start=None, | |
| # t2s_codebook_size=None, | |
| # t2s_special_token_ids=None | |
| # ): | |
| # # --- 1. Attention Bias Setup (no changes) --- | |
| # attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| # if batch_size_t2i > 0 and t2i_masks is not None: | |
| # attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) | |
| # attention_bias[:batch_size_t2i] = attention_bias_t2i | |
| # # --- 2. Model Forward Pass (no changes) --- | |
| # logits = self(input_ids, attention_bias=attention_bias).logits | |
| # self.output_size = logits.shape[-1] | |
| # # --- 3. Loss Calculation --- | |
| # device = input_ids.device | |
| # zero_loss = torch.tensor(0.0, device=device) | |
| # # Calculate masked indices for the entire batch | |
| # masked_indices = (input_ids == self.config.mask_token_id) | |
| # current_idx = 0 | |
| # # --- T2I Loss --- | |
| # if batch_size_t2i > 0: | |
| # loss_t2i = F.cross_entropy( | |
| # logits[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), | |
| # labels[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, | |
| # ) | |
| # else: | |
| # loss_t2i = zero_loss | |
| # current_idx += batch_size_t2i | |
| # # --- LM Loss --- | |
| # if batch_size_lm > 0: | |
| # start, end = current_idx, current_idx + batch_size_lm | |
| # logits_lm, labels_lm = logits[start:end], labels[start:end] | |
| # masked_indices_lm = masked_indices[start:end] | |
| # loss_lm = F.cross_entropy( | |
| # logits_lm[masked_indices_lm].contiguous().view(-1, self.output_size), | |
| # labels_lm[masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| # ) / p_mask_lm.to(device)[masked_indices_lm] | |
| # if answer_lengths_lm is not None: | |
| # loss_lm = torch.sum(loss_lm / answer_lengths_lm.to(device)[masked_indices_lm]) / logits_lm.shape[0] | |
| # else: | |
| # loss_lm = loss_lm.sum() / logits_lm.shape[0] | |
| # else: | |
| # loss_lm = zero_loss | |
| # current_idx += batch_size_lm | |
| # # --- MMU Loss --- | |
| # if batch_size_mmu > 0: | |
| # start, end = current_idx, current_idx + batch_size_mmu | |
| # loss_mmu = calculate_mmu_style_loss( | |
| # logits[start:end], labels[start:end], masked_indices[start:end], | |
| # p_mask_mmu, answer_lengths_mmu, self.output_size, device | |
| # ) | |
| # else: | |
| # loss_mmu = zero_loss | |
| # current_idx += batch_size_mmu | |
| # # --- VID (V2T) Loss --- | |
| # if batch_size_v2t > 0: | |
| # start, end = current_idx, current_idx + batch_size_v2t | |
| # loss_vid = calculate_mmu_style_loss( | |
| # logits[start:end], labels[start:end], masked_indices[start:end], | |
| # p_mask_vid, answer_lengths_vid, self.output_size, device | |
| # ) | |
| # else: | |
| # loss_vid = zero_loss | |
| # current_idx += batch_size_v2t | |
| # # --- S2T Loss --- | |
| # if batch_size_s2t > 0: | |
| # start, end = current_idx, current_idx + batch_size_s2t | |
| # loss_s2t = calculate_mmu_style_loss( | |
| # logits[start:end], labels[start:end], masked_indices[start:end], | |
| # p_mask_s2t, answer_lengths_s2t, self.output_size, device | |
| # ) | |
| # else: | |
| # loss_s2t = zero_loss | |
| # current_idx += batch_size_s2t | |
| # # --- T2S Loss --- | |
| # if batch_size_t2s > 0: | |
| # start, end = current_idx, current_idx + batch_size_t2s | |
| # if ( | |
| # t2s_vocab_start is not None | |
| # and t2s_codebook_size is not None | |
| # and t2s_special_token_ids is not None | |
| # ): | |
| # eoa_id = t2s_special_token_ids.get('eoa') | |
| # eos_id = t2s_special_token_ids.get('eos') | |
| # else: | |
| # eoa_id = eos_id = None | |
| # if eoa_id is not None and eos_id is not None: | |
| # loss_t2s = calculate_t2s_loss( | |
| # logits[start:end], | |
| # labels[start:end], | |
| # masked_indices[start:end], | |
| # p_mask_t2s, | |
| # answer_lengths_t2s, | |
| # t2s_vocab_start, | |
| # t2s_codebook_size, | |
| # eoa_id, | |
| # eos_id, | |
| # device, | |
| # ignore_index=-100, | |
| # ) | |
| # else: | |
| # loss_t2s = calculate_mmu_style_loss( | |
| # logits[start:end], labels[start:end], masked_indices[start:end], | |
| # p_mask_t2s, answer_lengths_t2s, self.output_size, device | |
| # ) | |
| # else: | |
| # loss_t2s = zero_loss | |
| # current_idx += batch_size_t2s | |
| # return logits, loss_t2i, loss_lm, loss_mmu, loss_vid, loss_s2t, loss_t2s | |
| def forward_process( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_t2i=0, | |
| batch_size_i2i=0, | |
| batch_size_lm=0, | |
| batch_size_mmu=0, | |
| batch_size_v2t=0, | |
| batch_size_v2s=0, | |
| batch_size_s2t=0, | |
| batch_size_s2s=0, | |
| batch_size_t2s=0, | |
| max_seq_length=128, | |
| p_mask_lm=None, | |
| p_mask_mmu=None, | |
| p_mask_vid=None, | |
| p_mask_v2s=None, | |
| p_mask_s2t=None, | |
| p_mask_s2s=None, | |
| p_mask_t2s=None, | |
| answer_lengths_lm=None, | |
| answer_lengths_mmu=None, | |
| answer_lengths_vid=None, | |
| answer_lengths_v2s=None, | |
| answer_lengths_s2t=None, | |
| answer_lengths_s2s=None, | |
| answer_lengths_t2s=None, | |
| t2i_masks=None, | |
| attention_masks_i2i=None, | |
| t2s_vocab_start=None, | |
| t2s_codebook_size=None, | |
| t2s_special_token_ids=None, | |
| text_vocab_size_override=None | |
| ): | |
| # --- 1. Attention Bias Setup (no changes) --- | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| if batch_size_t2i > 0 and t2i_masks is not None: | |
| attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) | |
| attention_bias[:batch_size_t2i] = attention_bias_t2i | |
| if batch_size_i2i > 0 and attention_masks_i2i is not None: | |
| start_i2i = batch_size_t2i | |
| end_i2i = start_i2i + batch_size_i2i | |
| attn_mask = attention_masks_i2i.to(input_ids.device) | |
| if attn_mask.dtype != torch.bool: | |
| attn_mask = attn_mask.bool() | |
| attention_bias_i2i = (attn_mask[:, :, None] & attn_mask[:, None, :]).unsqueeze(1) | |
| attention_bias[start_i2i:end_i2i] = attention_bias_i2i | |
| # --- 2. Model Forward Pass (no changes) --- | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| self.output_size = logits.shape[-1] | |
| # --- 3. Loss Calculation --- | |
| device = input_ids.device | |
| zero_loss = torch.tensor(0.0, device=device) | |
| # Calculate masked indices for the entire batch | |
| masked_indices = (input_ids == self.config.mask_token_id) | |
| text_vocab_size = text_vocab_size_override | |
| image_vocab_size = getattr(self.config, "codebook_size", 0) | |
| image_vocab_start = text_vocab_size | |
| image_vocab_end = min(image_vocab_start + image_vocab_size, logits.shape[-1]) | |
| current_idx = 0 | |
| # --- T2I Loss --- | |
| if batch_size_t2i > 0: | |
| logits_t2i = logits[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:] | |
| labels_t2i = labels[current_idx:current_idx + batch_size_t2i, max_seq_length + 1:] | |
| if image_vocab_size <= 0: | |
| warnings.warn("t2i encountered non-positive image vocab size; skipping loss.") | |
| loss_t2i = zero_loss | |
| else: | |
| effective_vocab = image_vocab_end - image_vocab_start | |
| if effective_vocab <= 0: | |
| warnings.warn("t2i effective image vocab is invalid; skipping loss.") | |
| loss_t2i = zero_loss | |
| else: | |
| logits_slice = logits_t2i[..., image_vocab_start:image_vocab_end] | |
| labels_relative = torch.full_like(labels_t2i, -100) | |
| valid_mask = (labels_t2i >= image_vocab_start) & (labels_t2i < image_vocab_end) | |
| if not valid_mask.any(): | |
| warnings.warn("t2i labels contain no valid image tokens; skipping loss.") | |
| loss_t2i = zero_loss | |
| else: | |
| labels_relative[valid_mask] = labels_t2i[valid_mask] - image_vocab_start | |
| loss_t2i = F.cross_entropy( | |
| logits_slice.contiguous().view(-1, effective_vocab), | |
| labels_relative.contiguous().view(-1), | |
| ignore_index=-100, | |
| ) | |
| else: | |
| loss_t2i = zero_loss | |
| current_idx += batch_size_t2i | |
| # --- I2I Loss --- | |
| if batch_size_i2i > 0: | |
| if image_vocab_size <= 0: | |
| warnings.warn("i2i encountered non-positive image vocab size; skipping loss.") | |
| loss_i2i = zero_loss | |
| else: | |
| start, end = current_idx, current_idx + batch_size_i2i | |
| logits_i2i = logits[start:end] | |
| labels_i2i = labels[start:end] | |
| effective_vocab = image_vocab_end - image_vocab_start | |
| if effective_vocab <= 0: | |
| warnings.warn("i2i effective image vocab is invalid; skipping loss.") | |
| loss_i2i = zero_loss | |
| else: | |
| logits_slice = logits_i2i[..., image_vocab_start:image_vocab_end] | |
| labels_relative = torch.full_like(labels_i2i, -100) | |
| image_mask = (labels_i2i >= image_vocab_start) & (labels_i2i < image_vocab_end) | |
| if not image_mask.any(): | |
| warnings.warn("i2i labels contain no valid image tokens; skipping loss.") | |
| loss_i2i = zero_loss | |
| else: | |
| labels_relative[image_mask] = labels_i2i[image_mask] - image_vocab_start | |
| loss_i2i = F.cross_entropy( | |
| logits_slice.contiguous().view(-1, effective_vocab), | |
| labels_relative.contiguous().view(-1), | |
| ignore_index=-100, | |
| ) | |
| else: | |
| loss_i2i = zero_loss | |
| current_idx += batch_size_i2i | |
| # --- LM Loss --- | |
| if batch_size_lm > 0: | |
| start, end = current_idx, current_idx + batch_size_lm | |
| logits_lm, labels_lm = logits[start:end], labels[start:end] | |
| masked_indices_lm = masked_indices[start:end] | |
| selected_logits_lm = logits_lm[masked_indices_lm] | |
| effective_vocab_lm = selected_logits_lm.shape[-1] | |
| if text_vocab_size and text_vocab_size < self.output_size: | |
| effective_vocab_lm = min(text_vocab_size, selected_logits_lm.shape[-1]) | |
| selected_logits_lm = selected_logits_lm[:, :effective_vocab_lm] | |
| loss_lm = F.cross_entropy( | |
| selected_logits_lm.contiguous().view(-1, effective_vocab_lm), | |
| labels_lm[masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| ) / p_mask_lm.to(device)[masked_indices_lm] | |
| if answer_lengths_lm is not None: | |
| loss_lm = torch.sum(loss_lm / answer_lengths_lm.to(device)[masked_indices_lm]) / logits_lm.shape[0] | |
| else: | |
| loss_lm = loss_lm.sum() / logits_lm.shape[0] | |
| else: | |
| loss_lm = zero_loss | |
| current_idx += batch_size_lm | |
| # --- MMU Loss --- | |
| if batch_size_mmu > 0: | |
| start, end = current_idx, current_idx + batch_size_mmu | |
| loss_mmu = calculate_mmu_style_loss( | |
| logits[start:end], labels[start:end], masked_indices[start:end], | |
| p_mask_mmu, answer_lengths_mmu, self.output_size, device, | |
| ) | |
| else: | |
| loss_mmu = zero_loss | |
| current_idx += batch_size_mmu | |
| # --- VID (V2T) Loss --- | |
| if batch_size_v2t > 0: | |
| start, end = current_idx, current_idx + batch_size_v2t | |
| loss_vid = calculate_mmu_style_loss( | |
| logits[start:end], labels[start:end], masked_indices[start:end], | |
| p_mask_vid, answer_lengths_vid, self.output_size, device, | |
| ) | |
| else: | |
| loss_vid = zero_loss | |
| current_idx += batch_size_v2t | |
| # --- V2S Loss --- | |
| if batch_size_v2s > 0: | |
| start, end = current_idx, current_idx + batch_size_v2s | |
| if ( | |
| t2s_vocab_start is None | |
| or t2s_codebook_size is None | |
| or t2s_special_token_ids is None | |
| ): | |
| warnings.warn("v2s missing t2s vocab configuration; skipping loss.") | |
| loss_v2s = zero_loss | |
| elif answer_lengths_v2s is None or not (answer_lengths_v2s > 0).any(): | |
| warnings.warn("v2s encountered empty answer lengths; skipping loss.") | |
| loss_v2s = zero_loss | |
| else: | |
| eoa_id = t2s_special_token_ids.get('eoa') | |
| eos_id = t2s_special_token_ids.get('eos') | |
| loss_v2s = calculate_t2s_loss( | |
| logits[start:end], | |
| labels[start:end], | |
| masked_indices[start:end], | |
| p_mask_v2s, | |
| answer_lengths_v2s, | |
| t2s_vocab_start, | |
| t2s_codebook_size, | |
| eoa_id, | |
| eos_id, | |
| device, | |
| ignore_index=-100, | |
| ) | |
| else: | |
| loss_v2s = zero_loss | |
| current_idx += batch_size_v2s | |
| # --- S2T Loss --- | |
| if batch_size_s2t > 0: | |
| start, end = current_idx, current_idx + batch_size_s2t | |
| loss_s2t = calculate_mmu_style_loss( | |
| logits[start:end], labels[start:end], masked_indices[start:end], | |
| p_mask_s2t, answer_lengths_s2t, self.output_size, device, | |
| ) | |
| else: | |
| loss_s2t = zero_loss | |
| current_idx += batch_size_s2t | |
| # --- S2S Loss --- | |
| if batch_size_s2s > 0: | |
| start, end = current_idx, current_idx + batch_size_s2s | |
| if ( | |
| t2s_vocab_start is None | |
| or t2s_codebook_size is None | |
| or t2s_special_token_ids is None | |
| or p_mask_s2s is None | |
| or answer_lengths_s2s is None | |
| ): | |
| warnings.warn("s2s missing t2s vocab configuration or masks; skipping loss.") | |
| loss_s2s = zero_loss | |
| elif not (answer_lengths_s2s > 0).any(): | |
| warnings.warn("s2s encountered empty answer lengths; skipping loss.") | |
| loss_s2s = zero_loss | |
| else: | |
| eoa_id = t2s_special_token_ids.get('eoa') | |
| eos_id = t2s_special_token_ids.get('eos') | |
| loss_s2s = calculate_t2s_loss( | |
| logits[start:end], | |
| labels[start:end], | |
| masked_indices[start:end], | |
| p_mask_s2s, | |
| answer_lengths_s2s, | |
| t2s_vocab_start, | |
| t2s_codebook_size, | |
| eoa_id, | |
| eos_id, | |
| device, | |
| ignore_index=-100, | |
| ) | |
| else: | |
| loss_s2s = zero_loss | |
| current_idx += batch_size_s2s | |
| # --- T2S Loss --- | |
| if batch_size_t2s > 0: | |
| start, end = current_idx, current_idx + batch_size_t2s | |
| if ( | |
| t2s_vocab_start is not None | |
| and t2s_codebook_size is not None | |
| and t2s_special_token_ids is not None | |
| ): | |
| eoa_id = t2s_special_token_ids.get('eoa') | |
| eos_id = t2s_special_token_ids.get('eos') | |
| else: | |
| eoa_id = eos_id = None | |
| if eoa_id is not None and eos_id is not None: | |
| loss_t2s = calculate_t2s_loss( | |
| logits[start:end], | |
| labels[start:end], | |
| masked_indices[start:end], | |
| p_mask_t2s, | |
| answer_lengths_t2s, | |
| t2s_vocab_start, | |
| t2s_codebook_size, | |
| eoa_id, | |
| eos_id, | |
| device, | |
| ignore_index=-100, | |
| ) | |
| else: | |
| loss_t2s = calculate_mmu_style_loss( | |
| logits[start:end], labels[start:end], masked_indices[start:end], | |
| p_mask_t2s, answer_lengths_t2s, self.output_size, device | |
| ) | |
| else: | |
| loss_t2s = zero_loss | |
| current_idx += batch_size_t2s | |
| return logits, loss_t2i, loss_i2i, loss_lm, loss_mmu, loss_vid, loss_v2s, loss_s2t, loss_s2s, loss_t2s | |
| def forward_process_with_r2i( | |
| self, | |
| input_ids, | |
| labels, | |
| t2i_masks=None, | |
| max_seq_length=128, | |
| batch_size_t2i=0, | |
| batch_size_lm=0, | |
| batch_size_mmu=0, | |
| batch_size_r2i=0, | |
| p_mask_lm=None, | |
| p_mask_mmu=None, | |
| p_mask_r2i=None, | |
| answer_lengths=None, | |
| answer_lengths_lm=None, | |
| answer_lengths_r2i=None, | |
| ): | |
| # attention bias, True for batch_size, 1, seq_len, seq_len | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) | |
| attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) | |
| attention_bias[:batch_size_t2i] = attention_bias_t2i | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| # logits = self(input_ids).logits | |
| self.output_size = logits.shape[-1] | |
| if batch_size_t2i == 0: | |
| loss_t2i = torch.tensor(0.0, device=input_ids.device) | |
| else: | |
| # t2i loss | |
| loss_t2i = F.cross_entropy( | |
| logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), | |
| labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, | |
| ) | |
| # llada loss | |
| start_lm = batch_size_t2i | |
| end_lm = start_lm + batch_size_lm | |
| start_mmu = end_lm | |
| end_mmu = start_mmu + batch_size_mmu | |
| start_r2i = end_mmu | |
| end_r2i = start_r2i + batch_size_r2i | |
| masked_indices = input_ids == self.config.mask_token_id | |
| masked_indices_lm = masked_indices[start_lm:end_lm] | |
| masked_indices_mmu = masked_indices[start_mmu:end_mmu] | |
| masked_indices_r2i = masked_indices[start_r2i:end_r2i] | |
| p_mask_lm = p_mask_lm.to(masked_indices_lm.device) | |
| p_mask_mmu = p_mask_mmu.to(masked_indices_mmu.device) | |
| p_mask_r2i = p_mask_r2i.to(masked_indices_r2i.device) | |
| answer_lengths = answer_lengths.to(masked_indices_mmu.device) | |
| answer_lengths_lm = answer_lengths_lm.to(masked_indices_lm.device) | |
| answer_lengths_r2i = answer_lengths_r2i.to(masked_indices_r2i.device) | |
| loss_lm = F.cross_entropy( | |
| logits[start_lm:end_lm][masked_indices_lm].contiguous().view(-1, self.output_size), | |
| labels[start_lm:end_lm][masked_indices_lm].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| )/p_mask_lm[masked_indices_lm] | |
| if answer_lengths_lm is not None: | |
| loss_lm = torch.sum(loss_lm / answer_lengths_lm[masked_indices_lm]) / (logits[start_lm:end_lm].shape[0]) | |
| else: | |
| loss_lm = loss_lm.sum() / (logits[start_lm:end_lm].shape[0] * logits[start_lm:end_lm].shape[1]) | |
| loss_mmu = F.cross_entropy( | |
| logits[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1, self.output_size), | |
| labels[start_mmu:end_mmu][masked_indices_mmu].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| )/p_mask_mmu[masked_indices_mmu] | |
| loss_mmu = torch.sum(loss_mmu/answer_lengths[masked_indices_mmu]) / (logits[start_mmu:end_mmu].shape[0]) | |
| loss_r2i = F.cross_entropy( | |
| logits[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1, self.output_size), | |
| labels[start_r2i:end_r2i][masked_indices_r2i].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| )/p_mask_r2i[masked_indices_r2i] | |
| loss_r2i = torch.sum(loss_r2i/answer_lengths_r2i[masked_indices_r2i]) / (logits[start_r2i:end_r2i].shape[0]) | |
| return logits, loss_t2i, loss_lm, loss_mmu, loss_r2i | |
| def forward_t2i( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_t2i=0, | |
| max_seq_length=128, | |
| t2i_masks=None | |
| ): | |
| # attention bias, True for batch_size, 1, seq_len, seq_len | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]) | |
| attention_bias_t2i = (t2i_masks[:, :, None] & t2i_masks[:, None, :]).bool().unsqueeze(1) | |
| attention_bias[:batch_size_t2i] = attention_bias_t2i | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| # logits = self(input_ids).logits | |
| self.output_size = logits.shape[-1] | |
| # print(f"logits shape: {logits.shape}") B, 359, vocab_size | |
| loss_t2i = F.cross_entropy( | |
| logits[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1, self.output_size), | |
| labels[:batch_size_t2i, max_seq_length + 1:].contiguous().view(-1), ignore_index=-100, | |
| ) | |
| return loss_t2i | |
| # Temp | |
| def forward_i2i(self, input_ids, attention_mask, labels): | |
| """ | |
| Forward pass for the I2I task. | |
| """ | |
| outputs = self( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| logits = outputs.logits | |
| loss = F.cross_entropy( | |
| logits.view(-1, logits.size(-1)), | |
| labels.view(-1), | |
| ignore_index=-100 | |
| ) | |
| return logits, loss | |
| # Temp | |
| def forward_s2t( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_s2t=0, | |
| max_seq_length=128, | |
| p_mask_s2t=None, | |
| answer_lengths=None, | |
| ): | |
| # attention bias, True for batch_size, 1, seq_len, seq_len | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| self.output_size = logits.shape[-1] | |
| masked_indices = input_ids == self.config.mask_token_id | |
| masked_indices_s2t = masked_indices[-batch_size_s2t:] | |
| p_mask_s2t = p_mask_s2t.to(masked_indices_s2t.device) | |
| answer_lengths = answer_lengths.to(masked_indices_s2t.device) | |
| loss_s2t = F.cross_entropy( | |
| logits[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1, self.output_size), | |
| labels[-batch_size_s2t:][masked_indices_s2t].contiguous().view(-1), ignore_index=-100, reduction='none' | |
| )/p_mask_s2t[masked_indices_s2t] | |
| loss_s2t = torch.sum(loss_s2t/answer_lengths[masked_indices_s2t]) / (logits[-batch_size_s2t:].shape[0]) | |
| return logits, loss_s2t | |
| def forward_t2s( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_t2s=0, | |
| max_seq_length=128, | |
| p_mask_t2s=None, | |
| answer_lengths=None, | |
| ): | |
| """ | |
| Forward pass for text-to-speech (T2S) diffusion LM training. | |
| Args: | |
| input_ids: (B, L) Input token IDs (text + [MASK]*len(speech)). | |
| labels: (B, L) Target speech codebook token IDs. | |
| batch_size_t2s: Batch size for t2s task (for multitask batches). | |
| max_seq_length: Prompt(text) 길이 | |
| p_mask_t2s: (B, L) Mask probability per position (optional). | |
| answer_lengths: (B,) 각 row별 target length (optional). | |
| Returns: | |
| logits, loss_t2s | |
| """ | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| self.output_size = logits.shape[-1] | |
| masked_indices = input_ids == self.config.mask_token_id | |
| masked_indices_t2s = masked_indices[-batch_size_t2s:] | |
| p_mask_t2s = p_mask_t2s.to(masked_indices_t2s.device) | |
| answer_lengths = answer_lengths.to(masked_indices_t2s.device) | |
| loss_t2s = F.cross_entropy( | |
| logits[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1, self.output_size), | |
| labels[-batch_size_t2s:][masked_indices_t2s].contiguous().view(-1), | |
| ignore_index=-100, reduction='none' | |
| ) / p_mask_t2s[masked_indices_t2s] | |
| loss_t2s = torch.sum(loss_t2s / answer_lengths[masked_indices_t2s]) / logits[-batch_size_t2s:].shape[0] | |
| return logits, loss_t2s | |
| def forward_v2t( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_v2t=0, | |
| max_seq_length=128, | |
| p_mask_v2t=None, | |
| answer_lengths=None, | |
| ): | |
| """ | |
| video-to-text (V2T) diffusion LM training. | |
| """ | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| self.output_size = logits.shape[-1] | |
| masked_indices = input_ids == self.config.mask_token_id | |
| masked_indices_v2t = masked_indices[:batch_size_v2t] | |
| p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) | |
| answer_lengths = answer_lengths.to(masked_indices_v2t.device) | |
| loss_v2t = F.cross_entropy( | |
| logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), | |
| labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), | |
| ignore_index=-100, | |
| reduction='none' | |
| ) / p_mask_v2t[masked_indices_v2t] | |
| loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) | |
| return logits, loss_v2t | |
| def forward_v2t_encoder( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_v2t=0, | |
| max_seq_length=128, | |
| p_mask_v2t=None, | |
| answer_lengths=None, | |
| ): | |
| """ | |
| video-to-text (V2T) diffusion LM training. | |
| """ | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| input_embeddings = super().model.transformer.wte(input_ids) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| self.output_size = logits.shape[-1] | |
| masked_indices = input_ids == self.config.mask_token_id | |
| masked_indices_v2t = masked_indices[:batch_size_v2t] | |
| p_mask_v2t = p_mask_v2t.to(masked_indices_v2t.device) | |
| answer_lengths = answer_lengths.to(masked_indices_v2t.device) | |
| loss_v2t = F.cross_entropy( | |
| logits[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1, self.output_size), | |
| labels[:batch_size_v2t][masked_indices_v2t].contiguous().view(-1), | |
| ignore_index=-100, | |
| reduction='none' | |
| ) / p_mask_v2t[masked_indices_v2t] | |
| loss_v2t = torch.sum(loss_v2t / answer_lengths[masked_indices_v2t]) / (logits[:batch_size_v2t].shape[0]) | |
| return logits, loss_v2t | |
| def forward_v2s( | |
| self, | |
| input_ids, | |
| labels, | |
| batch_size_v2s=0, | |
| max_seq_length: int = 128, | |
| p_mask_v2s=None, | |
| answer_lengths=None, | |
| t2s_vocab_start: Optional[int] = None, | |
| t2s_codebook_size: Optional[int] = None, | |
| t2s_special_token_ids: Optional[Dict[str, int]] = None, | |
| ): | |
| """ | |
| # video-to-speech (V2S) diffusion LM training. | |
| """ | |
| attention_bias = torch.ones(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1], device=input_ids.device) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| self.output_size = logits.shape[-1] | |
| masked_indices = input_ids == self.config.mask_token_id | |
| masked_indices_v2s = masked_indices[:batch_size_v2s] | |
| if batch_size_v2s == 0: | |
| return logits, torch.tensor(0.0, device=input_ids.device) | |
| p_mask_v2s = p_mask_v2s.to(masked_indices_v2s.device) | |
| answer_lengths = answer_lengths.to(masked_indices_v2s.device) | |
| if ( | |
| t2s_vocab_start is not None | |
| and t2s_codebook_size is not None | |
| and t2s_special_token_ids is not None | |
| ): | |
| eoa_id = t2s_special_token_ids.get('eoa') | |
| eos_id = t2s_special_token_ids.get('eos') | |
| else: | |
| eoa_id = eos_id = None | |
| loss_v2s = calculate_t2s_loss( | |
| logits[:batch_size_v2s], | |
| labels[:batch_size_v2s], | |
| masked_indices_v2s, | |
| p_mask_v2s, | |
| answer_lengths, | |
| t2s_vocab_start, | |
| t2s_codebook_size, | |
| eoa_id, | |
| eos_id, | |
| input_ids.device, | |
| ignore_index=-100, | |
| ) | |
| return logits, loss_v2s | |
| # def forward_i2i(self, input_ids, attention_mask, labels, max_prompt_length): | |
| # """ | |
| # Forward pass for the I2I task. | |
| # """ | |
| # outputs = self( | |
| # input_ids=input_ids, | |
| # attention_mask=attention_mask | |
| # ) | |
| # logits = outputs.logits | |
| # logits_for_loss = logits[:, max_prompt_length:].contiguous() | |
| # labels_for_loss = labels[:, max_prompt_length:].contiguous() | |
| # loss = F.cross_entropy( | |
| # logits_for_loss.view(-1, logits_for_loss.size(-1)), | |
| # labels_for_loss.view(-1), | |
| # ignore_index=-100 | |
| # ) | |
| # return logits, loss | |
| def mmu_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): | |
| """ | |
| Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete | |
| the sequence max_new_tokens times, feeding the predictions back into the model each time. | |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. | |
| """ | |
| if attention_mask is not None and 0.0 in attention_mask: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| # print(f"attention_bias: {attention_bias}") | |
| else: | |
| attention_bias = None | |
| try: | |
| device = idx.device | |
| except: | |
| device = input_embeddings.device | |
| result = [] | |
| batch_size = idx.shape[0] | |
| x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) | |
| x[:, :idx.shape[1]] = idx.clone() | |
| prompt_index = (x != mask_id) | |
| assert max_new_tokens % block_length == 0 | |
| num_blocks = max_new_tokens // block_length | |
| assert steps % num_blocks == 0 | |
| steps = steps // num_blocks | |
| # print(f"num_blocks: {num_blocks}, steps: {steps}") | |
| # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) | |
| for num_block in range(num_blocks): | |
| block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) | |
| # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) | |
| # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") | |
| for i in range(steps): | |
| mask_index = (x == mask_id) | |
| if cfg_scale > 0.0: | |
| un_x = x.clone() | |
| un_x[prompt_index] = mask_id | |
| x_ = torch.cat([x, un_x], dim=0) | |
| logits = self(x_).logits | |
| logits, un_logits = torch.chunk(logits, 2, dim=0) | |
| logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
| else: | |
| logits = self(x, attention_bias=attention_bias).logits | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) # b, l | |
| if remasking == 'low_confidence': | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.squeeze( | |
| torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l | |
| elif remasking == 'random': | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| raise NotImplementedError(remasking) | |
| x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf | |
| x0 = torch.where(mask_index, x0, x) | |
| confidence = torch.where(mask_index, x0_p, -np.inf) | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| for j in range(confidence.shape[0]): | |
| _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) | |
| transfer_index[j, select_index] = True | |
| x[transfer_index] = x0[transfer_index] | |
| # logits = logits[:, -1, :] / temperature | |
| # # optionally crop the logits to only the top k options | |
| # if top_k is not None: | |
| # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| # logits[logits < v[:, [-1]]] = -float('Inf') | |
| # # apply softmax to convert logits to (normalized) probabilities | |
| # probs = F.softmax(logits, dim=-1) | |
| # # sample from the distribution | |
| # idx_next = torch.multinomial(probs, num_samples=1) | |
| # result.append(idx_next[0][0]) | |
| # # append sampled index to the running sequence and continue | |
| # if self.config.w_clip_vit: | |
| # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) | |
| # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) | |
| # else: | |
| # idx = torch.cat((idx, idx_next), dim=1) | |
| # if eot_token is not None and idx_next.cpu() == eot_token: | |
| # break | |
| return x | |
| def s2t_generate(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): | |
| """ | |
| Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete | |
| the sequence max_new_tokens times, feeding the predictions back into the model each time. | |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. | |
| """ | |
| if attention_mask is not None and 0.0 in attention_mask: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| # print(f"attention_bias: {attention_bias}") | |
| else: | |
| attention_bias = None | |
| try: | |
| device = idx.device | |
| except: | |
| device = input_embeddings.device | |
| result = [] | |
| batch_size = idx.shape[0] | |
| x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) | |
| x[:, :idx.shape[1]] = idx.clone() | |
| prompt_index = (x != mask_id) | |
| assert max_new_tokens % block_length == 0 | |
| num_blocks = max_new_tokens // block_length | |
| assert steps % num_blocks == 0 | |
| steps = steps // num_blocks | |
| # print(f"num_blocks: {num_blocks}, steps: {steps}") | |
| # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) | |
| for num_block in range(num_blocks): | |
| block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) | |
| # num_transfer_tokens = get_num_transfer_tokens(prompt_index, steps) | |
| # print(f"num_transfer_tokens: {num_transfer_tokens}, num_transfer_tokens.shape: {num_transfer_tokens.shape}") | |
| for i in range(steps): | |
| mask_index = (x == mask_id) | |
| if cfg_scale > 0.0: | |
| un_x = x.clone() | |
| un_x[prompt_index] = mask_id | |
| x_ = torch.cat([x, un_x], dim=0) | |
| logits = self(x_).logits | |
| logits, un_logits = torch.chunk(logits, 2, dim=0) | |
| logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
| else: | |
| logits = self(x, attention_bias=attention_bias).logits | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) # b, l | |
| if remasking == 'low_confidence': | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.squeeze( | |
| torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l | |
| elif remasking == 'random': | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| raise NotImplementedError(remasking) | |
| x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf | |
| x0 = torch.where(mask_index, x0, x) | |
| confidence = torch.where(mask_index, x0_p, -np.inf) | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| for j in range(confidence.shape[0]): | |
| _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) | |
| transfer_index[j, select_index] = True | |
| x[transfer_index] = x0[transfer_index] | |
| # logits = logits[:, -1, :] / temperature | |
| # # optionally crop the logits to only the top k options | |
| # if top_k is not None: | |
| # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| # logits[logits < v[:, [-1]]] = -float('Inf') | |
| # # apply softmax to convert logits to (normalized) probabilities | |
| # probs = F.softmax(logits, dim=-1) | |
| # # sample from the distribution | |
| # idx_next = torch.multinomial(probs, num_samples=1) | |
| # result.append(idx_next[0][0]) | |
| # # append sampled index to the running sequence and continue | |
| # if self.config.w_clip_vit: | |
| # idx_next_embeddings = self.mmada.model.embed_tokens(idx_next) | |
| # input_embeddings = torch.cat([input_embeddings, idx_next_embeddings], dim=1) | |
| # else: | |
| # idx = torch.cat((idx, idx_next), dim=1) | |
| # if eot_token is not None and idx_next.cpu() == eot_token: | |
| # break | |
| return x | |
| def ti2ti_generate( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, | |
| timesteps_text: int | None = None, | |
| timesteps_image: int | None = None, | |
| guidance_scale=0, | |
| noise_schedule=cosine_schedule, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=1024, | |
| mask_token_id=126336, | |
| resolution=512, | |
| codebook_size=8192, | |
| uni_prompting=None, | |
| **kwargs, | |
| ): | |
| """ | |
| TI2TI generation that fills masked text and image tokens; allows separate timesteps. | |
| Returns (filled_tokens, decoded_texts). | |
| """ | |
| if input_ids is None or attention_mask is None: | |
| raise ValueError("input_ids and attention_mask are required for ti2ti_generate.") | |
| if uni_prompting is None: | |
| raise ValueError("uni_prompting is required for ti2ti_generate.") | |
| device = input_ids.device | |
| text_vocab_size = len(uni_prompting.text_tokenizer) | |
| image_vocab_start = text_vocab_size | |
| image_vocab_end = image_vocab_start + codebook_size | |
| timesteps_text = timesteps if timesteps_text is None else timesteps_text | |
| timesteps_image = timesteps if timesteps_image is None else timesteps_image | |
| seq = input_ids.clone() | |
| if attention_mask is None: | |
| attn = torch.ones_like(seq, dtype=torch.long) | |
| else: | |
| attn = attention_mask | |
| use_guidance = uncond_input_ids is not None and guidance_scale > 0 | |
| if use_guidance: | |
| seq_uncond = uncond_input_ids.clone() | |
| if uncond_attention_mask is None: | |
| attn_uncond = torch.ones_like(seq_uncond, dtype=torch.long) | |
| else: | |
| attn_uncond = uncond_attention_mask | |
| else: | |
| seq_uncond = None | |
| attn_uncond = None | |
| total_len = seq.shape[1] | |
| def _uniform_transfer_plan(mask_bool: torch.Tensor, steps_count: int) -> Optional[torch.Tensor]: | |
| """Evenly divide masked token updates across steps.""" | |
| if steps_count is None or steps_count <= 0: | |
| return None | |
| mask_num = mask_bool.sum(dim=1, keepdim=True) | |
| if mask_num.numel() == 0: | |
| return None | |
| base = mask_num // steps_count | |
| remainder = mask_num % steps_count | |
| plan = torch.zeros(mask_num.size(0), steps_count, device=mask_bool.device, dtype=torch.int64) + base | |
| for idx in range(mask_num.size(0)): | |
| rem_val = remainder[idx].item() | |
| if rem_val > 0: | |
| plan[idx, :rem_val] += 1 | |
| return plan | |
| prompt_block_len = uni_prompting.max_text_len | |
| soi_id = int(uni_prompting.sptids_dict.get("<|soi|>", torch.tensor([-1]))[0].item()) | |
| eoi_id = int(uni_prompting.sptids_dict.get("<|eoi|>", torch.tensor([-1]))[0].item()) | |
| pad_id = int(getattr(uni_prompting, "pad_id", 0)) | |
| def _locate_blocks(sample_seq: torch.Tensor, sample_attn: Optional[torch.Tensor]): | |
| # Find second (target) soi/eoi pair; fallback to template formula. | |
| soi_positions = (sample_seq == soi_id).nonzero(as_tuple=True)[0] | |
| eoi_positions = (sample_seq == eoi_id).nonzero(as_tuple=True)[0] | |
| tgt_soi = None | |
| tgt_eoi = None | |
| if soi_positions.numel() >= 2: | |
| tgt_soi = int(soi_positions[1].item()) | |
| tgt_eoi_candidates = [int(e.item()) for e in eoi_positions if int(e.item()) > tgt_soi] | |
| if tgt_eoi_candidates: | |
| tgt_eoi = tgt_eoi_candidates[0] | |
| if tgt_soi is None or tgt_eoi is None: | |
| # fallback: compute with pad offset the old way | |
| non_pad = (sample_seq != pad_id).nonzero(as_tuple=True) | |
| pad_offset = int(non_pad[0][0].item()) if len(non_pad) > 0 and non_pad[0].numel() > 0 else 0 | |
| tgt_soi = pad_offset + 1 + 1 + seq_len + 1 + prompt_block_len + 1 # soi before target img | |
| tgt_eoi = tgt_soi + seq_len + 1 # eoi after target img | |
| img_start_local = tgt_soi + 1 | |
| img_end_local = min(tgt_eoi, sample_seq.size(0)) | |
| if sample_attn is not None: | |
| text_attn = sample_attn[tgt_eoi + 1 :] | |
| nonzero = (text_attn != 0).nonzero(as_tuple=True) | |
| if len(nonzero) > 0 and nonzero[0].numel() > 0: | |
| last_idx = int(nonzero[0][-1].item()) | |
| text_end_local = tgt_eoi + 1 + last_idx + 1 | |
| else: | |
| text_end_local = tgt_eoi + 1 + prompt_block_len | |
| else: | |
| text_end_local = tgt_eoi + 1 + prompt_block_len | |
| text_start_local = tgt_eoi + 1 | |
| text_end_local = min(text_end_local, sample_seq.size(0)) | |
| return img_start_local, img_end_local, text_start_local, text_end_local | |
| img_start, img_end, text_start, text_end = _locate_blocks(seq[0], attn[0] if attn is not None else None) | |
| text_indices = torch.arange(total_len, device=device) | |
| initial_text_mask = (seq == mask_token_id) & (text_indices >= text_start) & (text_indices < text_end) | |
| text_transfer_plan = _uniform_transfer_plan(initial_text_mask, timesteps_text) | |
| text_step_idx = 0 | |
| # Simultaneous fill: at each step, update image/text masks that still remain | |
| max_steps = max(timesteps_image, timesteps_text) | |
| for step in range(max_steps): | |
| mask_map = seq == mask_token_id | |
| img_mask = mask_map & (text_indices >= img_start) & (text_indices < img_end) if step < timesteps_image else None | |
| text_mask = mask_map & (text_indices >= text_start) & (text_indices < text_end) if step < timesteps_text else None | |
| if not ((img_mask is not None and img_mask.any()) or (text_mask is not None and text_mask.any())): | |
| break | |
| attn_bias = (attn[:, :, None] & attn[:, None, :]).bool().unsqueeze(1) | |
| logits_cond = self(seq, attention_bias=attn_bias).logits | |
| if use_guidance: | |
| attn_bias_uncond = (attn_uncond[:, :, None] & attn_uncond[:, None, :]).bool().unsqueeze(1) | |
| logits_uncond = self(seq_uncond, attention_bias=attn_bias_uncond).logits | |
| logits = logits_uncond + (guidance_scale + 1.0) * (logits_cond - logits_uncond) | |
| else: | |
| logits = logits_cond | |
| if text_mask is not None and text_mask.any(): | |
| logits_text = logits[..., :text_vocab_size] | |
| probs_text = logits_text.softmax(dim=-1) | |
| sampled_text = torch.multinomial( | |
| probs_text.view(-1, text_vocab_size), | |
| 1, | |
| replacement=False | |
| ).view(*logits_text.shape[:2]) | |
| sampled_probs = torch.gather( | |
| probs_text, dim=-1, index=sampled_text.unsqueeze(-1) | |
| ).squeeze(-1) | |
| candidate_seq = torch.where(text_mask, sampled_text, seq) | |
| confidence = torch.full_like(sampled_probs, float("-inf")) | |
| confidence = torch.where(text_mask, sampled_probs, confidence) | |
| if text_transfer_plan is not None and text_step_idx < text_transfer_plan.shape[1]: | |
| transfer_counts = text_transfer_plan[:, text_step_idx] | |
| else: | |
| transfer_counts = text_mask.sum(dim=1) | |
| transfer_mask = torch.zeros_like(text_mask, dtype=torch.bool) | |
| for b_idx in range(seq.shape[0]): | |
| mask_count = int(text_mask[b_idx].sum().item()) | |
| if mask_count == 0: | |
| continue | |
| k = int(min(max(transfer_counts[b_idx].item(), 0), mask_count)) | |
| if k <= 0: | |
| continue | |
| _, top_idx = torch.topk(confidence[b_idx], k=k) | |
| transfer_mask[b_idx, top_idx] = True | |
| if transfer_mask.any(): | |
| seq = torch.where(transfer_mask, candidate_seq, seq) | |
| text_step_idx += 1 | |
| if img_mask is not None and img_mask.any(): | |
| logits_img = logits[..., image_vocab_start:image_vocab_end] | |
| probs_img = logits_img.softmax(dim=-1) | |
| sampled_img = torch.multinomial( | |
| probs_img.view(-1, codebook_size), | |
| 1, | |
| replacement=False | |
| ).view(*logits_img.shape[:2]) + image_vocab_start | |
| seq = torch.where(img_mask, sampled_img, seq) | |
| if use_guidance: | |
| updated_mask = torch.zeros_like(seq, dtype=torch.bool) | |
| if img_mask is not None: | |
| updated_mask |= img_mask | |
| if text_mask is not None: | |
| updated_mask |= text_mask | |
| seq_uncond = torch.where(updated_mask, seq, seq_uncond) | |
| # Decode text tokens from filled sequence | |
| pred_texts = [] | |
| for row in seq: | |
| text_tokens = [int(t) for t in row.tolist() if 0 <= t < text_vocab_size] | |
| pred_texts.append(uni_prompting.text_tokenizer.decode(text_tokens, skip_special_tokens=True)) | |
| return seq, pred_texts | |
| def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None): | |
| """ | |
| Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete | |
| the sequence max_new_tokens times, feeding the predictions back into the model each time. | |
| Most likely you'll want to make sure to be in model.eval() mode of operation for this. | |
| """ | |
| if attention_mask is not None and 0.0 in attention_mask: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| # print(f"attention_bias: {attention_bias}") | |
| else: | |
| attention_bias = None | |
| try: | |
| device = idx.device | |
| except: | |
| device = input_embeddings.device | |
| result = [] | |
| batch_size = idx.shape[0] | |
| x = torch.full((batch_size, idx.shape[1] + max_new_tokens), mask_id, dtype=torch.long).to(self.device) | |
| x[:, :idx.shape[1]] = idx.clone() | |
| prompt_index = (x != mask_id) | |
| assert max_new_tokens % block_length == 0 | |
| num_blocks = max_new_tokens // block_length | |
| assert steps % num_blocks == 0 | |
| steps = steps // num_blocks | |
| for num_block in range(num_blocks): | |
| block_mask_index = (x[:, idx.shape[1] + num_block * block_length: idx.shape[1] + (num_block + 1) * block_length:] == mask_id) | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) | |
| for i in range(steps): | |
| mask_index = (x == mask_id) | |
| if cfg_scale > 0.0: | |
| un_x = x.clone() | |
| un_x[prompt_index] = mask_id | |
| x_ = torch.cat([x, un_x], dim=0) | |
| logits = self(x_).logits | |
| logits, un_logits = torch.chunk(logits, 2, dim=0) | |
| logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
| else: | |
| logits = self(x, attention_bias=attention_bias).logits | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) # b, l | |
| if remasking == 'low_confidence': | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.squeeze( | |
| torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l | |
| elif remasking == 'random': | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| raise NotImplementedError(remasking) | |
| x0_p[:, idx.shape[1] + (num_block + 1) * block_length:] = -np.inf | |
| x0 = torch.where(mask_index, x0, x) | |
| confidence = torch.where(mask_index, x0_p, -np.inf) | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| for j in range(confidence.shape[0]): | |
| _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) | |
| transfer_index[j, select_index] = True | |
| x[transfer_index] = x0[transfer_index] | |
| if eot_token is not None: | |
| last_token_index_in_current_block = idx.shape[1] + (num_block + 1) * block_length - 1 | |
| if last_token_index_in_current_block < x.shape[1]: | |
| tokens_at_block_end = x[:, last_token_index_in_current_block] | |
| if torch.all(tokens_at_block_end == eot_token): | |
| break | |
| return x | |
| def t2i_generate_decoding_stepwise( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, # ideal number of steps is 18 in maskgit paper | |
| guidance_scale=0, | |
| noise_schedule=cosine_schedule, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=1024, | |
| mask_token_id = 126336, | |
| resolution = 512, | |
| codebook_size = 8192, | |
| vq_model = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Generate 1:1 similar to the original MaskGit repo | |
| https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 | |
| """ | |
| # begin with all image token ids masked | |
| # 计算有多少个mask token | |
| mask_count = (input_ids == mask_token_id).sum().item() | |
| num_vq_tokens = seq_len | |
| num_new_special_tokens = 0 | |
| uni_prompting = kwargs.get("uni_prompting", None) | |
| # print(f"config.model.mmada.llm_vocab_size: {config.model.mmada.llm_vocab_size}, {len(uni_prompting.text_tokenizer)}") | |
| input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() | |
| input_ids_minus_lm_vocab_size = torch.where(input_ids_minus_lm_vocab_size == mask_token_id, mask_token_id, input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens) | |
| # for classifier-free guidance | |
| if uncond_input_ids is not None: | |
| uncond_prefix = uncond_input_ids[:, :resolution + 1] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat( | |
| [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) | |
| attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(model_input, attention_bias=attention_bias).logits | |
| # print(f"logits.shape: {logits.shape}") | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| # logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| # it seems that muse has a different cfg setting | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| else: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, len(uni_prompting.text_tokenizer) + num_new_special_tokens: len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| # logits: 1, 1024, 8192 | |
| # print(f"logits.shape: {logits.shape}") | |
| probs = logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, logits.size(-1)) | |
| # print(f"probs: {probs}, probs.shape: {probs.shape}, sampled: {sampled}, sampled.shape: {sampled.shape}") | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # 1, 1024 | |
| unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
| # print(f"unknown_map.sum(dim=-1, keepdim=True): {unknown_map.sum(dim=-1, keepdim=True)}") | |
| sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
| # Defines the mask ratio for the next round. The number to mask out is | |
| current_image_vq_indices = sampled_ids.clone() | |
| # print(f"current_image_vq_indices: {current_image_vq_indices}") | |
| current_image_vq_indices = torch.clamp(current_image_vq_indices, 0, 8192 - 1) | |
| current_image = vq_model.decode_code(current_image_vq_indices) | |
| images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0) | |
| images *= 255.0 | |
| images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) | |
| pil_images = Image.fromarray(images[0]) | |
| yield pil_images, f"Step {step + 1}/{timesteps}" | |
| # determined by mask_ratio * unknown_number_in_the_beginning. | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio)) | |
| # Computes the probabilities of each selected tokens. | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) | |
| selected_probs = selected_probs.squeeze(-1) | |
| # Ignores the tokens given in the input by overwriting their confidence. | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| # Gets mask lens for each sample in the batch according to the mask ratio. | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| # Keeps at least one of prediction in this round and also masks out at least | |
| # one and for the next iteration | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) | |
| ) | |
| # print(f"mask_len: {mask_len}, mask_len.shape: {mask_len.shape}") | |
| # Adds noise for randomness | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| # Masks tokens with lower confidence. | |
| input_ids[:, -(num_vq_tokens + 1):-1] = torch.where(masking, mask_token_id, | |
| sampled_ids + len(uni_prompting.text_tokenizer) | |
| + num_new_special_tokens) | |
| input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
| return sampled_ids | |
| def i2i_generate_decoding_stepwise( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| uncond_input_ids: torch.LongTensor = None, | |
| attention_mask=None, | |
| uncond_attention_mask=None, | |
| temperature=1.0, | |
| timesteps=18, # ideal number of steps is 18 in maskgit paper | |
| guidance_scale=0, | |
| noise_schedule=cosine_schedule, | |
| generator: torch.Generator = None, | |
| config=None, | |
| seq_len=1024, | |
| mask_token_id=126336, | |
| resolution=512, | |
| codebook_size=8192, | |
| vq_model=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Stepwise i2i decoding that yields intermediate images per step. | |
| """ | |
| if vq_model is None: | |
| raise ValueError("vq_model is required for stepwise decoding.") | |
| mask_count = (input_ids == mask_token_id).sum().item() | |
| num_vq_tokens = seq_len | |
| num_new_special_tokens = 0 | |
| uni_prompting = kwargs.get("uni_prompting", None) | |
| input_ids_minus_lm_vocab_size = input_ids[:, -(num_vq_tokens + 1):-1].clone() | |
| input_ids_minus_lm_vocab_size = torch.where( | |
| input_ids_minus_lm_vocab_size == mask_token_id, | |
| mask_token_id, | |
| input_ids_minus_lm_vocab_size - len(uni_prompting.text_tokenizer) - num_new_special_tokens, | |
| ) | |
| if uncond_input_ids is not None: | |
| uncond_prefix = uncond_input_ids[:, :resolution + 1] | |
| for step in range(timesteps): | |
| if uncond_input_ids is not None and guidance_scale > 0: | |
| uncond_input_ids = torch.cat( | |
| [uncond_prefix, input_ids[:, resolution + 1:]], dim=1) | |
| model_input = torch.cat([input_ids, uncond_input_ids]) | |
| all_attention_mask = torch.cat([attention_mask, uncond_attention_mask], dim=0) | |
| attention_bias = (all_attention_mask[:, :, None] & all_attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(model_input, attention_bias=attention_bias).logits | |
| cond_logits, uncond_logits = torch.chunk(logits, 2, dim=0) | |
| logits = (1 + guidance_scale) * cond_logits - guidance_scale * uncond_logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, | |
| len(uni_prompting.text_tokenizer) + num_new_special_tokens: | |
| len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| else: | |
| attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) | |
| logits = self(input_ids, attention_bias=attention_bias).logits | |
| logits = logits[:, -(num_vq_tokens + 1):-1, | |
| len(uni_prompting.text_tokenizer) + num_new_special_tokens: | |
| len(uni_prompting.text_tokenizer) + num_new_special_tokens + codebook_size] | |
| probs = logits.softmax(dim=-1) | |
| sampled = probs.reshape(-1, logits.size(-1)) | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) | |
| unknown_map = input_ids_minus_lm_vocab_size == mask_token_id | |
| sampled_ids = torch.where(unknown_map, sampled_ids, input_ids_minus_lm_vocab_size) | |
| current_image_vq_indices = torch.clamp(sampled_ids.clone(), 0, codebook_size - 1) | |
| current_image = vq_model.decode_code(current_image_vq_indices) | |
| images = torch.clamp((current_image + 1.0) / 2.0, min=0.0, max=1.0) | |
| images *= 255.0 | |
| images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) | |
| pil_images = Image.fromarray(images[0]) | |
| yield pil_images, f"Step {step + 1}/{timesteps}" | |
| ratio = 1.0 * (step + 1) / timesteps | |
| mask_ratio = noise_schedule(torch.tensor(ratio)) | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]).squeeze(-1) | |
| selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(logits.device) | |
| mask_len = torch.max( | |
| torch.tensor([1], device=logits.device), | |
| torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len), | |
| ) | |
| temperature = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) | |
| input_ids[:, -(num_vq_tokens + 1):-1] = torch.where( | |
| masking, | |
| mask_token_id, | |
| sampled_ids + len(uni_prompting.text_tokenizer) + num_new_special_tokens, | |
| ) | |
| input_ids_minus_lm_vocab_size = torch.where(masking, mask_token_id, sampled_ids) | |
| return sampled_ids | |
| AutoConfig.register("omada", OMadaConfig) | |
| AutoModelForCausalLM.register(OMadaConfig, OMadaModelLM) | |
| AutoModel.register(OMadaConfig, OMadaModelLM) | |