Update README.md
Browse files
README.md
CHANGED
|
@@ -50,337 +50,177 @@ This model has been 4-bit quantized Llada-8B-Base model with [GPTQModel](https:/
|
|
| 50 |
## Example:
|
| 51 |
```python
|
| 52 |
'''
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
import torch
|
| 57 |
-
import
|
| 58 |
-
from
|
| 59 |
-
import
|
|
|
|
| 60 |
import numpy as np
|
| 61 |
-
import torch.nn.functional as F
|
| 62 |
-
from datasets import Dataset
|
| 63 |
-
from lm_eval.__main__ import cli_evaluate
|
| 64 |
-
from lm_eval.api.instance import Instance
|
| 65 |
-
from lm_eval.api.model import LM
|
| 66 |
-
from lm_eval.models.huggingface import HFLM
|
| 67 |
-
from lm_eval.api.registry import register_model
|
| 68 |
-
from tqdm import tqdm
|
| 69 |
-
|
| 70 |
-
from transformers import AutoTokenizer, AutoModel
|
| 71 |
-
from gptqmodel import GPTQModel
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
@register_model("llada_dist")
|
| 76 |
-
class LLaDAEvalHarness(LM):
|
| 77 |
-
def __init__(
|
| 78 |
-
self,
|
| 79 |
-
model_path='',
|
| 80 |
-
mask_id=126336,
|
| 81 |
-
max_length=4096,
|
| 82 |
-
block_length = 4096,
|
| 83 |
-
steps = 128,
|
| 84 |
-
batch_size=32,
|
| 85 |
-
mc_num=128,
|
| 86 |
-
is_check_greedy=True,
|
| 87 |
-
cfg=0.,
|
| 88 |
-
device="cuda",
|
| 89 |
-
gptqmodel=True
|
| 90 |
-
):
|
| 91 |
-
"""
|
| 92 |
-
Args:
|
| 93 |
-
model_path: LLaDA-8B-Base model path.
|
| 94 |
-
mask_id: The token id of [MASK] is 126336.
|
| 95 |
-
max_length: the max sequence length.
|
| 96 |
-
batch_size: mini batch size.
|
| 97 |
-
mc_num: Monte Carlo estimation iterations
|
| 98 |
-
is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer
|
| 99 |
-
is generated through greedy sampling conditioned on the prompt (note that this differs from conditional
|
| 100 |
-
generation). We implement this verification through the suffix_greedy_prediction() function, which
|
| 101 |
-
returns a True/False judgment used for accuracy calculation.
|
| 102 |
-
When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function.
|
| 103 |
-
However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality,
|
| 104 |
-
we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False
|
| 105 |
-
by default, significantly accelerating the evaluation process.
|
| 106 |
-
cfg_scale: Unsupervised classifier-free guidance scale.
|
| 107 |
-
"""
|
| 108 |
-
super().__init__()
|
| 109 |
-
|
| 110 |
-
accelerator = accelerate.Accelerator()
|
| 111 |
-
if accelerator.num_processes > 1:
|
| 112 |
-
self.accelerator = accelerator
|
| 113 |
-
else:
|
| 114 |
-
self.accelerator = None
|
| 115 |
-
|
| 116 |
-
model_kwargs = {}
|
| 117 |
-
if self.accelerator is not None:
|
| 118 |
-
model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}})
|
| 119 |
-
|
| 120 |
-
#self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, gptqmodel=gptqmodel, **model_kwargs)
|
| 121 |
-
self.model = GPTQModel.load(model_path, device='cuda' , trust_remote_code=True )
|
| 122 |
-
self.model.eval()
|
| 123 |
-
|
| 124 |
-
self.device = torch.device(device)
|
| 125 |
-
if self.accelerator is not None:
|
| 126 |
-
self.model = self.accelerator.prepare(self.model)
|
| 127 |
-
self.device = torch.device(f'{self.accelerator.device}')
|
| 128 |
-
self._rank = self.accelerator.local_process_index
|
| 129 |
-
self._world_size = self.accelerator.num_processes
|
| 130 |
-
|
| 131 |
-
self.mask_id = mask_id
|
| 132 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 133 |
-
|
| 134 |
-
self.mc_num = mc_num
|
| 135 |
-
self.batch_size = int(batch_size)
|
| 136 |
-
assert mc_num % self.batch_size == 0
|
| 137 |
-
self.sampling_eps = 0.
|
| 138 |
-
self.max_length = max_length
|
| 139 |
-
self.block_length = block_length
|
| 140 |
-
self.steps = steps
|
| 141 |
-
self.is_check_greedy = is_check_greedy
|
| 142 |
-
|
| 143 |
-
self.cfg = cfg
|
| 144 |
-
print(f'model: {model_path}')
|
| 145 |
-
print(f'Is check greedy: {is_check_greedy}')
|
| 146 |
-
print(f'cfg: {cfg}')
|
| 147 |
-
|
| 148 |
-
@property
|
| 149 |
-
def rank(self):
|
| 150 |
-
return self._rank
|
| 151 |
-
|
| 152 |
-
@property
|
| 153 |
-
def world_size(self):
|
| 154 |
-
return self._world_size
|
| 155 |
|
| 156 |
-
def _forward_process(self, batch, prompt_index):
|
| 157 |
-
b, l = batch.shape
|
| 158 |
|
| 159 |
-
target_len = (l - prompt_index.sum()).item()
|
| 160 |
-
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
| 161 |
|
| 162 |
-
x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
|
| 163 |
-
x = ((x - 1) % target_len) + 1
|
| 164 |
-
assert x.min() >= 1 and x.max() <= target_len
|
| 165 |
|
| 166 |
-
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
| 167 |
-
is_mask = indices < x.unsqueeze(1)
|
| 168 |
-
|
| 169 |
-
for i in range(b):
|
| 170 |
-
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
| 171 |
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
| 175 |
-
|
| 176 |
-
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
un_batch = batch.clone()
|
| 184 |
-
un_batch[prompt_index] = self.mask_id
|
| 185 |
-
batch = torch.cat([batch, un_batch])
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 191 |
-
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
| 192 |
-
return logits[:, :batch.shape[1]]
|
| 193 |
-
|
| 194 |
-
@torch.no_grad()
|
| 195 |
-
def get_loglikelihood(self, prefix, target):
|
| 196 |
-
seq = torch.concatenate([prefix, target])[None, :]
|
| 197 |
-
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
| 198 |
-
|
| 199 |
-
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 200 |
-
|
| 201 |
-
loss_acc = []
|
| 202 |
-
for _ in range(self.mc_num // self.batch_size):
|
| 203 |
-
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
| 204 |
-
|
| 205 |
-
mask_indices = perturbed_seq == self.mask_id
|
| 206 |
-
|
| 207 |
-
logits = self.get_logits(perturbed_seq, prompt_index)
|
| 208 |
-
|
| 209 |
-
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
| 210 |
-
loss = loss.sum() / self.batch_size
|
| 211 |
-
loss_acc.append(loss.item())
|
| 212 |
-
|
| 213 |
-
return - sum(loss_acc) / len(loss_acc)
|
| 214 |
-
|
| 215 |
-
@torch.no_grad()
|
| 216 |
-
def suffix_greedy_prediction(self, prefix, target):
|
| 217 |
-
if not self.is_check_greedy:
|
| 218 |
-
return False
|
| 219 |
-
|
| 220 |
-
seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device)
|
| 221 |
-
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
| 222 |
-
prefix, target = prefix.to(self.device), target.to(self.device)
|
| 223 |
-
seq[0, :len(prefix)] = prefix
|
| 224 |
-
|
| 225 |
-
for i in range(len(target)):
|
| 226 |
-
mask_index = (seq == self.mask_id)
|
| 227 |
-
logits = self.get_logits(seq, prompt_index)[mask_index]
|
| 228 |
-
x0 = torch.argmax(logits, dim=-1)
|
| 229 |
-
|
| 230 |
-
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
| 231 |
-
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1)
|
| 232 |
-
_, index = torch.sort(confidence, descending=True)
|
| 233 |
-
x0[index[1:]] = self.mask_id
|
| 234 |
-
seq[mask_index] = x0.clone()
|
| 235 |
-
correct = target == seq[0, len(prefix):]
|
| 236 |
-
correct = torch.all(correct)
|
| 237 |
-
return correct
|
| 238 |
-
|
| 239 |
-
def _encode_pair(self, context, continuation):
|
| 240 |
-
n_spaces = len(context) - len(context.rstrip())
|
| 241 |
-
if n_spaces > 0:
|
| 242 |
-
continuation = context[-n_spaces:] + continuation
|
| 243 |
-
context = context[:-n_spaces]
|
| 244 |
-
|
| 245 |
-
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
| 246 |
-
context_enc = self.tokenizer(context)["input_ids"]
|
| 247 |
-
|
| 248 |
-
context_enc_len = len(context_enc)
|
| 249 |
-
continuation_enc = whole_enc[context_enc_len:]
|
| 250 |
-
|
| 251 |
-
return context_enc, continuation_enc
|
| 252 |
-
|
| 253 |
-
def loglikelihood(self, requests):
|
| 254 |
-
def _tokenize(e):
|
| 255 |
-
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
| 256 |
-
return {
|
| 257 |
-
"prefix_text": e["prefix"],
|
| 258 |
-
"target_text": e["target"],
|
| 259 |
-
"prefix": prefix,
|
| 260 |
-
"target": target,
|
| 261 |
-
}
|
| 262 |
-
|
| 263 |
-
ds = []
|
| 264 |
-
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
| 265 |
-
ds = Dataset.from_list(ds)
|
| 266 |
-
ds = ds.map(_tokenize)
|
| 267 |
-
ds = ds.with_format("torch")
|
| 268 |
-
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
| 269 |
-
|
| 270 |
-
assert max(prompt_len) <= 4096
|
| 271 |
-
|
| 272 |
-
out = []
|
| 273 |
-
with torch.no_grad():
|
| 274 |
-
for elem in tqdm(ds, desc="Computing likelihood..."):
|
| 275 |
-
prefix = elem["prefix"]
|
| 276 |
-
target = elem["target"]
|
| 277 |
-
|
| 278 |
-
ll = self.get_loglikelihood(prefix, target)
|
| 279 |
-
|
| 280 |
-
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
| 281 |
-
|
| 282 |
-
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
| 283 |
-
print('=' * 20)
|
| 284 |
-
print('prefix: ', elem['prefix_text'])
|
| 285 |
-
print('target: ', elem['target_text'])
|
| 286 |
-
print(ll, is_target_greedy_dec)
|
| 287 |
-
print('=' * 20, end='\n\n')
|
| 288 |
-
torch.cuda.empty_cache()
|
| 289 |
-
return out
|
| 290 |
-
|
| 291 |
-
def loglikelihood_rolling(self, requests):
|
| 292 |
-
|
| 293 |
-
raise NotImplementedError
|
| 294 |
-
def generate_until(self, context, max_length, stop, **generation_kwargs):
|
| 295 |
-
raise NotImplementedError
|
| 296 |
-
@torch.no_grad()
|
| 297 |
-
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
| 298 |
-
'''
|
| 299 |
-
Args:
|
| 300 |
-
model: Mask predictor.
|
| 301 |
-
prompt: A tensor of shape (1, l).
|
| 302 |
-
steps: Sampling steps, less than or equal to gen_length.
|
| 303 |
-
gen_length: Generated answer length.
|
| 304 |
-
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
|
| 305 |
-
temperature: Categorical distribution sampling temperature.
|
| 306 |
-
cfg_scale: Unsupervised classifier-free guidance scale.
|
| 307 |
-
remasking: Remasking strategy. 'low_confidence' or 'random'.
|
| 308 |
-
mask_id: The toke id of [MASK] is 126336.
|
| 309 |
-
'''
|
| 310 |
-
|
| 311 |
-
# using the hyperparams in orginal paper
|
| 312 |
-
prompt = context
|
| 313 |
-
|
| 314 |
-
#
|
| 315 |
-
gen_length = self.max_length
|
| 316 |
-
block_length = self.block_length
|
| 317 |
-
steps = self.max_length
|
| 318 |
-
temperature=0.
|
| 319 |
-
cfg_scale=0.
|
| 320 |
-
remasking='low_confidence'
|
| 321 |
-
mask_id=126336
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(self.model.device)
|
| 325 |
-
x[:, :prompt.shape[1]] = prompt.clone()
|
| 326 |
-
|
| 327 |
-
prompt_index = (x != mask_id)
|
| 328 |
-
|
| 329 |
-
assert gen_length % block_length == 0
|
| 330 |
-
num_blocks = gen_length // block_length
|
| 331 |
-
|
| 332 |
-
assert steps % num_blocks == 0
|
| 333 |
-
steps = steps // num_blocks
|
| 334 |
-
|
| 335 |
-
for num_block in range(num_blocks):
|
| 336 |
-
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
| 337 |
-
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
| 338 |
-
for i in range(steps):
|
| 339 |
-
|
| 340 |
-
mask_index = (x == mask_id)
|
| 341 |
-
if cfg_scale > 0.:
|
| 342 |
-
un_x = x.clone()
|
| 343 |
-
un_x[prompt_index] = mask_id
|
| 344 |
-
x_ = torch.cat([x, un_x], dim=0)
|
| 345 |
-
logits = self.model(x_).logits
|
| 346 |
-
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 347 |
-
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
| 348 |
-
else:
|
| 349 |
-
logits = self.model(x).logits
|
| 350 |
-
|
| 351 |
-
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 352 |
-
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
| 353 |
-
|
| 354 |
-
if remasking == 'low_confidence':
|
| 355 |
-
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 356 |
-
x0_p = torch.squeeze(
|
| 357 |
-
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 358 |
-
elif remasking == 'random':
|
| 359 |
-
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 360 |
-
else:
|
| 361 |
-
raise NotImplementedError(remasking)
|
| 362 |
-
|
| 363 |
-
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
| 364 |
-
|
| 365 |
-
x0 = torch.where(mask_index, x0, x)
|
| 366 |
-
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 367 |
-
|
| 368 |
-
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 369 |
-
for j in range(confidence.shape[0]):
|
| 370 |
-
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
| 371 |
-
transfer_index[j, select_index] = True
|
| 372 |
-
x[transfer_index] = x0[transfer_index]
|
| 373 |
-
|
| 374 |
-
return x
|
| 375 |
|
|
|
|
|
|
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
-
```
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
```
|
|
|
|
| 50 |
## Example:
|
| 51 |
```python
|
| 52 |
'''
|
| 53 |
+
|
| 54 |
+
# Copyright 2024-2025 ModelCloud.ai
|
| 55 |
+
# Copyright 2024-2025 qubitium@modelcloud.ai
|
| 56 |
+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
|
| 57 |
+
#
|
| 58 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 59 |
+
# you may not use this file except in compliance with the License.
|
| 60 |
+
# You may obtain a copy of the License at
|
| 61 |
+
#
|
| 62 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 63 |
+
#
|
| 64 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 65 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 66 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 67 |
+
# See the License for the specific language governing permissions and
|
| 68 |
+
# limitations under the License.
|
| 69 |
+
|
| 70 |
import torch
|
| 71 |
+
from datasets import load_dataset
|
| 72 |
+
from gptqmodel import GPTQModel, QuantizeConfig, BACKEND
|
| 73 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 74 |
+
import torch.nn.functional as F
|
| 75 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
|
|
|
|
|
|
| 77 |
|
|
|
|
|
|
|
| 78 |
|
|
|
|
|
|
|
|
|
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
def add_gumbel_noise(logits, temperature):
|
| 82 |
+
'''
|
| 83 |
+
The Gumbel max is a method for sampling categorical distributions.
|
| 84 |
+
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
| 85 |
+
Thus, we use float64.
|
| 86 |
+
'''
|
| 87 |
+
logits = logits.to(torch.float64)
|
| 88 |
+
noise = torch.rand_like(logits, dtype=torch.float64)
|
| 89 |
+
gumbel_noise = (- torch.log(noise)) ** temperature
|
| 90 |
+
return logits.exp() / gumbel_noise
|
| 91 |
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
+
def get_num_transfer_tokens(mask_index, steps):
|
| 94 |
+
'''
|
| 95 |
+
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
| 96 |
+
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
| 97 |
+
the expected number of tokens transitioned at each step should be consistent.
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
| 100 |
+
'''
|
| 101 |
+
mask_num = mask_index.sum(dim=1, keepdim=True) #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
base = mask_num // steps
|
| 104 |
+
remainder = mask_num % steps
|
| 105 |
|
| 106 |
+
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
| 107 |
+
|
| 108 |
+
for i in range(mask_num.size(0)):
|
| 109 |
+
num_transfer_tokens[i, :remainder[i]] += 1
|
| 110 |
+
|
| 111 |
+
return num_transfer_tokens
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@ torch.no_grad()
|
| 120 |
+
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
|
| 121 |
+
cfg_scale=0., remasking='low_confidence', mask_id=126336):
|
| 122 |
+
'''
|
| 123 |
+
Args:
|
| 124 |
+
model: Mask predictor.
|
| 125 |
+
prompt: A tensor of shape (1, l).
|
| 126 |
+
steps: Sampling steps, less than or equal to gen_length.
|
| 127 |
+
gen_length: Generated answer length.
|
| 128 |
+
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
|
| 129 |
+
temperature: Categorical distribution sampling temperature.
|
| 130 |
+
cfg_scale: Unsupervised classifier-free guidance scale.
|
| 131 |
+
remasking: Remasking strategy. 'low_confidence' or 'random'.
|
| 132 |
+
mask_id: The toke id of [MASK] is 126336.
|
| 133 |
+
'''
|
| 134 |
+
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
|
| 135 |
+
x[:, :prompt.shape[1]] = prompt.clone()
|
| 136 |
+
|
| 137 |
+
prompt_index = (x != mask_id)
|
| 138 |
+
|
| 139 |
+
assert gen_length % block_length == 0
|
| 140 |
+
num_blocks = gen_length // block_length
|
| 141 |
+
|
| 142 |
+
assert steps % num_blocks == 0
|
| 143 |
+
steps = steps // num_blocks
|
| 144 |
+
|
| 145 |
+
for num_block in range(num_blocks):
|
| 146 |
+
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
| 147 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
| 148 |
+
for i in range(steps):
|
| 149 |
+
|
| 150 |
+
mask_index = (x == mask_id)
|
| 151 |
+
if cfg_scale > 0.:
|
| 152 |
+
un_x = x.clone()
|
| 153 |
+
un_x[prompt_index] = mask_id
|
| 154 |
+
x_ = torch.cat([x, un_x], dim=0)
|
| 155 |
+
logits = model(x_).logits
|
| 156 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
| 157 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
| 158 |
+
else:
|
| 159 |
+
logits = model(x).logits
|
| 160 |
+
|
| 161 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
| 162 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
| 163 |
+
|
| 164 |
+
if remasking == 'low_confidence':
|
| 165 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
| 166 |
+
x0_p = torch.squeeze(
|
| 167 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
| 168 |
+
elif remasking == 'random':
|
| 169 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
| 170 |
+
else:
|
| 171 |
+
raise NotImplementedError(remasking)
|
| 172 |
+
|
| 173 |
+
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
| 174 |
+
|
| 175 |
+
x0 = torch.where(mask_index, x0, x)
|
| 176 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
| 177 |
+
|
| 178 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
| 179 |
+
for j in range(confidence.shape[0]):
|
| 180 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
| 181 |
+
transfer_index[j, select_index] = True
|
| 182 |
+
x[transfer_index] = x0[transfer_index]
|
| 183 |
+
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
def main():
|
| 187 |
+
quantized_model_id="FunAGI/LLaDA-8B-Base-gptqmodel-4bit"
|
| 188 |
+
tokenizer = AutoTokenizer.from_pretrained(quantized_model_id ,use_fast=False)
|
| 189 |
|
| 190 |
+
|
| 191 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 192 |
+
prompt = "Paul is at a train station and is waiting for his train. He isn't sure how long he needs to wait, but he knows that the fourth train scheduled to arrive at the station is the one he needs to get on. The first train is scheduled to arrive in 10 minutes, and this train will stay in the station for 20 minutes. The second train is to arrive half an hour after the first train leaves the station, and this second train will stay in the station for a quarter of the amount of time that the first train stayed in the station. The third train is to arrive an hour after the second train leaves the station, and this third train is to leave the station immediately after it arrives. The fourth train will arrive 20 minutes after the third train leaves, and this is the train Paul will board. In total, how long, in minutes, will Paul wait for his train?"
|
| 193 |
+
|
| 194 |
+
# # # Add special tokens for the Instruct model. The Base model does not require the following two lines.
|
| 195 |
+
m = [{"role": "user", "content": prompt}, ]
|
| 196 |
+
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
|
| 197 |
+
|
| 198 |
+
input_ids = tokenizer(prompt)['input_ids']
|
| 199 |
+
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
|
| 200 |
+
|
| 201 |
|
|
|
|
| 202 |
|
| 203 |
+
|
| 204 |
+
model = GPTQModel.load(quantized_model_id, device=device , trust_remote_code=True )
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
steps=256
|
| 208 |
+
out = generate(model, input_ids, steps=steps , gen_length=256, block_length=8, temperature=0., cfg_scale=0., remasking='low_confidence')
|
| 209 |
+
print("*"*30+ f"Steps {steps}"+ "*"*30)
|
| 210 |
+
print(input_ids.shape)
|
| 211 |
+
print( tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0])
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
import logging
|
| 217 |
+
|
| 218 |
+
logging.basicConfig(
|
| 219 |
+
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
| 220 |
+
level=logging.INFO,
|
| 221 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
main()
|
| 225 |
+
|
| 226 |
```
|