| |
| """Perplexity evaluation for causal LMs on HF datasets or provided text.""" |
|
|
| import argparse |
| import json |
| import math |
| import os |
| from typing import Dict, Iterable, List, Optional |
|
|
| import torch |
|
|
| try: |
| from datasets import load_dataset |
| except Exception: |
| load_dataset = None |
|
|
| try: |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| except Exception as exc: |
| raise SystemExit("transformers is required: pip install transformers") from exc |
|
|
| try: |
| from tqdm import tqdm |
| except Exception: |
| tqdm = None |
|
|
|
|
| def _tqdm_enabled() -> bool: |
| value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0")) |
| return value.strip().lower() not in {"1", "true", "yes", "on"} |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Compute perplexity for a causal LM on one or more datasets." |
| ) |
| parser.add_argument("--model", required=True, help="HF model id or local path") |
| parser.add_argument( |
| "--dataset", |
| action="append", |
| default=[], |
| help="HF dataset name (repeatable).", |
| ) |
| parser.add_argument( |
| "--dataset_config", |
| action="append", |
| default=[], |
| help="Optional dataset config (repeatable or single shared config).", |
| ) |
| parser.add_argument( |
| "--dataset_split", |
| default="test", |
| help="Dataset split to use (default: test)", |
| ) |
| parser.add_argument( |
| "--dataset_text_field", |
| default=None, |
| help="Text field in dataset (default: auto-detect, applies to all datasets)", |
| ) |
| parser.add_argument( |
| "--text", |
| action="append", |
| default=[], |
| help="Inline text samples (can pass multiple)", |
| ) |
| parser.add_argument( |
| "--text_file", |
| default=None, |
| help="Path to a text file for evaluation data", |
| ) |
| parser.add_argument( |
| "--num_samples", |
| type=int, |
| default=0, |
| help="Number of token sequences to use per dataset (0 = all)", |
| ) |
| parser.add_argument( |
| "--seq_len", type=int, default=2048, help="Sequence length" |
| ) |
| parser.add_argument( |
| "--batch_size", type=int, default=2, help="Batch size" |
| ) |
| parser.add_argument( |
| "--max_batches", |
| type=int, |
| default=None, |
| help="Optional max number of batches to evaluate per dataset", |
| ) |
| parser.add_argument( |
| "--model_family", |
| type=str, |
| choices=["auto", "llama", "qwen"], |
| default="auto", |
| help="Model family for BOS handling", |
| ) |
| parser.add_argument( |
| "--add_bos", |
| type=str, |
| choices=["auto", "always", "never"], |
| default="auto", |
| help="Whether to prepend BOS to each sample", |
| ) |
| parser.add_argument( |
| "--device", |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device for model + compute", |
| ) |
| parser.add_argument( |
| "--dtype", |
| default="auto", |
| choices=["auto", "float32", "float16", "bfloat16"], |
| help="Model dtype", |
| ) |
| parser.add_argument( |
| "--seed", type=int, default=0, help="Random seed for shuffling" |
| ) |
| parser.add_argument( |
| "--shuffle", |
| action="store_true", |
| help="Shuffle dataset before sampling", |
| ) |
| parser.add_argument( |
| "--num_workers", |
| type=int, |
| default=0, |
| help="DataLoader workers", |
| ) |
| parser.add_argument( |
| "--cache_dir", |
| default=None, |
| help="Optional datasets cache directory", |
| ) |
| parser.add_argument( |
| "--trust_remote_code", |
| action="store_true", |
| help="Allow custom model code from hub", |
| ) |
| parser.add_argument( |
| "--output", |
| default=None, |
| help="Optional JSON output path", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _normalize_config(config: Optional[str]) -> Optional[str]: |
| if config is None: |
| return None |
| if config.strip().lower() in {"none", "null", "-"}: |
| return None |
| return config |
|
|
|
|
| def _expand_dataset_configs( |
| datasets: List[str], configs: List[str] |
| ) -> List[Optional[str]]: |
| if not configs: |
| return [None] * len(datasets) |
| if len(configs) == 1 and len(datasets) > 1: |
| return [_normalize_config(configs[0])] * len(datasets) |
| if len(configs) != len(datasets): |
| raise SystemExit( |
| "Provide zero, one, or matching-count --dataset_config values." |
| ) |
| return [_normalize_config(cfg) for cfg in configs] |
|
|
|
|
| def guess_text_field(dataset) -> str: |
| if hasattr(dataset, "column_names") and dataset.column_names: |
| if "text" in dataset.column_names: |
| return "text" |
| return dataset.column_names[0] |
| if hasattr(dataset, "features"): |
| names = list(dataset.features.keys()) |
| if "text" in names: |
| return "text" |
| if names: |
| return names[0] |
| return "text" |
|
|
|
|
| def _infer_model_family(model) -> str: |
| model_type = str(getattr(getattr(model, "config", None), "model_type", "")).lower() |
| architectures = getattr(getattr(model, "config", None), "architectures", []) |
| arch_lower = " ".join(str(name).lower() for name in architectures) |
| if "qwen" in model_type or "qwen" in arch_lower: |
| return "qwen" |
| if "llama" in model_type or "llama" in arch_lower: |
| return "llama" |
| return "unknown" |
|
|
|
|
| def _resolve_add_bos(setting: str, model_family: str, tokenizer) -> bool: |
| if setting == "always": |
| return True |
| if setting == "never": |
| return False |
| if model_family == "llama": |
| return True |
| if model_family == "qwen": |
| return False |
| if hasattr(tokenizer, "add_bos_token"): |
| return bool(getattr(tokenizer, "add_bos_token")) |
| init_kwargs = getattr(tokenizer, "init_kwargs", None) |
| if isinstance(init_kwargs, dict) and "add_bos_token" in init_kwargs: |
| return bool(init_kwargs["add_bos_token"]) |
| return False |
|
|
|
|
| def build_token_chunks( |
| texts: Iterable[str], |
| tokenizer, |
| seq_len: int, |
| num_samples: int, |
| add_bos: bool = False, |
| ) -> List[torch.Tensor]: |
| chunks: List[torch.Tensor] = [] |
| buffer: List[int] = [] |
| for text in texts: |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| if add_bos and tokenizer.bos_token_id is not None: |
| ids = [tokenizer.bos_token_id] + ids |
| if not ids: |
| continue |
| buffer.extend(ids) |
| while len(buffer) >= seq_len and len(chunks) < num_samples: |
| chunk = buffer[:seq_len] |
| buffer = buffer[seq_len:] |
| chunks.append(torch.tensor(chunk, dtype=torch.long)) |
| if len(chunks) >= num_samples: |
| break |
| return chunks |
|
|
|
|
| def get_dtype(dtype: str): |
| if dtype == "auto": |
| return None |
| if dtype == "float16": |
| return torch.float16 |
| if dtype == "bfloat16": |
| return torch.bfloat16 |
| return torch.float32 |
|
|
|
|
| def compute_ppl(model, dataloader, device: str, max_batches: Optional[int]) -> float: |
| model.eval() |
| nll_sum = 0.0 |
| token_count = 0 |
| iterator = dataloader |
| if tqdm is not None and _tqdm_enabled(): |
| iterator = tqdm(dataloader, desc="PPL", unit="batch") |
| with torch.no_grad(): |
| for step, batch in enumerate(iterator): |
| if isinstance(batch, dict): |
| input_ids = batch["input_ids"].to(device) |
| else: |
| input_ids = batch[0].to(device) |
| outputs = model(input_ids=input_ids) |
| logits = outputs.logits |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = input_ids[:, 1:].contiguous() |
| loss = torch.nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| reduction="sum", |
| ) |
| nll_sum += float(loss.item()) |
| token_count += shift_labels.numel() |
| if max_batches is not None and step + 1 >= max_batches: |
| break |
|
|
| if token_count == 0: |
| raise RuntimeError("No tokens processed; check evaluation inputs.") |
|
|
| return math.exp(nll_sum / token_count) |
|
|
|
|
| def _load_lm_dataset( |
| tokenizer, |
| dataset_name: str, |
| config: Optional[str], |
| split: str, |
| text_field: Optional[str], |
| seq_len: int, |
| add_bos: bool, |
| cache_dir: Optional[str], |
| ): |
| dataset = load_dataset( |
| dataset_name, |
| config, |
| split=split, |
| trust_remote_code=True, |
| cache_dir=cache_dir, |
| ) |
|
|
| field = text_field or guess_text_field(dataset) |
|
|
| def is_valid_text(example) -> bool: |
| value = example.get(field) |
| return isinstance(value, str) and value.strip() != "" |
|
|
| dataset = dataset.filter(is_valid_text, desc=f"filter-{dataset_name}") |
|
|
| def tokenize_fn(examples): |
| tokenized = tokenizer( |
| examples[field], |
| add_special_tokens=False, |
| return_attention_mask=False, |
| ) |
| if add_bos and tokenizer.bos_token_id is not None: |
| tokenized["input_ids"] = [ |
| [tokenizer.bos_token_id] + ids for ids in tokenized["input_ids"] |
| ] |
| return tokenized |
|
|
| tokenized = dataset.map( |
| tokenize_fn, |
| batched=True, |
| remove_columns=dataset.column_names, |
| desc=f"tokenize-{dataset_name}", |
| ) |
|
|
| def group_texts(examples): |
| concatenated = [] |
| for ids in examples["input_ids"]: |
| concatenated.extend(ids) |
| total_length = (len(concatenated) // seq_len) * seq_len |
| if total_length == 0: |
| return {"input_ids": []} |
| return { |
| "input_ids": [ |
| concatenated[i : i + seq_len] for i in range(0, total_length, seq_len) |
| ] |
| } |
|
|
| lm_dataset = tokenized.map( |
| group_texts, |
| batched=True, |
| batch_size=1000, |
| remove_columns=tokenized.column_names, |
| desc=f"group-{dataset_name}", |
| ) |
| lm_dataset.set_format(type="torch", columns=["input_ids"]) |
| return lm_dataset |
|
|
|
|
| def prepare_ppl_dataloaders( |
| tokenizer, |
| datasets: List[str], |
| configs: List[Optional[str]], |
| split: str, |
| text_field: Optional[str], |
| num_samples: int, |
| seq_len: int, |
| batch_size: int, |
| seed: int, |
| shuffle: bool, |
| model_family: str = "auto", |
| add_bos: str = "auto", |
| cache_dir: Optional[str] = None, |
| num_workers: int = 0, |
| model=None, |
| ) -> Dict[str, torch.utils.data.DataLoader]: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for dataset evaluation") |
|
|
| resolved_family = model_family |
| if resolved_family == "auto": |
| if model is None: |
| raise SystemExit("model is required when model_family is 'auto'") |
| resolved_family = _infer_model_family(model) |
| use_bos = _resolve_add_bos(add_bos, resolved_family, tokenizer) |
| if use_bos and tokenizer.bos_token_id is None: |
| use_bos = False |
|
|
| dataloaders: Dict[str, torch.utils.data.DataLoader] = {} |
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| lm_dataset = _load_lm_dataset( |
| tokenizer=tokenizer, |
| dataset_name=dataset_name, |
| config=config, |
| split=split, |
| text_field=text_field, |
| seq_len=seq_len, |
| add_bos=use_bos, |
| cache_dir=cache_dir, |
| ) |
| if shuffle: |
| try: |
| lm_dataset = lm_dataset.shuffle(seed=seed + idx) |
| except Exception: |
| pass |
| if num_samples and hasattr(lm_dataset, "__len__"): |
| limit = min(num_samples, len(lm_dataset)) |
| lm_dataset = lm_dataset.select(range(limit)) |
|
|
| data_loader = torch.utils.data.DataLoader( |
| lm_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| ) |
| label = dataset_name if config is None else f"{dataset_name}:{config}" |
| dataloaders[label] = data_loader |
|
|
| return dataloaders |
|
|
|
|
| def evaluate_ppl_dataloaders( |
| model, |
| dataloaders: Dict[str, torch.utils.data.DataLoader], |
| device: str, |
| max_batches: Optional[int] = None, |
| ) -> Dict[str, float]: |
| results: Dict[str, float] = {} |
| for label, data_loader in dataloaders.items(): |
| ppl = compute_ppl(model, data_loader, device, max_batches=max_batches) |
| results[label] = ppl |
| return results |
|
|
|
|
| def evaluate_ppl_datasets( |
| model, |
| tokenizer, |
| datasets: List[str], |
| configs: List[Optional[str]], |
| split: str, |
| text_field: Optional[str], |
| num_samples: int, |
| seq_len: int, |
| batch_size: int, |
| device: str, |
| seed: int, |
| shuffle: bool, |
| model_family: str = "auto", |
| add_bos: str = "auto", |
| max_batches: Optional[int] = None, |
| cache_dir: Optional[str] = None, |
| num_workers: int = 0, |
| ) -> Dict[str, float]: |
| if load_dataset is None: |
| raise SystemExit("datasets is required for dataset evaluation") |
|
|
| resolved_family = model_family |
| if resolved_family == "auto": |
| resolved_family = _infer_model_family(model) |
| use_bos = _resolve_add_bos(add_bos, resolved_family, tokenizer) |
| if use_bos and tokenizer.bos_token_id is None: |
| use_bos = False |
|
|
| results: Dict[str, float] = {} |
| for idx, (dataset_name, config) in enumerate(zip(datasets, configs)): |
| lm_dataset = _load_lm_dataset( |
| tokenizer=tokenizer, |
| dataset_name=dataset_name, |
| config=config, |
| split=split, |
| text_field=text_field, |
| seq_len=seq_len, |
| add_bos=use_bos, |
| cache_dir=cache_dir, |
| ) |
| if shuffle: |
| try: |
| lm_dataset = lm_dataset.shuffle(seed=seed + idx) |
| except Exception: |
| pass |
| if num_samples and hasattr(lm_dataset, "__len__"): |
| limit = min(num_samples, len(lm_dataset)) |
| lm_dataset = lm_dataset.select(range(limit)) |
|
|
| data_loader = torch.utils.data.DataLoader( |
| lm_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| ) |
| label = dataset_name if config is None else f"{dataset_name}:{config}" |
| ppl = compute_ppl(model, data_loader, device, max_batches=max_batches) |
| results[label] = ppl |
| return results |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| torch.manual_seed(args.seed) |
|
|
| dtype = get_dtype(args.dtype) |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model, |
| torch_dtype=dtype, |
| trust_remote_code=args.trust_remote_code, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model, trust_remote_code=args.trust_remote_code |
| ) |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model.to(args.device) |
|
|
| results: Dict[str, float] = {} |
| resolved_family = args.model_family |
| if resolved_family == "auto": |
| resolved_family = _infer_model_family(model) |
| use_bos = _resolve_add_bos(args.add_bos, resolved_family, tokenizer) |
| if use_bos and tokenizer.bos_token_id is None: |
| use_bos = False |
|
|
| if args.dataset: |
| datasets = list(args.dataset) |
| configs = _expand_dataset_configs(datasets, list(args.dataset_config)) |
| results.update( |
| evaluate_ppl_datasets( |
| model, |
| tokenizer, |
| datasets=datasets, |
| configs=configs, |
| split=args.dataset_split, |
| text_field=args.dataset_text_field, |
| num_samples=args.num_samples, |
| seq_len=args.seq_len, |
| batch_size=args.batch_size, |
| device=args.device, |
| seed=args.seed, |
| shuffle=args.shuffle, |
| model_family=resolved_family, |
| add_bos="always" if use_bos else "never", |
| max_batches=args.max_batches, |
| cache_dir=args.cache_dir, |
| num_workers=args.num_workers, |
| ) |
| ) |
|
|
| if args.text_file or args.text: |
| custom_texts: List[str] = [] |
| if args.text_file: |
| with open(args.text_file, "r", encoding="utf-8") as handle: |
| custom_texts.extend([line.strip() for line in handle if line.strip()]) |
| if args.text: |
| custom_texts.extend([t for t in args.text if t]) |
| if custom_texts: |
| chunks = build_token_chunks( |
| custom_texts, |
| tokenizer, |
| args.seq_len, |
| args.num_samples if args.num_samples > 0 else 1_000_000, |
| add_bos=use_bos, |
| ) |
| if not chunks: |
| raise SystemExit( |
| "Not enough custom text to build token sequences. " |
| "Provide more --text/--text_file content or reduce --seq_len." |
| ) |
| dataset = torch.utils.data.TensorDataset(torch.stack(chunks)) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=False |
| ) |
| results["custom"] = compute_ppl( |
| model, dataloader, args.device, max_batches=args.max_batches |
| ) |
|
|
| if not results: |
| raise SystemExit("Provide --dataset and/or --text/--text_file for evaluation") |
|
|
| print("Perplexity results:") |
| for name, ppl in results.items(): |
| print(f"{name}: {ppl:.4f}") |
|
|
| if args.output: |
| with open(args.output, "w", encoding="utf-8") as handle: |
| json.dump({"model": args.model, "results": results}, handle, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|