pbansal commited on
Commit
c2f1c2b
·
verified ·
1 Parent(s): 59275df

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +178 -0
generate.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+
8
+ def add_gumbel_noise(logits, temperature):
9
+ '''
10
+ The Gumbel max is a method for sampling categorical distributions.
11
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
12
+ Thus, we use float64.
13
+ '''
14
+ if temperature == 0:
15
+ return logits
16
+ logits = logits.to(torch.float64)
17
+ noise = torch.rand_like(logits, dtype=torch.float64)
18
+ gumbel_noise = (- torch.log(noise)) ** temperature
19
+ return logits.exp() / gumbel_noise
20
+
21
+
22
+ def get_num_transfer_tokens(mask_index, steps):
23
+ '''
24
+ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
25
+ Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
26
+ the expected number of tokens transitioned at each step should be consistent.
27
+
28
+ This function is designed to precompute the number of tokens that need to be transitioned at each step.
29
+ '''
30
+ mask_num = mask_index.sum(dim=1, keepdim=True)
31
+
32
+ base = mask_num // steps
33
+ remainder = mask_num % steps
34
+
35
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
36
+
37
+ for i in range(mask_num.size(0)):
38
+ num_transfer_tokens[i, :remainder[i]] += 1
39
+
40
+ return num_transfer_tokens
41
+
42
+
43
+ @ torch.no_grad()
44
+ def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0.,
45
+ cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False, use_adjust=False):
46
+ '''
47
+ Args:
48
+ model: Mask predictor.
49
+ prompt: A tensor of shape (1, L).
50
+ steps: Sampling steps, less than or equal to gen_length.
51
+ gen_length: Generated answer length.
52
+ block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
53
+ temperature: Categorical distribution sampling temperature.
54
+ cfg_scale: Unsupervised classifier-free guidance scale.
55
+ remasking: Remasking strategy. 'low_confidence' or 'random'.
56
+ mask_id: The toke id of [MASK] is 126336.
57
+ logits_eos_inf: Whether to set the logits of EOS token to -inf. See Appendix B.4 of LLaDA for details
58
+ confidence_eos_eot_inf: Whether to set the confidence of EOS and EoT token to -inf. See Appendix B.4 of LLaDA for details
59
+ '''
60
+ x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
61
+ x[:, :prompt.shape[1]] = prompt.clone()
62
+
63
+ if attention_mask is not None:
64
+ attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
65
+
66
+ prompt_index = (x != mask_id)
67
+
68
+ assert gen_length % block_length == 0
69
+ num_blocks = gen_length // block_length
70
+
71
+ assert steps % num_blocks == 0
72
+ steps = steps // num_blocks
73
+
74
+ for num_block in range(num_blocks):
75
+ block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
76
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
77
+ for i in range(steps):
78
+ max_adjust_steps = num_transfer_tokens[:, i].max().item() if use_adjust else 1
79
+ mask_index = (x == mask_id)
80
+ if cfg_scale > 0.:
81
+ un_x = x.clone()
82
+ un_x[prompt_index] = mask_id
83
+ x_ = torch.cat([x, un_x], dim=0)
84
+ if attention_mask is not None:
85
+ attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
86
+ outputs = model(x_, attention_mask=attention_mask_, output_hidden_states=True)
87
+ logits = outputs.logits
88
+ hidden_states = outputs.hidden_states[-1]
89
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
90
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
91
+ else:
92
+ outputs = model(x, attention_mask=attention_mask, output_hidden_states=True)
93
+ logits = outputs.logits
94
+ hidden_states = outputs.hidden_states[-1]
95
+
96
+ for adjust_step in range(max_adjust_steps):
97
+ if logits_eos_inf:
98
+ logits[:, :, 126081] = -torch.inf
99
+
100
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
101
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
102
+
103
+ if confidence_eos_eot_inf:
104
+ logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf
105
+
106
+ if remasking == 'low_confidence':
107
+ p = F.softmax(logits, dim=-1)
108
+ x0_p = torch.squeeze(
109
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
110
+ elif remasking == 'random':
111
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
112
+ else:
113
+ raise NotImplementedError(remasking)
114
+
115
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
116
+
117
+ x0 = torch.where(mask_index, x0, x)
118
+ confidence = torch.where(mask_index, x0_p, -np.inf)
119
+
120
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
121
+ for j in range(confidence.shape[0]):
122
+ k = num_transfer_tokens[j, i] if (not use_adjust) else int(num_transfer_tokens[j, i]>adjust_step)
123
+ if (k>0):
124
+ _, select_index = torch.topk(confidence[j], k=k)
125
+ transfer_index[j, select_index] = True
126
+ x[transfer_index] = x0[transfer_index]
127
+ if use_adjust:
128
+ assert cfg_scale == 0.
129
+ mask_index = (x == mask_id)
130
+ hidden_states, logits = model.drafter_forward(x, hidden_states)
131
+ return x
132
+
133
+
134
+ def main():
135
+ device = 'cuda'
136
+
137
+ # model_name = 'GSAI-ML/LLaDA-8B-Base'
138
+ model_name = 'pbansal/LLaDA-8B-Base-Adjust'
139
+
140
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
141
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
142
+
143
+ # The LLaDA architecture theoretically supports both left-padding and right-padding.
144
+ # However, the sampling code implementation is simpler with left-padding.
145
+ if tokenizer.padding_side != 'left':
146
+ tokenizer.padding_side = 'left'
147
+
148
+ # If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
149
+ assert tokenizer.pad_token_id != 126336
150
+
151
+ prompts = [ "Question : Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?\n\nAnswer :",
152
+ "Question : Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?\n\nAnswer :",
153
+ "Question : Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?\n\nAnswer :"]
154
+
155
+ # Add special tokens for the Instruct model. The Base model does not require the following two lines.
156
+ # messages = [{"role": "user", "content": prompt} for prompt in prompts]
157
+ # prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages]
158
+
159
+ encoded_outputs = tokenizer(
160
+ prompts,
161
+ add_special_tokens=False,
162
+ padding=True,
163
+ return_tensors="pt"
164
+ )
165
+ input_ids = encoded_outputs['input_ids'].to(device)
166
+ attention_mask = encoded_outputs['attention_mask'].to(device)
167
+
168
+ length = 64
169
+ steps = int(length/8)
170
+ out = generate(model, input_ids, attention_mask, steps=steps, gen_length=length, block_length=length, temperature=1.0, cfg_scale=0., remasking='low_confidence',
171
+ use_adjust=True)
172
+ output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)
173
+ for o in output:
174
+ print(o)
175
+ print('-' * 50)
176
+
177
+ if __name__ == '__main__':
178
+ main()