lijiang commited on
Commit
0f53954
·
verified ·
1 Parent(s): ae01bed

Delete generate_from_llada.py

Browse files
Files changed (1) hide show
  1. generate_from_llada.py +0 -294
generate_from_llada.py DELETED
@@ -1,294 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import torch.nn.functional as F
4
-
5
- from transformers import AutoTokenizer, AutoModel
6
-
7
-
8
- def add_gumbel_noise(logits, temperature):
9
- '''
10
- The Gumbel max is a method for sampling categorical distributions.
11
- According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
12
- Thus, we use float64.
13
- '''
14
- if temperature == 0:
15
- return logits
16
- logits = logits.to(torch.float64)
17
- noise = torch.rand_like(logits, dtype=torch.float64)
18
- gumbel_noise = (- torch.log(noise)) ** temperature
19
- return logits.exp() / gumbel_noise
20
-
21
-
22
- def get_num_transfer_tokens(mask_index, steps):
23
- '''
24
- In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
25
- Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
26
- the expected number of tokens transitioned at each step should be consistent.
27
-
28
- This function is designed to precompute the number of tokens that need to be transitioned at each step.
29
- '''
30
- mask_num = mask_index.sum(dim=1, keepdim=True)
31
-
32
- base = mask_num // steps
33
- remainder = mask_num % steps
34
-
35
- num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
36
-
37
- for i in range(mask_num.size(0)):
38
- num_transfer_tokens[i, :remainder[i]] += 1
39
-
40
- return num_transfer_tokens
41
-
42
- def get_num_transfer_tokens_sch(mask_index, steps,schedule=None,schedule_kwargs=None):
43
- '''
44
- In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
45
- Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
46
- the expected number of tokens transitioned at each step should be consistent.
47
-
48
- This function is designed to precompute the number of tokens that need to be transitioned at each step.
49
- '''
50
- if schedule is None:
51
- return get_num_transfer_tokens(mask_index,steps)
52
- if schedule_kwargs is None:
53
- schedule_kwargs = {}
54
-
55
- mask_num = mask_index.sum(dim=1, keepdim=True)
56
- steps = int(min(steps,mask_num[0]))
57
- t = torch.linspace(0, 1, steps+1)
58
- # at least one sample per step
59
- if schedule =='logit_normal':
60
- sigmas = sigmoid_normal_cdf(t)
61
- elif schedule =='shift':
62
- sigmas = logit_normal_schedule(schedule_kwargs.get('shift',3),t)
63
- elif schedule == 'cosine':
64
- sigmas = cosine_schedule(t)
65
- else:
66
- sigmas = t
67
- sigmas = sigmas.to(mask_num.device)
68
- num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64)
69
-
70
- for i in range(mask_num.size(0)):
71
- # print(sigmas.shape)
72
- sigmas_sample = (sigmas*mask_num[i]).to(torch.int64)
73
- # print(sigmas_sample)
74
- sigmas_sample = sigmas_sample[1:]-sigmas_sample[:-1]
75
- # print(sigmas_sample)
76
- # fix detal
77
- sigmas_sample = torch.clamp(sigmas_sample,1,None) # should only increase
78
- delta = sigmas_sample.sum() - mask_num[i]
79
- # breakpoint()
80
- assert delta>=0
81
- j = 0
82
-
83
- while delta > 0:
84
- j = j % len(sigmas_sample)
85
- if sigmas_sample[j] == 1:
86
- j += 1
87
- continue
88
-
89
- delta -= 1
90
- sigmas_sample[j] -= 1
91
- j += 1
92
- # breakpoint()
93
- assert sigmas_sample.sum()==mask_num[i]
94
- num_transfer_tokens[i] = sigmas_sample#.to(torch.int64)
95
- return num_transfer_tokens.flip(-1)
96
-
97
- def linear(y):
98
- return y
99
-
100
- def cosine_schedule(x):
101
- """
102
- Cosine schedule mapping [0, 1] -> [1, 0]
103
- """
104
- x = np.clip(x, 0, 1)
105
- return 1-0.5 * (1 + np.cos(np.pi * x))
106
-
107
- def sigmoid_normal_cdf(y):
108
- # y must be in (0, 1)
109
- logit_y = torch.log(y / (1 - y))
110
- return 0.5 * (1 + torch.erf(logit_y / torch.sqrt(torch.tensor(2.0))))
111
- def logit_normal_schedule(shift,sigmas):
112
- # shift = 1 / shift
113
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
114
- return sigmas
115
- import os
116
- DEBUG_PRINT_OUTPUT = os.environ.get('DEBUG_PRINT_OUTPUT',False)
117
- @ torch.no_grad()
118
- def generate(model, prompt=None, steps=None, max_new_tokens=128, block_length=128, temperature=0.,
119
- cfg_scale=0., remasking='low_confidence', mask_id=126336,inputs_embeds=None, position_ids=None,attention_mask=None,
120
- tokenizer=None,
121
- verbose=False,
122
- step_per_block=None,
123
- prefix_lm=False,
124
- schedule=None,
125
- schedule_kwargs=None,
126
- draft_tokens=None,
127
- step_ratio=None,
128
- **kwargs):
129
- '''
130
- Args:
131
- model: Mask predictor.
132
- prompt: A tensor of shape (1, L).
133
- steps: Sampling steps, less than or equal to gen_length.
134
- gen_length: Generated answer length.
135
- block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
136
- temperature: Categorical distribution sampling temperature.
137
- cfg_scale: Unsupervised classifier-free guidance scale.
138
- remasking: Remasking strategy. 'low_confidence' or 'random'.
139
- mask_id: The toke id of [MASK] is 126336.
140
- '''
141
- # breakpoint()
142
- # remasking =
143
- # step_ratio = 0.5
144
- # block_length = 1024
145
- # steps = 1024
146
- steps = max_new_tokens # min(steps,max_new_tokens)
147
- # if step_ratio:
148
- # steps = int(max_new_tokens*step_ratio)
149
- gen_length = max_new_tokens
150
- assert position_ids is None
151
- if prompt is None:
152
- assert inputs_embeds is not None
153
- bsz, seq_len = inputs_embeds.shape[:2]
154
- prompt = torch.full((bsz, seq_len), 0, dtype=torch.long).to(model.device)
155
- past_key_values = None
156
- if prefix_lm:
157
- past_key_values = model(None,input_embeddings=inputs_embeds,use_cache=True).attn_key_values
158
- # breakpoint()
159
- x = torch.full((1, gen_length), mask_id, dtype=torch.long).to(model.device)
160
- prompt = torch.full((bsz, 0), 0, dtype=torch.long).to(model.device)
161
- # x[:, :prompt.shape[1]] = prompt.clone()
162
- else:
163
- x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
164
- x[:, :prompt.shape[1]] = prompt.clone()
165
-
166
- prompt_index = (x != mask_id)
167
- assert prompt.shape[0] == 1
168
- if draft_tokens is not None:
169
- assert draft_tokens.shape[1] <= gen_length
170
- x[:, prompt.shape[1]:prompt.shape[1]+draft_tokens.shape[1]] = draft_tokens.clone()
171
-
172
- # if block_length < gen_length:
173
- # block_length = gen_length
174
- assert gen_length % block_length == 0
175
- num_blocks = gen_length // block_length
176
-
177
- assert ( steps % num_blocks == 0) or step_per_block is not None
178
- steps = steps // num_blocks
179
- if step_per_block:
180
- steps = min(step_per_block,block_length)
181
- assert step_ratio is None, 'Please do not pass both step_ratio and step_per_block'
182
- # step_ratio = 0.5
183
- # schedule = 'shift'
184
- # schedule_kwargs = dict(shift=3)
185
- # breakpoint()
186
- if step_ratio:
187
- steps = int(steps*step_ratio)
188
-
189
- # print(steps,step_per_block,block_length,draft_tokens.shape[-1])
190
- # NFE = 0
191
- if verbose:
192
- history = []
193
- for num_block in range(num_blocks):
194
-
195
- block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
196
- num_transfer_tokens = get_num_transfer_tokens_sch(block_mask_index, steps,schedule=schedule,schedule_kwargs=schedule_kwargs)
197
- if DEBUG_PRINT_OUTPUT:
198
- print(f"Block: {num_block + 1}/{num_blocks}, Steps per Block: {steps}, Block Length: {block_length}")
199
- print(f"Tokens generated per step {num_transfer_tokens[0]}")
200
- for i in range(steps):
201
- # print(i)
202
- mask_index = (x == mask_id)
203
- # print(mask_index.sum())
204
- if mask_index.sum() == 0:
205
- continue
206
- # NFE += 2
207
- if cfg_scale > 0.:
208
- assert NotImplementedError('cfg_scale > 0. is not supported.')
209
- un_x = x.clone()
210
- un_x[prompt_index] = mask_id
211
- x_ = torch.cat([x, un_x], dim=0)
212
- #
213
- logits = model(x_,input_embeds_inference=[inputs_embeds,None]).logits
214
- logits, un_logits = torch.chunk(logits, 2, dim=0)
215
- logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
216
- else:
217
- inputs_embeds_curr = model.transformer.wte(x)
218
- #print(tokenizer.batch_decode(x)[0].replace('<|endoftext|>',''))
219
- # print((x==mask_id).sum())
220
- # breakpoint()
221
- if prefix_lm:
222
- # breakpoint()
223
- logits = model(None,input_embeddings=inputs_embeds_curr,past_key_values=past_key_values).logits
224
- else:
225
- if inputs_embeds is not None:
226
- inputs_embeds_curr[:,:inputs_embeds.shape[1]] = inputs_embeds
227
- logits = model(None,input_embeddings=inputs_embeds_curr).logits
228
- # logits = logits.cpu()
229
- logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
230
- x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
231
- # torch.cuda.empty_cache()
232
- # torch.cuda.synchronize()
233
- if remasking == 'low_confidence':
234
- p = F.softmax(logits.to(torch.float64), dim=-1)
235
- x0_p = torch.squeeze(
236
- torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
237
- elif remasking == 'random':
238
- x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
239
- elif remasking == 'entrophy':
240
- epsilon = 1e-10
241
- probs = F.softmax(logits.to(torch.float64), dim=-1)
242
- log_probs = torch.log(probs + epsilon)
243
- x0_p = torch.sum(probs * log_probs, dim=-1)
244
- elif remasking == 'margin':
245
- ## similar to margin algo in Dream
246
- p = F.softmax(logits.to(torch.float64), dim=-1)
247
- sorted_probs, _ = torch.sort(p, dim=-1, descending=True)
248
- top1_probs = sorted_probs[:, :, 0]
249
- top2_probs = sorted_probs[:, :, 1]
250
- x0_p = top1_probs - top2_probs
251
- else:
252
- raise NotImplementedError(remasking)
253
-
254
- x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
255
-
256
- x0 = torch.where(mask_index, x0, x)
257
- confidence = torch.where(mask_index, x0_p, -np.inf)
258
-
259
- transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
260
- for j in range(confidence.shape[0]):
261
- _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
262
- transfer_index[j, select_index] = True
263
- x[transfer_index] = x0[transfer_index]
264
- if verbose:
265
- history.append(x.clone().cpu())
266
- # breakpoint()
267
- # print(f"NFE: {NFE} Num Blocks: {num_blocks}")
268
- if verbose:
269
- return x,history
270
- return x
271
-
272
-
273
- def main():
274
- device = 'cuda'
275
-
276
- model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
277
- tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
278
-
279
- prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"
280
-
281
- # Add special tokens for the Instruct model. The Base model does not require the following two lines.
282
- m = [{"role": "user", "content": prompt}, ]
283
- prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
284
-
285
- input_ids = tokenizer(prompt)['input_ids']
286
- input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
287
-
288
- out = generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
289
- print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
290
- generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
291
-
292
-
293
- if __name__ == '__main__':
294
- main()