Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
from utils.util import forward_process_length, shift_logits,forward_process
import torch.nn.functional as F
def compute_loss_by_config(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
config
):
"""Select different loss functions based on config file"""
training_mode = config.get('training_mode', 'dream')
if training_mode == 'llada':
import ipdb; ipdb.set_trace()
return compute_llada_loss(
input_ids, denoiser, question_length, mask_id, block_size,
enable_shift, share_steps, self_align, feature_align, self_step, eos_id
)
elif training_mode == 'dream':
return compute_loss(
input_ids, denoiser, question_length, mask_id, block_size,
enable_shift, share_steps, self_align, feature_align, self_step, eos_id
)
else:
raise ValueError(f"Unsupported training mode: {training_mode}")
def compute_loss(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
):
B, L = input_ids.shape
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
prompt_mask = (token_positions < question_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# prompt_mask = prompt_mask.to(torch.int64)
noisy_batch = noisy_batch.to(denoiser.device)
attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device)
attention_mask=attention_mask.to(torch.float16)
logits=denoiser(noisy_batch,attention_mask=attention_mask).logits
logits=shift_logits(logits)
if self_align:
with torch.no_grad():
with denoiser.disable_adapter():
# ref_model = denoiser
# ref_model.eval()
# print(type(ref_model))
# denoiser.eval()
ref_logits=denoiser(noisy_batch,attention_mask=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits
ref_logits=shift_logits(ref_logits)
ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1)
# denoiser.train()
token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices]
# print("token_loss_2",token_loss_2.shape)
else:
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
losses = {
# 'loss_1': token_loss_2.mean() * 0,
'loss': token_loss_2.mean(),
}
return losses
def compute_normal_loss(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
):
B, L = input_ids.shape
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
prompt_mask = (token_positions < question_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# prompt_mask = prompt_mask.to(torch.int64)
noisy_batch = noisy_batch.to(denoiser.device)
logits=denoiser(noisy_batch).logits
logits=shift_logits(logits)
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
losses = {
# 'loss_1': token_loss_2.mean() * 0,
'loss': token_loss_2.mean(),
}
return losses
import torch
def compute_llada_loss(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
):
mask_id=126336
B, L = input_ids.shape
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
prompt_mask = (token_positions < question_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# prompt_mask = prompt_mask.to(torch.int64)
noisy_batch = noisy_batch.to(denoiser.device)
# print(noisy_batch)
import ipdb; ipdb.set_trace()
attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device)
attention_mask=attention_mask.to(torch.float16)
# print(type(denoiser),noisy_batch.shape,attention_mask.shape)
logits=denoiser(noisy_batch,attention_bias=attention_mask).logits
# logits=shift_logits(logits)
if self_align:
with torch.no_grad():
with denoiser.disable_adapter():
# ref_model = denoiser
# ref_model.eval()
# print(type(ref_model))
ref_logits=denoiser(noisy_batch,attention_bias=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits
# ref_logits=shift_logits(ref_logits)
ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1)
token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices]
# print("token_loss_2",token_loss_2.shape)
else:
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
losses = {
# 'loss_1': token_loss_2.mean() * 0,
'loss': token_loss_2.mean(),
}
return losses
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
B,seq_len= input_ids.shape
# 初始化为全 -inf
attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
# 1. Prompt部分:每个token可以注意整个prompt
for i in range(B):
attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt
# 2. 块划分:从 prompt_length 开始划分 block
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
for b in range(num_blocks):
block_start = prompt_length[i] + b * block_size
# print(block_start,block_size,seq_len)
block_end = min(block_start + block_size, seq_len)
# 块内全注意
attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0
# 块之间因果注意(只能看前面块)
for prev_b in range(b):
prev_start = prompt_length[i] + prev_b * block_size
prev_end = min(prev_start + block_size, seq_len)
# 当前块可以看前面块
attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0
return attn_mask # [seq_len, seq_len], float, 0.0 for allowed, -inf for disallowed
if __name__ == "__main__":
seq_len = 10
input_ids = torch.randint(0, 100, (2, seq_len)) # 示例输入
block_size = 4
prompt_length = torch.tensor([2, 4]) # 示例prompt长度
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_mask = build_custom_float_attention_mask(input_ids, prompt_length, block_size, device)
print(attn_mask)