#!/usr/bin/env python3 import argparse import csv import json import os import sys from typing import Iterable import numpy as np import torch from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer class IndexDataset(Dataset): def __init__(self, tensors: torch.Tensor): self.tensors = tensors def __getitem__(self, index: int) -> torch.Tensor: return self.tensors[index] def __len__(self) -> int: return len(self.tensors) def get_dataset(name: str): if name == "wikitext2": train_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") return train_data, test_data, "text" if name == "ptb": train_data = load_dataset("ptb_text_only", "penn_treebank", split="train") test_data = load_dataset("ptb_text_only", "penn_treebank", split="validation") return train_data, test_data, "sentence" raise ValueError(f"Unsupported dataset: {name}") def process_data(samples, tokenizer, seq_len: int, field_name: str, add_bos_to_every: bool) -> IndexDataset: test_ids = tokenizer( "\n\n".join(samples[field_name]), return_tensors="pt", add_special_tokens=False, ).input_ids[0] if not add_bos_to_every and tokenizer.bos_token_id is not None: test_ids = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), test_ids), dim=0) batches = [] num_samples = test_ids.numel() // seq_len for index in range(num_samples): batch = test_ids[(index * seq_len) : ((index + 1) * seq_len)] if add_bos_to_every and tokenizer.bos_token_id is not None: batch = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), batch), dim=0) batches.append(batch) return IndexDataset(tensors=torch.stack(batches)) def get_loader(name: str, tokenizer, seq_len: int, batch_size: int, add_bos_to_every: bool): _, test_data, field_name = get_dataset(name) dataset = process_data(test_data, tokenizer, seq_len, field_name, add_bos_to_every) return DataLoader(dataset, batch_size=batch_size, shuffle=False) @torch.no_grad() def evaluate_ppl(model, test_loader, device: str) -> float: nlls = [] for batch in tqdm(test_loader, desc="Running PPL", dynamic_ncols=True): batch = batch.to(device) outputs = model(batch) shift_logits = outputs.logits[:, :-1, :].contiguous() shift_labels = batch[:, 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(reduction="none") loss = loss_fct( shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1), ) nlls.append(loss.cpu()) return float(np.exp(torch.cat(nlls, dim=-1).mean().item())) def resolve_dtype(args) -> torch.dtype: if args.use_bfloat: return torch.bfloat16 dtype_name = args.dtype if args.dtype is not None else args.torch_dtype if dtype_name is None: dtype_name = "float16" dtype_map = { "float16": torch.float16, "fp16": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, "float32": torch.float32, "fp32": torch.float32, } if dtype_name not in dtype_map: raise ValueError(f"Unsupported dtype: {dtype_name}") return dtype_map[dtype_name] def normalize_datasets(datasets: Iterable[str]) -> list[str]: normalized = [] for dataset in datasets: normalized.append("wikitext2" if dataset == "wikitext" else dataset) return normalized def build_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Shared perplexity evaluation for abprune.") parser.add_argument("--base_model", "--model-path", dest="model_path", required=True) parser.add_argument("--output_dir", type=str, default=None) parser.add_argument("--dataset", nargs="+", default=["wikitext2", "ptb"]) parser.add_argument("--max_seq_len", "--seq-len", dest="seq_len", type=int, default=1024) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--device", default="cuda") parser.add_argument( "--dtype", default=None, choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"], ) parser.add_argument( "--torch_dtype", default=None, choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"], ) parser.add_argument("--use_bfloat", action="store_true") parser.add_argument("--add_bos_to_every", action="store_true") parser.add_argument("--fix_decapoda_config", action="store_true") parser.add_argument("--local_files_only", action="store_true") return parser def maybe_fix_decapoda_config(tokenizer, enabled: bool) -> None: if not enabled: return if tokenizer.bos_token_id is None and tokenizer.eos_token_id is not None: tokenizer.bos_token = tokenizer.eos_token if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token def ensure_llmpruner_on_path() -> None: repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner") if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path: sys.path.insert(0, llmpruner_root) def load_model_and_tokenizer(model_path: str, *, torch_dtype: torch.dtype, local_files_only: bool): if os.path.isfile(model_path) and model_path.endswith(".bin"): ensure_llmpruner_on_path() checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint: raise ValueError( "Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries." ) model = checkpoint["model"] tokenizer = checkpoint["tokenizer"] if torch_dtype is not None: model = model.to(dtype=torch_dtype) return model, tokenizer tokenizer = AutoTokenizer.from_pretrained( model_path, local_files_only=local_files_only, ) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, local_files_only=local_files_only, ) return model, tokenizer def main() -> None: parser = build_arg_parser() args = parser.parse_args() datasets = normalize_datasets(args.dataset) torch_dtype = resolve_dtype(args) model, tokenizer = load_model_and_tokenizer( args.model_path, torch_dtype=torch_dtype, local_files_only=args.local_files_only, ) maybe_fix_decapoda_config(tokenizer, args.fix_decapoda_config) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token model.eval() model.to(args.device) metrics = {} for dataset in datasets: test_loader = get_loader( dataset, tokenizer, seq_len=args.seq_len, batch_size=args.batch_size, add_bos_to_every=args.add_bos_to_every, ) metrics[dataset] = evaluate_ppl(model, test_loader, args.device) print(f"PPL-{dataset}: {metrics[dataset]} | add_bos_to_every: {args.add_bos_to_every} | seq_len: {args.seq_len}") mem = None if torch.cuda.is_available() and args.device.startswith("cuda"): mem = torch.cuda.memory_allocated(args.device) / 1024 / 1024 result = { "model_path": os.path.abspath(args.model_path), "datasets": datasets, "seq_len": args.seq_len, "batch_size": args.batch_size, "device": args.device, "dtype": str(torch_dtype).replace("torch.", ""), "add_bos_to_every": args.add_bos_to_every, "metrics": metrics, "params": int(sum(parameter.numel() for parameter in model.parameters())), "mem_mib": mem, } if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) filename = "ppl_bos.csv" if args.add_bos_to_every else "ppl.csv" csv_path = os.path.join(args.output_dir, filename) with open(csv_path, "w", newline="", encoding="utf-8") as handle: writer = csv.writer(handle) writer.writerow([*(f"ppl_{dataset}" for dataset in datasets), "params", "mem"]) writer.writerow([*(metrics[dataset] for dataset in datasets), result["params"], mem]) print(json.dumps(result, ensure_ascii=True)) if __name__ == "__main__": main()