Spaces:
No application file
No application file
File size: 6,159 Bytes
4f2b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from torch.distributions import Uniform
def forward_process_block_fixed_p(x, mask_id, p_mask):
B, L = x.shape
if isinstance(p_mask, float):
p_mask = torch.full((B, 1), p_mask, device=x.device)
elif p_mask.ndim == 1:
p_mask = p_mask[:, None]
rand = torch.rand((B, L), device=x.device)
mask = rand < p_mask
x_masked = torch.where(mask, mask_id, x)
return x_masked, mask
import torch
def generate_monotonic_pmasks(batch_size, max_blocks, device):
"""
生成 shape (B, max_blocks) 的单调非降随机序列,每行第一个元素在[0,1]随机,后续不小于前一个
"""
# 第一个block p_mask随机
p0 = torch.rand(batch_size, 1, device=device)/2+0.2
# print(p0)
# 后续blocks生成增量 [0, 1],加起来保证不超过1(之后用 clamp)
increments = torch.rand(batch_size, max_blocks - 1, device=device) * (0.7 - p0)/ (max_blocks - 1)
# print(increments)
# 逐元素累加,保证非降
cum_increments = torch.cumsum(increments, dim=1)
# print(cum_increments)
# 总 p_mask = p0 + 累积增量,保证不超过1
p_masks = torch.cat([p0, p0 + cum_increments], dim=1)
p_masks = torch.clamp(p_masks, max=1.0)
# print(p_masks)
return p_masks # (B, max_blocks)
def forward_process_length(input_ids, mask_id, block_size, prompt_lengths,eos_id=None):
"""
Args:
input_ids: (B, L)
prompt_lengths: (B,)
Returns:
noisy_batch, masked_indices, p_mask_tensor
"""
B, L = input_ids.shape
device = input_ids.device
noisy_batch = input_ids.clone()
eos_indices= (input_ids==eos_id)
masked_indices = torch.zeros_like(input_ids,dtype=torch.bool)
p_mask_tensor = torch.zeros((B, L), device=device)
# 计算每个样本block数
non_prompt_lens = L - prompt_lengths
full_blocks = non_prompt_lens // block_size
remainders = non_prompt_lens % block_size
total_blocks = full_blocks + (remainders > 0).long()
max_blocks = total_blocks.max().item()
# 生成每个样本block的mask比率,单调非降且第一个随机
p_masks = generate_monotonic_pmasks(B, max_blocks, device) # shape (B, max_blocks)
for i in range(B):
prompt_len = prompt_lengths[i].item()
num_blocks = total_blocks[i].item()
start_block = torch.tensor([0]) # 随机选择一个block开始
for block_idx in range(num_blocks):
if block_idx < start_block:
continue
start = prompt_len + block_idx * block_size
end = min(start + block_size, L)
p_block = p_masks[i, block_idx-start_block].item()
block = noisy_batch[i, start:end].unsqueeze(0)
masked_block, mask = forward_process_block_fixed_p(block, mask_id, p_block)
noisy_batch[i, start:end] = masked_block.squeeze(0)
masked_indices[i, start:end] = mask.squeeze(0)
# if torch.all(input_ids[i, start:end] == eos_id):
# masked_indices[i,start:end]== False
# print("1")
p_mask_tensor[i, start:end] = p_block
return noisy_batch, masked_indices, p_mask_tensor
# def forward_process_length(input_ids, mask_id, block_size, prompt_lengths, p_min=0.2, p_max=0.9):
# """
# 返回每个 token 的实际 mask 概率 tensor(非prompt区域),其余为0。
# """
# B, L = input_ids.shape
# device = input_ids.device
# noisy_batch = input_ids.clone()
# masked_indices = torch.zeros_like(input_ids, dtype=torch.bool)
# p_mask_tensor = torch.zeros((B, L), device=device) # 最终返回值
# for i in range(B):
# prompt_len = prompt_lengths[i].item()
# non_prompt_len = L - prompt_len
# full_blocks = non_prompt_len // block_size
# remainder = non_prompt_len % block_size
# total_blocks = full_blocks + (1 if remainder > 0 else 0)
# for block_idx in range(total_blocks):
# start = prompt_len + block_idx * block_size
# end = min(start + block_size, L)
# # block的 mask 概率(线性递增)
# if total_blocks > 1:
# p_block = p_min + (p_max - p_min) * (block_idx / (total_blocks - 1))
# else:
# p_block = p_max
# block = noisy_batch[i, start:end].unsqueeze(0)
# masked_block, mask = forward_process_block_fixed_p(block, mask_id, p_block)
# noisy_batch[i, start:end] = masked_block.squeeze(0)
# masked_indices[i, start:end] = mask.squeeze(0)
# # 记录 p_mask 到 tensor 中
# p_mask_tensor[i, start:end] = p_block
# return noisy_batch, masked_indices, p_mask_tensor
def forward_process(input_ids,mask_id ,t_max=1.0, eps=1e-4):
B, L = input_ids.shape
# t = torch.rand(B, device=input_ids.device)
dist = Uniform(0., t_max)
t = dist.sample((B,)).to(input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, L)
masked_indices = torch.rand((B, L), device=input_ids.device) < p_mask
noisy_batch = torch.where(masked_indices, mask_id, input_ids)
return noisy_batch, masked_indices, p_mask
def flatten_dict(d, parent_key='', sep='_'):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def shift_logits(logits):
shifted_logits = torch.zeros_like(logits)
shifted_logits[:, 1:, :] = logits[:, :-1, :]
shifted_logits[:, 0, :] = 1.0
return shifted_logits
if __name__ == '__main__':
input_ids= torch.tensor([[1,5,4,3,25,6,7,9,5,8,7,6],[1,3,8,9,7,34,6,9,5,8,7,6]])
mask_id=0
block_size=3
prompt_length=torch.tensor([2,1])
noisy_batch, masked_indices,p_mask = forward_process_length(input_ids, mask_id, block_size, prompt_length)
print("noisy_batch:", noisy_batch)
print("masked_indices:", masked_indices)
print("p_mask:", p_mask)
|