Spaces:
No application file
No application file
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributions as dists | |
| from peft import PeftModel, PeftConfig | |
| def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None): | |
| B,seq_len= input_ids.shape | |
| # 初始化为全 -inf | |
| attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device) | |
| # 1. Prompt部分:每个token可以注意整个prompt | |
| for i in range(B): | |
| attn_mask[i,:,:,:prompt_length[i]] = 0.0 # 允许所有 token 看 prompt | |
| # 2. 块划分:从 prompt_length 开始划分 block | |
| num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size | |
| for b in range(num_blocks): | |
| block_start = prompt_length[i] + b * block_size | |
| # print(block_start,block_size,seq_len) | |
| block_end = min(block_start + block_size, seq_len) | |
| # 块内全注意 | |
| attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0 | |
| # 块之间因果注意(只能看前面块) | |
| for prev_b in range(b): | |
| prev_start = prompt_length[i] + prev_b * block_size | |
| prev_end = min(prev_start + block_size, seq_len) | |
| # 当前块可以看前面块 | |
| attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0 | |
| return attn_mask | |
| def top_p_logits(logits, top_p=None): | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) | |
| mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) | |
| logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) | |
| return logits | |
| def top_k_logits(logits, top_k=None): | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) | |
| return logits | |
| def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): | |
| if temperature > 0: | |
| logits = logits / temperature | |
| if top_p is not None and top_p < 1: | |
| logits = top_p_logits(logits, top_p) | |
| if top_k is not None: | |
| logits = top_k_logits(logits, top_k) | |
| probs = torch.softmax(logits, dim=-1) | |
| if temperature > 0: | |
| try: | |
| x0 = dists.Categorical(probs=probs).sample() | |
| confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) | |
| except: | |
| confidence, x0 = probs.max(dim=-1) | |
| else: | |
| confidence, x0 = probs.max(dim=-1) | |
| if margin_confidence: | |
| sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) | |
| # Extract top1 and top2 probabilities | |
| top1_probs = sorted_probs[:, 0] | |
| top2_probs = sorted_probs[:, 1] | |
| # Calculate confidence as top1 - top2 | |
| confidence = top1_probs - top2_probs | |
| if neg_entropy: | |
| epsilon = 1e-10 | |
| log_probs = torch.log(probs + epsilon) | |
| confidence = torch.sum(probs * log_probs, dim=-1) | |
| return confidence, x0 | |
| # def generate(model,prompt,block_size,max_length,mask_id): | |
| # def generate(model, prompt, block_size, max_length, mask_id, eos_token_id=None): | |
| # device = prompt.device | |
| # output = prompt.clone() | |
| # while output.shape[1] < max_length: | |
| # # 添加一个 block 的 mask | |
| # mask_block = torch.full((1, block_size), mask_id, dtype=torch.long, device=device) | |
| # input_ids = torch.cat([output, mask_block], dim=1) | |
| # attention_mask = build_custom_float_attention_mask(input_ids, torch.tensor([[prompt.shape[1]]]), block_size, device=device) | |
| # attention_mask = attention_mask.to(torch.bfloat16) | |
| # for i in range(block_size): | |
| def generate_block(denoiser, block_size, mask_id,tokenizer,device): | |
| denoiser.eval() | |
| question = 'please give me a code about transformer model' | |
| # prompt = tokenizer(question)['input_ids'] | |
| # prompt = torch.tensor(prompt).to(accelerator.device).unsqueeze(0) | |
| messages = [ | |
| {"role": "user", "content": question} | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", return_dict=True, add_generation_prompt=True | |
| ).input_ids | |
| prompt = prompt.to(device) | |
| mask_id = 151666 | |
| gen_len = (384 - prompt.shape[1])//block_size | |
| print(gen_len) | |
| temperature = 0.2 | |
| top_p = 0.95 | |
| with torch.inference_mode(): | |
| for i in range(gen_len): | |
| if i==0: | |
| x_t = torch.cat([prompt, torch.tensor([[mask_id]*block_size]).to(device)], dim=1) | |
| else: | |
| x_t = torch.cat([x_t, torch.tensor([[mask_id]*block_size]).to(device)], dim=1) | |
| attention_mask = build_custom_float_attention_mask(x_t, torch.tensor([[prompt.shape[1]]]), block_size, device=device) | |
| attention_mask = attention_mask.to(torch.bfloat16) | |
| for n in range(block_size): | |
| mask_index = (x_t == mask_id) | |
| if mask_index.sum() == 0: | |
| break | |
| logits =denoiser(x_t, attention_mask=attention_mask).logits | |
| logits = shift_logits(logits) | |
| mask_logits = logits[mask_index] | |
| confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=None, neg_entropy=True) | |
| number_transfer_tokens = 1 | |
| _, transfer_index = torch.topk(confidence, number_transfer_tokens) | |
| x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + mask_id | |
| x0_[transfer_index] = x0[transfer_index].clone() | |
| x_t[mask_index] = x0_ | |
| answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0] | |
| print(answer) | |
| answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0] | |
| print(answer) | |
| if __name__ == "__main__": | |
| config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora") | |
| model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path) | |
| lora_model = PeftModel.from_pretrained(model, "ybelkada/opt-350m-lora") |