import functools import logging import math from typing import List, Dict, Tuple, Optional import torch.nn.functional as F import torch from torch import nn from transformers import AutoTokenizer, AutoConfig from .modeling_llada import LLaDAModelLM from .configuration_llada import LLaDAConfig from transformers.modeling_outputs import CausalLMOutputWithPast __all__ = ["LLaDAForMultiModalGeneration"] def create_attention_mask(original_lengths, max_tokens, device): batch_size = len(original_lengths) attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device) for i, length in enumerate(original_lengths): attention_mask[i, :length] = 1 return attention_mask class LLaDAForMultiModalGeneration(LLaDAModelLM): config_class = LLaDAConfig base_model_prefix = "model" IMAGE_START_TOKEN = 126349 IMAGE_END_TOKEN = 126350 ANSWER_START_TOKEN = 126354 ANSWER_END_TOKEN = 126355 BREAKLINE_TOKEN = 126084 MASK_TOKEN = 126336 PAD_TOKEN = 126339 def __init__(self, config: LLaDAConfig, *args, **kwargs): print(f"Initializing LLaDAForMultiModalGeneration with config: {config}") super().__init__(config, *args, **kwargs) self._debug_step = 0 def forward( self, input_ids=None, labels=None, infer=False, use_cache=False, return_dict=False, compute_separate_losses=True, t=None, text_coeff=1.0, image_coeff=1.0, ): if infer: input_ids = input_ids.tolist() max_tokens = max([len(_) for _ in input_ids]) original_lengths = [len(example) for example in input_ids] input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) attention_mask = create_attention_mask(original_lengths, max_tokens, self.device) attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) output = LLaDAModelLM.forward( self, input_ids=input_ids, attention_bias=attention_bias, use_cache=use_cache ) if infer: return output if labels is None: if return_dict: return {'logits': output.logits} else: return output.logits labels = [label + [-100] * (max_tokens - len(label)) for label in labels] labels = torch.tensor(labels, dtype=torch.int64, device=self.device) logits = output.logits batch_size = logits.shape[0] unscaled_loss = F.cross_entropy( logits.contiguous().view(-1, logits.shape[-1]), labels.contiguous().view(-1), ignore_index=-100, reduction='none' ).view(batch_size, -1) valid_mask = (labels != -100) if valid_mask.sum() > 0: interleave_loss = unscaled_loss[valid_mask].mean() else: interleave_loss = torch.tensor(0.0, device=self.device) if compute_separate_losses: self._debug_step += 1 debug_this_step = (self._debug_step <= 3) if debug_this_step: print(f"\n{'='*80}") print(f"DEBUG Step {self._debug_step}") print(f"{'='*80}") text_loss_list = [] image_loss_list = [] for b in range(batch_size): answer_start_positions = (input_ids[b] == self.ANSWER_START_TOKEN).nonzero(as_tuple=True)[0] if len(answer_start_positions) == 0: continue answer_start = answer_start_positions[0].item() answer_end_in_search = (input_ids[b, answer_start:] == self.ANSWER_END_TOKEN).nonzero(as_tuple=True)[0] if len(answer_end_in_search) > 0: answer_end = answer_start + answer_end_in_search[0].item() else: answer_end = original_lengths[b] answer_region_input = input_ids[b, answer_start:answer_end] image_start_in_answer = (answer_region_input == self.IMAGE_START_TOKEN).nonzero(as_tuple=True)[0] if len(image_start_in_answer) > 0: image_start_pos = answer_start + image_start_in_answer[0].item() image_end_search = input_ids[b, image_start_pos:] image_end_in_search = (image_end_search == self.IMAGE_END_TOKEN).nonzero(as_tuple=True)[0] if len(image_end_in_search) > 0 : image_end_pos = image_start_pos + image_end_in_search[0].item() for pos in range(image_start_pos + 1, image_end_pos): if input_ids[b, pos] != self.BREAKLINE_TOKEN: image_loss_list.append(unscaled_loss[b, pos]) for pos in range(image_end_pos + 1, answer_end): if labels[b, pos] != -100: text_loss_list.append(unscaled_loss[b, pos]) else: for pos in range(answer_start + 1, answer_end): if labels[b, pos] != -100: text_loss_list.append(unscaled_loss[b, pos]) if debug_this_step: print(f"Total text_loss_list length: {len(text_loss_list)}") print(f"Total image_loss_list length: {len(image_loss_list)}") if len(text_loss_list) > 0: non_zero_text = [l.item() for l in text_loss_list if l.item() > 0] print(f"Non-zero text losses count: {len(non_zero_text)}/{len(text_loss_list)}") print(f"Sample non-zero text losses: {non_zero_text[:5]}") if len(image_loss_list) > 0: non_zero_image = [l.item() for l in image_loss_list if l.item() > 0] print(f"Non-zero image losses count: {len(non_zero_image)}/{len(image_loss_list)}") print(f"Sample non-zero image losses: {non_zero_image[:5]}") print(f"{'='*80}\n") if len(text_loss_list) > 0: text_loss = torch.stack(text_loss_list).mean() else: text_loss = torch.tensor(0.0, device=self.device) if len(image_loss_list) > 0: image_loss = torch.stack(image_loss_list).mean() else: image_loss = torch.tensor(0.0, device=self.device) if t is not None and len(text_loss_list) > 0: text_loss = text_loss / t.mean().clamp(min=0.01) if return_dict: return { 'logits': logits, 'loss': interleave_loss, 'interleave_loss': interleave_loss, 'text_loss': text_loss, 'image_loss': image_loss, 'labels': labels, } else: return interleave_loss, { 'text_loss': text_loss, 'image_loss': image_loss, 'interleave_loss': interleave_loss, } else: if return_dict: return {'logits': logits, 'loss': interleave_loss, 'labels': labels} else: return interleave_loss def get_fsdp_wrap_module_list(self) -> List: modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out] return modules def get_checkpointing_wrap_module_list(self) -> List: return list(self.model.transformer.blocks)