import torch from utils.util import forward_process_length, shift_logits,forward_process import torch.nn.functional as F def compute_loss_by_config( input_ids, denoiser, question_length, mask_id, block_size, enable_shift, share_steps, self_align, feature_align, self_step, eos_id, config ): """Select different loss functions based on config file""" training_mode = config.get('training_mode', 'dream') if training_mode == 'llada': import ipdb; ipdb.set_trace() return compute_llada_loss( input_ids, denoiser, question_length, mask_id, block_size, enable_shift, share_steps, self_align, feature_align, self_step, eos_id ) elif training_mode == 'dream': return compute_loss( input_ids, denoiser, question_length, mask_id, block_size, enable_shift, share_steps, self_align, feature_align, self_step, eos_id ) else: raise ValueError(f"Unsupported training mode: {training_mode}") def compute_loss( input_ids, denoiser, question_length, mask_id, block_size, enable_shift, share_steps, self_align, feature_align, self_step, eos_id, ): B, L = input_ids.shape noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id) token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L) prompt_mask = (token_positions < question_length.unsqueeze(1)) noisy_batch[prompt_mask] = input_ids[prompt_mask] # prompt_mask = prompt_mask.to(torch.int64) noisy_batch = noisy_batch.to(denoiser.device) attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device) attention_mask=attention_mask.to(torch.float16) logits=denoiser(noisy_batch,attention_mask=attention_mask).logits logits=shift_logits(logits) if self_align: with torch.no_grad(): with denoiser.disable_adapter(): # ref_model = denoiser # ref_model.eval() # print(type(ref_model)) # denoiser.eval() ref_logits=denoiser(noisy_batch,attention_mask=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits ref_logits=shift_logits(ref_logits) ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1) # denoiser.train() token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices] # print("token_loss_2",token_loss_2.shape) else: token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] losses = { # 'loss_1': token_loss_2.mean() * 0, 'loss': token_loss_2.mean(), } return losses def compute_normal_loss( input_ids, denoiser, question_length, mask_id, block_size, enable_shift, share_steps, self_align, feature_align, self_step, eos_id, ): B, L = input_ids.shape noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id) token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L) prompt_mask = (token_positions < question_length.unsqueeze(1)) noisy_batch[prompt_mask] = input_ids[prompt_mask] # prompt_mask = prompt_mask.to(torch.int64) noisy_batch = noisy_batch.to(denoiser.device) logits=denoiser(noisy_batch).logits logits=shift_logits(logits) token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] losses = { # 'loss_1': token_loss_2.mean() * 0, 'loss': token_loss_2.mean(), } return losses import torch def compute_llada_loss( input_ids, denoiser, question_length, mask_id, block_size, enable_shift, share_steps, self_align, feature_align, self_step, eos_id, ): mask_id=126336 B, L = input_ids.shape noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id) token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L) prompt_mask = (token_positions < question_length.unsqueeze(1)) noisy_batch[prompt_mask] = input_ids[prompt_mask] # prompt_mask = prompt_mask.to(torch.int64) noisy_batch = noisy_batch.to(denoiser.device) # print(noisy_batch) import ipdb; ipdb.set_trace() attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device) attention_mask=attention_mask.to(torch.float16) # print(type(denoiser),noisy_batch.shape,attention_mask.shape) logits=denoiser(noisy_batch,attention_bias=attention_mask).logits # logits=shift_logits(logits) if self_align: with torch.no_grad(): with denoiser.disable_adapter(): # ref_model = denoiser # ref_model.eval() # print(type(ref_model)) ref_logits=denoiser(noisy_batch,attention_bias=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits # ref_logits=shift_logits(ref_logits) ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1) token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices] # print("token_loss_2",token_loss_2.shape) else: token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] losses = { # 'loss_1': token_loss_2.mean() * 0, 'loss': token_loss_2.mean(), } return losses def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None): B,seq_len= input_ids.shape # 初始化为全 -inf attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device) # 1. Prompt部分:每个token可以注意整个prompt for i in range(B): attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt # 2. 块划分:从 prompt_length 开始划分 block num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size for b in range(num_blocks): block_start = prompt_length[i] + b * block_size # print(block_start,block_size,seq_len) block_end = min(block_start + block_size, seq_len) # 块内全注意 attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0 # 块之间因果注意(只能看前面块) for prev_b in range(b): prev_start = prompt_length[i] + prev_b * block_size prev_end = min(prev_start + block_size, seq_len) # 当前块可以看前面块 attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0 return attn_mask # [seq_len, seq_len], float, 0.0 for allowed, -inf for disallowed if __name__ == "__main__": seq_len = 10 input_ids = torch.randint(0, 100, (2, seq_len)) # 示例输入 block_size = 4 prompt_length = torch.tensor([2, 4]) # 示例prompt长度 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') attn_mask = build_custom_float_attention_mask(input_ids, prompt_length, block_size, device) print(attn_mask)