Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
import torch.nn.functional as F
import torch.distributions as dists
from peft import PeftModel, PeftConfig
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
B,seq_len= input_ids.shape
# 初始化为全 -inf
attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
# 1. Prompt部分:每个token可以注意整个prompt
for i in range(B):
attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt
# 2. 块划分:从 prompt_length 开始划分 block
num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
for b in range(num_blocks):
block_start = prompt_length[i] + b * block_size
# print(block_start,block_size,seq_len)
block_end = min(block_start + block_size, seq_len)
# 块内全注意
attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0
# 块之间因果注意(只能看前面块)
for prev_b in range(b):
prev_start = prompt_length[i] + prev_b * block_size
prev_end = min(prev_start + block_size, seq_len)
# 当前块可以看前面块
attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0
return attn_mask
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
# Extract top1 and top2 probabilities
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
# Calculate confidence as top1 - top2
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
# def generate(model,prompt,block_size,max_length,mask_id):
# def generate(model, prompt, block_size, max_length, mask_id, eos_token_id=None):
# device = prompt.device
# output = prompt.clone()
# while output.shape[1] < max_length:
# # 添加一个 block 的 mask
# mask_block = torch.full((1, block_size), mask_id, dtype=torch.long, device=device)
# input_ids = torch.cat([output, mask_block], dim=1)
# attention_mask = build_custom_float_attention_mask(input_ids, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
# attention_mask = attention_mask.to(torch.bfloat16)
# for i in range(block_size):
def generate_block(denoiser, block_size, mask_id,tokenizer,device):
denoiser.eval()
question = 'please give me a code about transformer model'
# prompt = tokenizer(question)['input_ids']
# prompt = torch.tensor(prompt).to(accelerator.device).unsqueeze(0)
messages = [
{"role": "user", "content": question}
]
prompt = tokenizer.apply_chat_template(
messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
).input_ids
prompt = prompt.to(device)
mask_id = 151666
gen_len = (384 - prompt.shape[1])//block_size
print(gen_len)
temperature = 0.2
top_p = 0.95
with torch.inference_mode():
for i in range(gen_len):
if i==0:
x_t = torch.cat([prompt, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
else:
x_t = torch.cat([x_t, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
attention_mask = build_custom_float_attention_mask(x_t, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
attention_mask = attention_mask.to(torch.bfloat16)
for n in range(block_size):
mask_index = (x_t == mask_id)
if mask_index.sum() == 0:
break
logits =denoiser(x_t, attention_mask=attention_mask).logits
logits = shift_logits(logits)
mask_logits = logits[mask_index]
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=None, neg_entropy=True)
number_transfer_tokens = 1
_, transfer_index = torch.topk(confidence, number_transfer_tokens)
x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + mask_id
x0_[transfer_index] = x0[transfer_index].clone()
x_t[mask_index] = x0_
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
print(answer)
answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
print(answer)
if __name__ == "__main__":
config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora")
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
lora_model = PeftModel.from_pretrained(model, "ybelkada/opt-350m-lora")