File size: 6,635 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn.functional as F
import torch.distributions as dists
from peft import PeftModel, PeftConfig
def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None):
    B,seq_len= input_ids.shape
    # 初始化为全 -inf
    attn_mask = torch.full((B,1,seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
    # 1. Prompt部分:每个token可以注意整个prompt
    for i in range(B):
        attn_mask[i,:,:,:prompt_length[i]] = 0.0  # 允许所有 token 看 prompt

        # 2. 块划分:从 prompt_length 开始划分 block
        num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size

        for b in range(num_blocks):
            block_start = prompt_length[i] + b * block_size
            # print(block_start,block_size,seq_len)
            block_end = min(block_start + block_size, seq_len)

            # 块内全注意
            attn_mask[i,:,block_start:block_end, block_start:block_end] = 0.0

            # 块之间因果注意(只能看前面块)
            for prev_b in range(b):
                prev_start = prompt_length[i] + prev_b * block_size
                prev_end = min(prev_start + block_size, seq_len)

                # 当前块可以看前面块
                attn_mask[i,:,block_start:block_end, prev_start:prev_end] = 0.0

    return attn_mask 
def top_p_logits(logits, top_p=None):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    return logits

def top_k_logits(logits, top_k=None):
    top_k = min(top_k, logits.size(-1))  # Safety check
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
    return logits

def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
    if temperature > 0:
        logits = logits / temperature
    if top_p is not None and top_p < 1:
        logits = top_p_logits(logits, top_p)
    if top_k is not None:
        logits = top_k_logits(logits, top_k)
    probs = torch.softmax(logits, dim=-1)

    if temperature > 0:
        try:
            x0 = dists.Categorical(probs=probs).sample()
            confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
        except:
            confidence, x0 = probs.max(dim=-1)
    else:
        confidence, x0 = probs.max(dim=-1)
    
    if margin_confidence:
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        # Extract top1 and top2 probabilities
        top1_probs = sorted_probs[:, 0] 
        top2_probs = sorted_probs[:, 1] 
        # Calculate confidence as top1 - top2
        confidence = top1_probs - top2_probs 
    
    if neg_entropy:
        epsilon = 1e-10
        log_probs = torch.log(probs + epsilon)
        confidence = torch.sum(probs * log_probs, dim=-1)
    
    return confidence, x0
# def generate(model,prompt,block_size,max_length,mask_id):
# def generate(model, prompt, block_size, max_length, mask_id, eos_token_id=None):
#     device = prompt.device
#     output = prompt.clone()

#     while output.shape[1] < max_length:
#         # 添加一个 block 的 mask
#         mask_block = torch.full((1, block_size), mask_id, dtype=torch.long, device=device)
#         input_ids = torch.cat([output, mask_block], dim=1)
#         attention_mask = build_custom_float_attention_mask(input_ids, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
#         attention_mask = attention_mask.to(torch.bfloat16)
#         for i in range(block_size):
def generate_block(denoiser, block_size, mask_id,tokenizer,device):
    denoiser.eval()
    question = 'please give me a code about transformer model'
    # prompt = tokenizer(question)['input_ids']
    # prompt = torch.tensor(prompt).to(accelerator.device).unsqueeze(0)
    messages = [
        {"role": "user", "content": question}
    ]
    prompt = tokenizer.apply_chat_template(
        messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
    ).input_ids
    prompt = prompt.to(device)

    mask_id = 151666
    gen_len = (384 - prompt.shape[1])//block_size
    print(gen_len)
    temperature = 0.2
    top_p = 0.95
    with torch.inference_mode():
        for i in range(gen_len):
            if i==0:
                x_t = torch.cat([prompt, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
            else:
                x_t = torch.cat([x_t, torch.tensor([[mask_id]*block_size]).to(device)], dim=1)
            attention_mask = build_custom_float_attention_mask(x_t, torch.tensor([[prompt.shape[1]]]), block_size, device=device)
            attention_mask = attention_mask.to(torch.bfloat16)
            for n in range(block_size):
                mask_index = (x_t == mask_id)
                if mask_index.sum() == 0:
                    break
                logits =denoiser(x_t, attention_mask=attention_mask).logits
                logits = shift_logits(logits)
                mask_logits = logits[mask_index]
                confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=None, neg_entropy=True)
                number_transfer_tokens = 1
                _, transfer_index = torch.topk(confidence, number_transfer_tokens)
                x0_ = torch.zeros_like(x0, device=device, dtype=torch.long) + mask_id
                x0_[transfer_index] = x0[transfer_index].clone()
                x_t[mask_index] = x0_
            answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
            print(answer)
    answer = tokenizer.batch_decode(x_t[:, prompt.shape[1]:], skip_special_tokens=False)[0]
    print(answer)

if __name__ == "__main__":
    config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora")
    model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
    lora_model = PeftModel.from_pretrained(model, "ybelkada/opt-350m-lora")