temp_ss / src /eval_ppl.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
import argparse
import csv
import json
import os
import sys
from typing import Iterable
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
class IndexDataset(Dataset):
def __init__(self, tensors: torch.Tensor):
self.tensors = tensors
def __getitem__(self, index: int) -> torch.Tensor:
return self.tensors[index]
def __len__(self) -> int:
return len(self.tensors)
def get_dataset(name: str):
if name == "wikitext2":
train_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
test_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
return train_data, test_data, "text"
if name == "ptb":
train_data = load_dataset("ptb_text_only", "penn_treebank", split="train")
test_data = load_dataset("ptb_text_only", "penn_treebank", split="validation")
return train_data, test_data, "sentence"
raise ValueError(f"Unsupported dataset: {name}")
def process_data(samples, tokenizer, seq_len: int, field_name: str, add_bos_to_every: bool) -> IndexDataset:
test_ids = tokenizer(
"\n\n".join(samples[field_name]),
return_tensors="pt",
add_special_tokens=False,
).input_ids[0]
if not add_bos_to_every and tokenizer.bos_token_id is not None:
test_ids = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), test_ids), dim=0)
batches = []
num_samples = test_ids.numel() // seq_len
for index in range(num_samples):
batch = test_ids[(index * seq_len) : ((index + 1) * seq_len)]
if add_bos_to_every and tokenizer.bos_token_id is not None:
batch = torch.cat((torch.LongTensor([tokenizer.bos_token_id]), batch), dim=0)
batches.append(batch)
return IndexDataset(tensors=torch.stack(batches))
def get_loader(name: str, tokenizer, seq_len: int, batch_size: int, add_bos_to_every: bool):
_, test_data, field_name = get_dataset(name)
dataset = process_data(test_data, tokenizer, seq_len, field_name, add_bos_to_every)
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
@torch.no_grad()
def evaluate_ppl(model, test_loader, device: str) -> float:
nlls = []
for batch in tqdm(test_loader, desc="Running PPL", dynamic_ncols=True):
batch = batch.to(device)
outputs = model(batch)
shift_logits = outputs.logits[:, :-1, :].contiguous()
shift_labels = batch[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
nlls.append(loss.cpu())
return float(np.exp(torch.cat(nlls, dim=-1).mean().item()))
def resolve_dtype(args) -> torch.dtype:
if args.use_bfloat:
return torch.bfloat16
dtype_name = args.dtype if args.dtype is not None else args.torch_dtype
if dtype_name is None:
dtype_name = "float16"
dtype_map = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
if dtype_name not in dtype_map:
raise ValueError(f"Unsupported dtype: {dtype_name}")
return dtype_map[dtype_name]
def normalize_datasets(datasets: Iterable[str]) -> list[str]:
normalized = []
for dataset in datasets:
normalized.append("wikitext2" if dataset == "wikitext" else dataset)
return normalized
def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Shared perplexity evaluation for abprune.")
parser.add_argument("--base_model", "--model-path", dest="model_path", required=True)
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=["wikitext2", "ptb"])
parser.add_argument("--max_seq_len", "--seq-len", dest="seq_len", type=int, default=1024)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--device", default="cuda")
parser.add_argument(
"--dtype",
default=None,
choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"],
)
parser.add_argument(
"--torch_dtype",
default=None,
choices=["float16", "fp16", "bfloat16", "bf16", "float32", "fp32"],
)
parser.add_argument("--use_bfloat", action="store_true")
parser.add_argument("--add_bos_to_every", action="store_true")
parser.add_argument("--fix_decapoda_config", action="store_true")
parser.add_argument("--local_files_only", action="store_true")
return parser
def maybe_fix_decapoda_config(tokenizer, enabled: bool) -> None:
if not enabled:
return
if tokenizer.bos_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.bos_token = tokenizer.eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
def ensure_llmpruner_on_path() -> None:
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
llmpruner_root = os.path.join(repo_root, "compare_model", "LLM-Pruner")
if os.path.isdir(llmpruner_root) and llmpruner_root not in sys.path:
sys.path.insert(0, llmpruner_root)
def load_model_and_tokenizer(model_path: str, *, torch_dtype: torch.dtype, local_files_only: bool):
if os.path.isfile(model_path) and model_path.endswith(".bin"):
ensure_llmpruner_on_path()
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
if not isinstance(checkpoint, dict) or "model" not in checkpoint or "tokenizer" not in checkpoint:
raise ValueError(
"Expected an LLM-Pruner checkpoint dict with `model` and `tokenizer` entries."
)
model = checkpoint["model"]
tokenizer = checkpoint["tokenizer"]
if torch_dtype is not None:
model = model.to(dtype=torch_dtype)
return model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
local_files_only=local_files_only,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
return model, tokenizer
def main() -> None:
parser = build_arg_parser()
args = parser.parse_args()
datasets = normalize_datasets(args.dataset)
torch_dtype = resolve_dtype(args)
model, tokenizer = load_model_and_tokenizer(
args.model_path,
torch_dtype=torch_dtype,
local_files_only=args.local_files_only,
)
maybe_fix_decapoda_config(tokenizer, args.fix_decapoda_config)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
model.to(args.device)
metrics = {}
for dataset in datasets:
test_loader = get_loader(
dataset,
tokenizer,
seq_len=args.seq_len,
batch_size=args.batch_size,
add_bos_to_every=args.add_bos_to_every,
)
metrics[dataset] = evaluate_ppl(model, test_loader, args.device)
print(f"PPL-{dataset}: {metrics[dataset]} | add_bos_to_every: {args.add_bos_to_every} | seq_len: {args.seq_len}")
mem = None
if torch.cuda.is_available() and args.device.startswith("cuda"):
mem = torch.cuda.memory_allocated(args.device) / 1024 / 1024
result = {
"model_path": os.path.abspath(args.model_path),
"datasets": datasets,
"seq_len": args.seq_len,
"batch_size": args.batch_size,
"device": args.device,
"dtype": str(torch_dtype).replace("torch.", ""),
"add_bos_to_every": args.add_bos_to_every,
"metrics": metrics,
"params": int(sum(parameter.numel() for parameter in model.parameters())),
"mem_mib": mem,
}
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
filename = "ppl_bos.csv" if args.add_bos_to_every else "ppl.csv"
csv_path = os.path.join(args.output_dir, filename)
with open(csv_path, "w", newline="", encoding="utf-8") as handle:
writer = csv.writer(handle)
writer.writerow([*(f"ppl_{dataset}" for dataset in datasets), "params", "mem"])
writer.writerow([*(metrics[dataset] for dataset in datasets), result["params"], mem])
print(json.dumps(result, ensure_ascii=True))
if __name__ == "__main__":
main()