Spaces:
No application file
No application file
File size: 6,635 Bytes
4f2b2f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import torch.nn.functional as F
import torch.distributions as dists
from peft import PeftModel, PeftConfig
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
B,seq_len= input_ids.shape
# 初始化为全 -inf
attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
# 1. Prompt部分:每个token可以注意整个prompt
for i in range(B):
attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt
# 2. 块划分:从 prompt_length 开始划分 block
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
for b in range(num_blocks):
block_start = prompt_length[i] + b * block_size
# print(block_start,block_size,seq_len)
block_end = min(block_start + block_size, seq_len)
# 块内全注意
attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0
# 块之间因果注意(只能看前面块)
for prev_b in range(b):
prev_start = prompt_length[i] + prev_b * block_size
prev_end = min(prev_start + block_size, seq_len)
# 当前块可以看前面块
attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0
return attn_mask
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
# Extract top1 and top2 probabilities
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
# Calculate confidence as top1 - top2
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
# def generate(model,prompt,block_size,max_length,mask_id):
# def generate(model, prompt, block_size, max_length, mask_id, eos_token_id=None):
# device = prompt.device
# output = prompt.clone()
# while output.shape[1] < max_length:
# # 添加一个 block 的 mask
# mask_block = torch.full((1, block_size), mask_id, dtype=torch.long, device=device)
# input_ids = torch.cat([output, mask_block], dim=1)
# attention_mask = build_custom_float_attention_mask(input_ids, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
# attention_mask = attention_mask.to(torch.bfloat16)
# for i in range(block_size):
def generate_block(denoiser, block_size, mask_id,tokenizer,device):
denoiser.eval()
question = 'please give me a code about transformer model'
# prompt = tokenizer(question)['input_ids']
# prompt = torch.tensor(prompt).to(accelerator.device).unsqueeze(0)
messages = [
{"role": "user", "content": question}
]
prompt = tokenizer.apply_chat_template(
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
).input_ids
prompt = prompt.to(device)
mask_id = 151666
gen_len = (384 - prompt.shape[1])//block_size
print(gen_len)
temperature = 0.2
top_p = 0.95
with torch.inference_mode():
for i in range(gen_len):
if i==0:
x_t = torch.cat([prompt, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
else:
x_t = torch.cat([x_t, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
attention_mask = build_custom_float_attention_mask(x_t, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
attention_mask = attention_mask.to(torch.bfloat16)
for n in range(block_size):
mask_index = (x_t == mask_id)
if mask_index.sum() == 0:
break
logits =denoiser(x_t, attention_mask=attention_mask).logits
logits = shift_logits(logits)
mask_logits = logits[mask_index]
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=None, neg_entropy=True)
number_transfer_tokens = 1
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + mask_id
x0_[transfer_index] = x0[transfer_index].clone()
x_t[mask_index] = x0_
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
print(answer)
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
print(answer)
if __name__ == "__main__":
config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora")
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
lora_model = PeftModel.from_pretrained(model, "ybelkada/opt-350m-lora") |