| ''' |
| This file is inspired by the code from https://github.com/ML-GSAI/SMDM |
| ''' |
| import accelerate |
| import torch |
| import re |
| from pathlib import Path |
| import random |
| import numpy as np |
| import torch.nn.functional as F |
| from datasets import Dataset |
| from lm_eval.__main__ import cli_evaluate |
| from lm_eval.api.instance import Instance |
| from lm_eval.api.model import LM |
| from lm_eval.api.registry import register_model |
| from tqdm import tqdm |
|
|
| from transformers import AutoTokenizer, AutoModel |
| import json |
| import os |
| import time |
|
|
|
|
| def set_seed(seed): |
| torch.manual_seed(seed) |
| random.seed(seed) |
| np.random.seed(seed) |
|
|
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def _sample_categorical(categorical_probs): |
| gumbel_norm = ( |
| 1e-10 |
| - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype) |
| return (categorical_probs / gumbel_norm).argmax(dim=-1) |
|
|
|
|
| @register_model("llada_dist") |
| class LLaDAEvalHarness(LM): |
| def __init__( |
| self, |
| model_path='', |
| mask_id=126336, |
| max_length=4096, |
| generated_samples_path='', |
| batch_size=32, |
| mc_num=128, |
| is_check_greedy=True, |
| cfg=0., |
| sampling_steps=512, |
| mask_length=512, |
| block_size=32, |
| remasking='low_confidence', |
| device="cuda", |
| sampler='', |
| remdm_number=0 |
| ): |
| ''' |
| Args: |
| model_path: LLaDA-8B-Base model path. |
| mask_id: The token id of [MASK] is 126336. |
| max_length: the max sequence length. |
| batch_size: mini batch size. |
| mc_num: Monte Carlo estimation iterations |
| is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer |
| is generated through greedy sampling conditioned on the prompt (note that this differs from conditional |
| generation). We implement this verification through the suffix_greedy_prediction() function, which |
| returns a True/False judgment used for accuracy calculation. |
| When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function. |
| However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality, |
| we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False |
| by default, significantly accelerating the evaluation process. |
| cfg_scale: Unsupervised classifier-free guidance scale. |
| ''' |
| super().__init__() |
|
|
| accelerator = accelerate.Accelerator() |
| if accelerator.num_processes > 1: |
| self.accelerator = accelerator |
| else: |
| self.accelerator = None |
| |
| model_kwargs = {} |
| if self.accelerator is not None: |
| model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}}) |
|
|
| self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs) |
| self.model.eval() |
|
|
| self.device = torch.device(device) |
| if self.accelerator is not None: |
| self.model = self.accelerator.prepare(self.model) |
| self.device = torch.device(f'{self.accelerator.device}') |
| self._rank = self.accelerator.local_process_index |
| self._world_size = self.accelerator.num_processes |
| else: |
| self.model = self.model.to(device) |
|
|
| self.mask_id = mask_id |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
| self.mc_num = mc_num |
| self.batch_size = int(batch_size) |
| assert mc_num % self.batch_size == 0 |
| self.sampling_eps = 0. |
| self.max_length = max_length |
| self.is_check_greedy = is_check_greedy |
|
|
| self.generated_samples_path = generated_samples_path |
| self.sampler = sampler |
| self.remdm_number = remdm_number |
|
|
| self.cfg = cfg |
| self.sampling_steps = sampling_steps |
| self.mask_length = mask_length |
| self.block_size = block_size |
| self.remasking = remasking |
| print(self.generated_samples_path) |
|
|
| @property |
| def rank(self): |
| return self._rank |
| |
| @property |
| def world_size(self): |
| return self._world_size |
|
|
| def _forward_process(self, batch, prompt_index): |
| b, l = batch.shape |
|
|
| target_len = (l - prompt_index.sum()).item() |
| k = torch.randint(1, target_len + 1, (), device=batch.device) |
|
|
| x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() |
| x = ((x - 1) % target_len) + 1 |
| assert x.min() >= 1 and x.max() <= target_len |
|
|
| indices = torch.arange(target_len, device=batch.device).repeat(b, 1) |
| is_mask = indices < x.unsqueeze(1) |
|
|
| for i in range(b): |
| is_mask[i] = is_mask[i][torch.randperm(target_len)] |
|
|
| is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1) |
|
|
| noisy_batch = torch.where(is_mask, self.mask_id, batch) |
|
|
| return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) |
|
|
| @torch.no_grad() |
| def get_logits(self, batch, prompt_index): |
| if self.cfg > 0.: |
| assert len(prompt_index) == batch.shape[1] |
| prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) |
| un_batch = batch.clone() |
| un_batch[prompt_index] = self.mask_id |
| batch = torch.cat([batch, un_batch]) |
|
|
| logits = self.model(batch).logits |
|
|
| if self.cfg > 0.: |
| logits, un_logits = torch.chunk(logits, 2, dim=0) |
| logits = un_logits + (self.cfg + 1) * (logits - un_logits) |
| return logits[:, :batch.shape[1]] |
|
|
| @torch.no_grad() |
| def get_loglikelihood(self, prefix, target): |
| seq = torch.concatenate([prefix, target])[None, :] |
| seq = seq.repeat((self.batch_size, 1)).to(self.device) |
|
|
| prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) |
|
|
| loss_acc = [] |
| for _ in range(self.mc_num // self.batch_size): |
| perturbed_seq, p_mask = self._forward_process(seq, prompt_index) |
|
|
| mask_indices = perturbed_seq == self.mask_id |
|
|
| logits = self.get_logits(perturbed_seq, prompt_index) |
|
|
| loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] |
| loss = loss.sum() / self.batch_size |
| loss_acc.append(loss.item()) |
|
|
| return - sum(loss_acc) / len(loss_acc) |
|
|
| @torch.no_grad() |
| def suffix_greedy_prediction(self, prefix, target): |
| if not self.is_check_greedy: |
| return False |
|
|
| seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device) |
| prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) |
| prefix, target = prefix.to(self.device), target.to(self.device) |
| seq[0, :len(prefix)] = prefix |
|
|
| for i in range(len(target)): |
| mask_index = (seq == self.mask_id) |
| logits = self.get_logits(seq, prompt_index)[mask_index] |
| x0 = torch.argmax(logits, dim=-1) |
|
|
| p = torch.softmax(logits.to(torch.float32), dim=-1) |
| confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1) |
| _, index = torch.sort(confidence, descending=True) |
| x0[index[1:]] = self.mask_id |
| seq[mask_index] = x0.clone() |
| correct = target == seq[0, len(prefix):] |
| correct = torch.all(correct) |
| return correct |
|
|
| def _encode_pair(self, context, continuation): |
| n_spaces = len(context) - len(context.rstrip()) |
| if n_spaces > 0: |
| continuation = context[-n_spaces:] + continuation |
| context = context[:-n_spaces] |
|
|
| whole_enc = self.tokenizer(context + continuation)["input_ids"] |
| context_enc = self.tokenizer(context)["input_ids"] |
|
|
| context_enc_len = len(context_enc) |
| continuation_enc = whole_enc[context_enc_len:] |
|
|
| return context_enc, continuation_enc |
|
|
| def loglikelihood(self, requests): |
| def _tokenize(e): |
| prefix, target = self._encode_pair(e["prefix"], e["target"]) |
| return { |
| "prefix_text": e["prefix"], |
| "target_text": e["target"], |
| "prefix": prefix, |
| "target": target, |
| } |
|
|
| ds = [] |
| ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] |
| ds = Dataset.from_list(ds) |
| ds = ds.map(_tokenize) |
| ds = ds.with_format("torch") |
| prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] |
|
|
| assert max(prompt_len) <= 4096 |
|
|
| out = [] |
| with torch.no_grad(): |
| for elem in tqdm(ds, desc="Computing likelihood..."): |
| prefix = elem["prefix"] |
| target = elem["target"] |
|
|
| ll = self.get_loglikelihood(prefix, target) |
|
|
| is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) |
|
|
| out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) |
| torch.cuda.empty_cache() |
| return out |
|
|
| def loglikelihood_rolling(self, requests): |
| raise NotImplementedError |
| |
| @torch.no_grad() |
| def llada_conf_sample(self, prompt): |
| xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long).to(self.model.device) |
| xt[:, :prompt.shape[1]] = prompt.clone() |
| |
| prompt_index = (xt != self.mask_id) |
| prompt_len = prompt_index.sum(1).item() |
|
|
| assert self.mask_length % self.block_size == 0 |
| num_blocks = self.mask_length // self.block_size |
|
|
| assert self.sampling_steps % num_blocks == 0 |
| steps = self.sampling_steps // num_blocks |
|
|
| assert self.mask_length % self.sampling_steps == 0 |
|
|
| for num_block in range(num_blocks): |
| for i in range(steps): |
| mask_index = (xt == self.mask_id) |
| logits = self.model(xt).logits |
| p = F.softmax(logits.to(torch.float64), dim=-1) |
| x0 = _sample_categorical(p) |
| x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) |
|
|
| x0_p[:, prompt_len + (num_block + 1) * self.block_size:] = -np.inf |
| x0 = torch.where(mask_index, x0, xt) |
| confidence = torch.where(mask_index, x0_p, -np.inf) |
|
|
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) |
| for j in range(confidence.shape[0]): |
| _, select_index = torch.topk(confidence[j], k=int(self.mask_length / self.sampling_steps)) |
| transfer_index[j, select_index] = True |
| xt[transfer_index] = x0[transfer_index] |
| if torch.sum(xt == self.tokenizer.eos_token_id) > 0: |
| return xt |
|
|
| return xt |
|
|
| @torch.no_grad() |
| def llada_remdm_sample(self, prompt): |
| xt = torch.full((1, prompt.shape[1] + self.mask_length), self.mask_id, dtype=torch.long).to(self.model.device) |
| xt[:, :prompt.shape[1]] = prompt.clone() |
|
|
| prompt_index = (xt != self.mask_id) |
| prompt_len = prompt_index.sum(1).item() |
|
|
| assert self.mask_length % self.block_size == 0 |
| num_blocks = self.mask_length // self.block_size |
|
|
| assert self.sampling_steps % num_blocks == 0 |
| steps = self.sampling_steps // num_blocks |
|
|
| assert self.mask_length % self.sampling_steps == 0 |
|
|
| for num_block in range(num_blocks): |
| conf_cache = torch.ones_like(xt, dtype=torch.float64) * np.inf |
| remask_thres = int(self.block_size / 8 * 7) |
| for i in range(2 * steps): |
| if i >= remask_thres and i < remask_thres + steps: |
| remask_index = torch.zeros_like(xt, dtype=torch.bool, device=xt.device) |
| _, mask_indices = torch.topk(conf_cache, k=self.remdm_number, largest=False, dim=1) |
| remask_index[0, mask_indices] = True |
| conf_cache[remask_index] = np.inf |
| xt[remask_index] = self.mask_id |
| mask_index = (xt == self.mask_id) |
| logits = self.model(xt).logits |
| p = F.softmax(logits.to(torch.float64), dim=-1) |
| x0 = _sample_categorical(p) |
| x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) |
|
|
| x0_p[:, prompt_len + (num_block + 1) * self.block_size:] = -np.inf |
| x0 = torch.where(mask_index, x0, xt) |
| confidence = torch.where(mask_index, x0_p, -np.inf) |
|
|
| if i >= remask_thres and i < remask_thres + steps: |
| transfer_length = self.remdm_number |
| else: |
| transfer_length = int(self.mask_length / self.sampling_steps) |
|
|
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) |
| for j in range(confidence.shape[0]): |
| _, select_index = torch.topk(confidence[j], k=transfer_length) |
| transfer_index[j, select_index] = True |
| xt[transfer_index] = x0[transfer_index] |
| conf_cache[transfer_index] = confidence[transfer_index] |
| if torch.sum(xt == self.tokenizer.eos_token_id) > 0: |
| return xt |
|
|
| return xt |
|
|
| @torch.no_grad() |
| def generate_until(self, requests: list[Instance]): |
| start_time = time.time() |
|
|
| def _tokenize(e): |
| return { |
| "question": self.tokenizer(e["question"])["input_ids"], |
| "question_text": e["question"], |
| "until": e["until"], |
| } |
|
|
| ds = [{"question": req.args[0], "until": req.args[1]['until']} for req in requests] |
| ds = Dataset.from_list(ds) |
| ds = ds.map(_tokenize) |
| ds = ds.with_format("torch") |
|
|
| out, out_for_json = [], [] |
| for elem in tqdm(ds, desc="Generating..."): |
| prompt = elem["question"].unsqueeze(0).to(self.device) |
| stop_tokens = elem["until"] + ["<|eot_id|>", self.tokenizer.eos_token] |
| |
| if self.sampler == 'llada_conf': |
| generated_answer = self.llada_conf_sample(prompt) |
| elif self.sampler == 'llada_remdm': |
| generated_answer = self.llada_remdm_sample(prompt) |
| |
| generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1]:], skip_special_tokens=False) |
| |
| for stop_seq in stop_tokens: |
| if stop_seq in generated_answer: |
| generated_answer = generated_answer.split(stop_seq)[0] |
|
|
| |
| generated_answer_ids = self.tokenizer(generated_answer)["input_ids"] |
| generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) |
| |
| out.append(generated_answer) |
| out_for_json.append({ |
| "prefix": elem["question_text"], |
| "result": generated_answer, |
| }) |
|
|
| if self.accelerator is not None: |
| self.accelerator.wait_for_everyone() |
|
|
| end_time = time.time() |
| total_duration = end_time - start_time |
| print(f"\n总耗时: {total_duration:.2f} 秒") |
|
|
| with open(os.path.join(self.generated_samples_path, str(self._rank) + ".json"), "w") as f: |
| final_output = { |
| "total_time_seconds": total_duration, |
| "samples": out_for_json |
| } |
| json.dump(final_output, f, indent=2) |
|
|
| return out |
|
|
|
|
| if __name__ == "__main__": |
| cli_evaluate() |