| |
| import argparse |
| import csv |
| import json |
| import os |
| import sys |
| from typing import Iterable |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| class IndexDataset(Dataset): |
| def __init__(self, tensors: torch.Tensor): |
| self.tensors = tensors |
|
|
| def __getitem__(self, index: int) -> torch.Tensor: |
| return self.tensors[index] |
|
|
| def __len__(self) -> int: |
| return len(self.tensors) |
|
|
|
|
| def get_dataset(name: str): |
| if name == "wikitext2": |
| train_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") |
| test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| return train_data, test_data, "text" |
| if name == "ptb": |
| train_data = load_dataset("ptb_text_only", "penn_treebank", split="train") |
| test_data = load_dataset("ptb_text_only", "penn_treebank", split="validation") |
| return train_data, test_data, "sentence" |
| raise ValueError(f"Unsupported dataset: {name}") |
|
|
|
|
| def process_data(samples, tokenizer, seq_len: int, field_name: str, add_bos_to_every: bool) -> IndexDataset: |
| test_ids = tokenizer( |
| "\n\n".join(samples[field_name]), |
| return_tensors="pt", |
| add_special_tokens=False, |
| ).input_ids[0] |
|
|
| if not add_bos_to_every and tokenizer.bos_token_id is not None: |
| test_ids = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), test_ids), dim=0) |
|
|
| batches = [] |
| num_samples = test_ids.numel() // seq_len |
| for index in range(num_samples): |
| batch = test_ids[(index * seq_len) : ((index + 1) * seq_len)] |
| if add_bos_to_every and tokenizer.bos_token_id is not None: |
| batch = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), batch), dim=0) |
| batches.append(batch) |
|
|
| return IndexDataset(tensors=torch.stack(batches)) |
|
|
|
|
| def get_loader(name: str, tokenizer, seq_len: int, batch_size: int, add_bos_to_every: bool): |
| _, test_data, field_name = get_dataset(name) |
| dataset = process_data(test_data, tokenizer, seq_len, field_name, add_bos_to_every) |
| return DataLoader(dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_ppl(model, test_loader, device: str) -> float: |
| nlls = [] |
| for batch in tqdm(test_loader, desc="Running PPL", dynamic_ncols=True): |
| batch = batch.to(device) |
| outputs = model(batch) |
| shift_logits = outputs.logits[:, :-1, :].contiguous() |
| shift_labels = batch[:, 1:].contiguous() |
| loss_fct = torch.nn.CrossEntropyLoss(reduction="none") |
| loss = loss_fct( |
| shift_logits.reshape(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ) |
| nlls.append(loss.cpu()) |
|
|
| return float(np.exp(torch.cat(nlls, dim=-1).mean().item())) |
|
|
|
|
| def resolve_dtype(args) -> torch.dtype: |
| if args.use_bfloat: |
| return torch.bfloat16 |
|
|
| dtype_name = args.dtype if args.dtype is not None else args.torch_dtype |
| if dtype_name is None: |
| dtype_name = "float16" |
|
|
| dtype_map = { |
| "float16": torch.float16, |
| "fp16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "bf16": torch.bfloat16, |
| "float32": torch.float32, |
| "fp32": torch.float32, |
| } |
| if dtype_name not in dtype_map: |
| raise ValueError(f"Unsupported dtype: {dtype_name}") |
| return dtype_map[dtype_name] |
|
|
|
|
| def normalize_datasets(datasets: Iterable[str]) -> list[str]: |
| normalized = [] |
| for dataset in datasets: |
| normalized.append("wikitext2" if dataset == "wikitext" else dataset) |
| return normalized |
|
|
|
|
| def build_arg_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Shared perplexity evaluation for abprune.") |
| parser.add_argument("--base_model", "--model-path", dest="model_path", required=True) |
| parser.add_argument("--output_dir", type=str, default=None) |
| parser.add_argument("--dataset", nargs="+", default=["wikitext2", "ptb"]) |
| parser.add_argument("--max_seq_len", "--seq-len", dest="seq_len", type=int, default=1024) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument( |
| "--dtype", |
| default=None, |
| choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"], |
| ) |
| parser.add_argument( |
| "--torch_dtype", |
| default=None, |
| choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"], |
| ) |
| parser.add_argument("--use_bfloat", action="store_true") |
| parser.add_argument("--add_bos_to_every", action="store_true") |
| parser.add_argument("--fix_decapoda_config", action="store_true") |
| parser.add_argument("--local_files_only", action="store_true") |
| return parser |
|
|
|
|
| def maybe_fix_decapoda_config(tokenizer, enabled: bool) -> None: |
| if not enabled: |
| return |
| if tokenizer.bos_token_id is None and tokenizer.eos_token_id is not None: |
| tokenizer.bos_token = tokenizer.eos_token |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
| def ensure_llmpruner_on_path() -> None: |
| repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner") |
| if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path: |
| sys.path.insert(0, llmpruner_root) |
|
|
|
|
| def load_model_and_tokenizer(model_path: str, *, torch_dtype: torch.dtype, local_files_only: bool): |
| if os.path.isfile(model_path) and model_path.endswith(".bin"): |
| ensure_llmpruner_on_path() |
| checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) |
| if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint: |
| raise ValueError( |
| "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries." |
| ) |
| model = checkpoint["model"] |
| tokenizer = checkpoint["tokenizer"] |
| if torch_dtype is not None: |
| model = model.to(dtype=torch_dtype) |
| return model, tokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| local_files_only=local_files_only, |
| ) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch_dtype, |
| local_files_only=local_files_only, |
| ) |
| return model, tokenizer |
|
|
|
|
| def main() -> None: |
| parser = build_arg_parser() |
| args = parser.parse_args() |
|
|
| datasets = normalize_datasets(args.dataset) |
| torch_dtype = resolve_dtype(args) |
|
|
| model, tokenizer = load_model_and_tokenizer( |
| args.model_path, |
| torch_dtype=torch_dtype, |
| local_files_only=args.local_files_only, |
| ) |
| maybe_fix_decapoda_config(tokenizer, args.fix_decapoda_config) |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model.eval() |
| model.to(args.device) |
|
|
| metrics = {} |
| for dataset in datasets: |
| test_loader = get_loader( |
| dataset, |
| tokenizer, |
| seq_len=args.seq_len, |
| batch_size=args.batch_size, |
| add_bos_to_every=args.add_bos_to_every, |
| ) |
| metrics[dataset] = evaluate_ppl(model, test_loader, args.device) |
| print(f"PPL-{dataset}: {metrics[dataset]} | add_bos_to_every: {args.add_bos_to_every} | seq_len: {args.seq_len}") |
|
|
| mem = None |
| if torch.cuda.is_available() and args.device.startswith("cuda"): |
| mem = torch.cuda.memory_allocated(args.device) / 1024 / 1024 |
|
|
| result = { |
| "model_path": os.path.abspath(args.model_path), |
| "datasets": datasets, |
| "seq_len": args.seq_len, |
| "batch_size": args.batch_size, |
| "device": args.device, |
| "dtype": str(torch_dtype).replace("torch.", ""), |
| "add_bos_to_every": args.add_bos_to_every, |
| "metrics": metrics, |
| "params": int(sum(parameter.numel() for parameter in model.parameters())), |
| "mem_mib": mem, |
| } |
|
|
| if args.output_dir is not None: |
| os.makedirs(args.output_dir, exist_ok=True) |
| filename = "ppl_bos.csv" if args.add_bos_to_every else "ppl.csv" |
| csv_path = os.path.join(args.output_dir, filename) |
| with open(csv_path, "w", newline="", encoding="utf-8") as handle: |
| writer = csv.writer(handle) |
| writer.writerow([*(f"ppl_{dataset}" for dataset in datasets), "params", "mem"]) |
| writer.writerow([*(metrics[dataset] for dataset in datasets), result["params"], mem]) |
|
|
| print(json.dumps(result, ensure_ascii=True)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|