Spaces:
No application file
No application file
File size: 8,015 Bytes
4f2b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import torch
from utils.util import forward_process_length, shift_logits,forward_process
import torch.nn.functional as F
def compute_loss_by_config(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
config
):
"""Select different loss functions based on config file"""
training_mode = config.get('training_mode', 'dream')
if training_mode == 'llada':
return compute_llada_loss(
input_ids, denoiser, question_length, mask_id, block_size,
enable_shift, share_steps, self_align, feature_align, self_step, eos_id
)
elif training_mode == 'dream':
return compute_loss(
input_ids, denoiser, question_length, mask_id, block_size,
enable_shift, share_steps, self_align, feature_align, self_step, eos_id
)
else:
raise ValueError(f"Unsupported training mode: {training_mode}")
def compute_loss(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
):
B, L = input_ids.shape
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
prompt_mask = (token_positions < question_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# prompt_mask = prompt_mask.to(torch.int64)
noisy_batch = noisy_batch.to(denoiser.device)
attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device)
attention_mask=attention_mask.to(torch.float16)
logits=denoiser(noisy_batch,attention_mask=attention_mask).logits
logits=shift_logits(logits)
if self_align:
with torch.no_grad():
with denoiser.disable_adapter():
# ref_model = denoiser
# ref_model.eval()
# print(type(ref_model))
# denoiser.eval()
ref_logits=denoiser(noisy_batch,attention_mask=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits
ref_logits=shift_logits(ref_logits)
ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1)
# denoiser.train()
token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices]
# print("token_loss_2",token_loss_2.shape)
else:
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
losses = {
# 'loss_1': token_loss_2.mean() * 0,
'loss': token_loss_2.mean(),
}
return losses
def compute_normal_loss(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
):
B, L = input_ids.shape
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
prompt_mask = (token_positions < question_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# prompt_mask = prompt_mask.to(torch.int64)
noisy_batch = noisy_batch.to(denoiser.device)
logits=denoiser(noisy_batch).logits
logits=shift_logits(logits)
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
losses = {
# 'loss_1': token_loss_2.mean() * 0,
'loss': token_loss_2.mean(),
}
return losses
import torch
def compute_llada_loss(
input_ids,
denoiser,
question_length,
mask_id,
block_size,
enable_shift,
share_steps,
self_align,
feature_align,
self_step,
eos_id,
):
mask_id=126336
B, L = input_ids.shape
noisy_batch, masked_indices, p_mask = forward_process_length(input_ids, mask_id=mask_id,prompt_lengths=question_length, block_size=block_size,eos_id=eos_id)
token_positions = torch.arange(L, device=noisy_batch.device).expand(B, L)
prompt_mask = (token_positions < question_length.unsqueeze(1))
noisy_batch[prompt_mask] = input_ids[prompt_mask]
# prompt_mask = prompt_mask.to(torch.int64)
noisy_batch = noisy_batch.to(denoiser.device)
# print(noisy_batch)
attention_mask=build_custom_float_attention_mask(noisy_batch, question_length, block_size, device=noisy_batch.device)
attention_mask=attention_mask.to(torch.float16)
# print(type(denoiser),noisy_batch.shape,attention_mask.shape)
logits=denoiser(noisy_batch,attention_bias=attention_mask).logits
# logits=shift_logits(logits)
if self_align:
with torch.no_grad():
with denoiser.disable_adapter():
# ref_model = denoiser
# ref_model.eval()
# print(type(ref_model))
ref_logits=denoiser(noisy_batch,attention_bias=torch.zeros([1,1,noisy_batch.shape[1],noisy_batch.shape[1]],dtype=torch.float16,device=denoiser.device)).logits
# ref_logits=shift_logits(ref_logits)
ref_logits = torch.nn.functional.softmax(ref_logits, dim=-1)
token_loss_2 = F.cross_entropy(logits[masked_indices], ref_logits[masked_indices], reduction='none') / p_mask[masked_indices]
# print("token_loss_2",token_loss_2.shape)
else:
token_loss_2= F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
losses = {
# 'loss_1': token_loss_2.mean() * 0,
'loss': token_loss_2.mean(),
}
return losses
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
B,seq_len= input_ids.shape
# 初始化为全 -inf
attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
# 1. Prompt部分:每个token可以注意整个prompt
for i in range(B):
attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt
# 2. 块划分:从 prompt_length 开始划分 block
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
for b in range(num_blocks):
block_start = prompt_length[i] + b * block_size
# print(block_start,block_size,seq_len)
block_end = min(block_start + block_size, seq_len)
# 块内全注意
attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0
# 块之间因果注意(只能看前面块)
for prev_b in range(b):
prev_start = prompt_length[i] + prev_b * block_size
prev_end = min(prev_start + block_size, seq_len)
# 当前块可以看前面块
attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0
return attn_mask # [seq_len, seq_len], float, 0.0 for allowed, -inf for disallowed
if __name__ == "__main__":
seq_len = 10
input_ids = torch.randint(0, 100, (2, seq_len)) # 示例输入
block_size = 4
prompt_length = torch.tensor([2, 4]) # 示例prompt长度
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_mask = build_custom_float_attention_mask(input_ids, prompt_length, block_size, device)
print(attn_mask) |