Spaces:
No application file
No application file
| import torch | |
| from utils.util import forward_process_length, shift_logits,forward_process | |
| import torch.nn.functional as F | |
| def compute_loss_by_config( | |
| input_ids, | |
| denoiser, | |
| question_length, | |
| mask_id, | |
| block_size, | |
| enable_shift, | |
| share_steps, | |
| self_align, | |
| feature_align, | |
| self_step, | |
| eos_id, | |
| config | |
| ): | |
| """Select different loss functions based on config file""" | |
| training_mode = config.get('training_mode', 'dream') | |
| if training_mode == 'llada': | |
| import ipdb; ipdb.set_trace() | |
| return compute_llada_loss( | |
| input_ids, denoiser, question_length, mask_id, block_size, | |
| enable_shift, share_steps, self_align, feature_align, self_step, eos_id | |
| ) | |
| elif training_mode == 'dream': | |
| return compute_loss( | |
| input_ids, denoiser, question_length, mask_id, block_size, | |
| enable_shift, share_steps, self_align, feature_align, self_step, eos_id | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported training mode: {training_mode}") | |
| def compute_loss( | |
| input_ids, | |
| denoiser, | |
| question_length, | |
| mask_id, | |
| block_size, | |
| enable_shift, | |
| share_steps, | |
| self_align, | |
| feature_align, | |
| self_step, | |
| eos_id, | |
| ): | |
| B, L = input_ids.shape | |
| noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id) | |
| token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L) | |
| prompt_mask = (token_positions < question_length.unsqueeze(1)) | |
| noisy_batch[prompt_mask] = input_ids[prompt_mask] | |
| # prompt_mask = prompt_mask.to(torch.int64) | |
| noisy_batch = noisy_batch.to(denoiser.device) | |
| attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device) | |
| attention_mask=attention_mask.to(torch.float16) | |
| logits=denoiser(noisy_batch,attention_mask=attention_mask).logits | |
| logits=shift_logits(logits) | |
| if self_align: | |
| with torch.no_grad(): | |
| with denoiser.disable_adapter(): | |
| # ref_model = denoiser | |
| # ref_model.eval() | |
| # print(type(ref_model)) | |
| # denoiser.eval() | |
| ref_logits=denoiser(noisy_batch,attention_mask=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits | |
| ref_logits=shift_logits(ref_logits) | |
| ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1) | |
| # denoiser.train() | |
| token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices] | |
| # print("token_loss_2",token_loss_2.shape) | |
| else: | |
| token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] | |
| losses = { | |
| # 'loss_1': token_loss_2.mean() * 0, | |
| 'loss': token_loss_2.mean(), | |
| } | |
| return losses | |
| def compute_normal_loss( | |
| input_ids, | |
| denoiser, | |
| question_length, | |
| mask_id, | |
| block_size, | |
| enable_shift, | |
| share_steps, | |
| self_align, | |
| feature_align, | |
| self_step, | |
| eos_id, | |
| ): | |
| B, L = input_ids.shape | |
| noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id) | |
| token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L) | |
| prompt_mask = (token_positions < question_length.unsqueeze(1)) | |
| noisy_batch[prompt_mask] = input_ids[prompt_mask] | |
| # prompt_mask = prompt_mask.to(torch.int64) | |
| noisy_batch = noisy_batch.to(denoiser.device) | |
| logits=denoiser(noisy_batch).logits | |
| logits=shift_logits(logits) | |
| token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] | |
| losses = { | |
| # 'loss_1': token_loss_2.mean() * 0, | |
| 'loss': token_loss_2.mean(), | |
| } | |
| return losses | |
| import torch | |
| def compute_llada_loss( | |
| input_ids, | |
| denoiser, | |
| question_length, | |
| mask_id, | |
| block_size, | |
| enable_shift, | |
| share_steps, | |
| self_align, | |
| feature_align, | |
| self_step, | |
| eos_id, | |
| ): | |
| mask_id=126336 | |
| B, L = input_ids.shape | |
| noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id) | |
| token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L) | |
| prompt_mask = (token_positions < question_length.unsqueeze(1)) | |
| noisy_batch[prompt_mask] = input_ids[prompt_mask] | |
| # prompt_mask = prompt_mask.to(torch.int64) | |
| noisy_batch = noisy_batch.to(denoiser.device) | |
| # print(noisy_batch) | |
| import ipdb; ipdb.set_trace() | |
| attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device) | |
| attention_mask=attention_mask.to(torch.float16) | |
| # print(type(denoiser),noisy_batch.shape,attention_mask.shape) | |
| logits=denoiser(noisy_batch,attention_bias=attention_mask).logits | |
| # logits=shift_logits(logits) | |
| if self_align: | |
| with torch.no_grad(): | |
| with denoiser.disable_adapter(): | |
| # ref_model = denoiser | |
| # ref_model.eval() | |
| # print(type(ref_model)) | |
| ref_logits=denoiser(noisy_batch,attention_bias=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits | |
| # ref_logits=shift_logits(ref_logits) | |
| ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1) | |
| token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices] | |
| # print("token_loss_2",token_loss_2.shape) | |
| else: | |
| token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices] | |
| losses = { | |
| # 'loss_1': token_loss_2.mean() * 0, | |
| 'loss': token_loss_2.mean(), | |
| } | |
| return losses | |
| def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None): | |
| B,seq_len= input_ids.shape | |
| # 初始化为全 -inf | |
| attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device) | |
| # 1. Prompt部分:每个token可以注意整个prompt | |
| for i in range(B): | |
| attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt | |
| # 2. 块划分:从 prompt_length 开始划分 block | |
| num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size | |
| for b in range(num_blocks): | |
| block_start = prompt_length[i] + b * block_size | |
| # print(block_start,block_size,seq_len) | |
| block_end = min(block_start + block_size, seq_len) | |
| # 块内全注意 | |
| attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0 | |
| # 块之间因果注意(只能看前面块) | |
| for prev_b in range(b): | |
| prev_start = prompt_length[i] + prev_b * block_size | |
| prev_end = min(prev_start + block_size, seq_len) | |
| # 当前块可以看前面块 | |
| attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0 | |
| return attn_mask # [seq_len, seq_len], float, 0.0 for allowed, -inf for disallowed | |
| if __name__ == "__main__": | |
| seq_len = 10 | |
| input_ids = torch.randint(0, 100, (2, seq_len)) # 示例输入 | |
| block_size = 4 | |
| prompt_length = torch.tensor([2, 4]) # 示例prompt长度 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| attn_mask = build_custom_float_attention_mask(input_ids, prompt_length, block_size, device) | |
| print(attn_mask) |