AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""
Evaluation utilities for ternary quantized models.
Supports:
- Perplexity evaluation on WikiText-2 and C4
- Side-by-side comparison with original model
- Prompt-bank generation benchmarking with collapse metrics
"""
from collections import Counter
import math
import time
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
DEFAULT_PROMPT_BANK = [
"The capital of France is",
"Once upon a time there was a",
"Question: What is 2 + 2?\nAnswer:",
"Write one sentence about the ocean:",
"A short conversation between Alice and Bob:\nAlice:",
"List three uses of a paperclip:\n1.",
]
def _synchronize_device(device: torch.device | str) -> None:
device_type = device.type if isinstance(device, torch.device) else str(device)
if device_type == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif device_type == "mps" and hasattr(torch, "mps") and torch.backends.mps.is_available():
torch.mps.synchronize()
def _extract_text_samples(dataset) -> list[str]:
for key in ("text", "sentence", "content"):
if key in dataset.column_names:
values = dataset[key]
return [str(v) for v in values if v is not None]
raise ValueError(
f"Could not find a text-like column in dataset. Available columns: {dataset.column_names}"
)
def _load_text_dataset(dataset_name: str, dataset_config: str, split: str):
from datasets import load_dataset
try:
return load_dataset(dataset_name, dataset_config, split=split)
except RuntimeError as exc:
if "Dataset scripts are no longer supported" not in str(exc):
raise
return load_dataset(
dataset_name,
dataset_config,
split=split,
trust_remote_code=True,
)
def _build_eval_text(
dataset,
tokenizer: AutoTokenizer,
max_tokens: Optional[int],
) -> str:
samples = _extract_text_samples(dataset)
if max_tokens is None:
return "\n\n".join(samples)
chunks = []
for idx, sample in enumerate(samples, start=1):
chunks.append(sample)
if idx < 8 and idx != len(samples):
continue
if idx % 8 != 0 and idx != len(samples):
continue
text = "\n\n".join(chunks)
num_tokens = int(
tokenizer(text, return_tensors="pt", truncation=False)["input_ids"].shape[1]
)
if num_tokens >= max_tokens:
return text
return "\n\n".join(chunks)
def get_default_prompt_bank(
primary_prompt: Optional[str] = None,
max_prompts: Optional[int] = None,
) -> list[str]:
prompts = []
if primary_prompt:
prompts.append(primary_prompt)
for prompt in DEFAULT_PROMPT_BANK:
if prompt not in prompts:
prompts.append(prompt)
if max_prompts is not None:
prompts = prompts[: max(1, int(max_prompts))]
return prompts
def _ngram_distinct_ratio(tokens: Sequence[int], n: int) -> float:
if len(tokens) < n or n <= 0:
return 1.0
grams = [tuple(tokens[idx : idx + n]) for idx in range(len(tokens) - n + 1)]
if not grams:
return 1.0
return len(set(grams)) / len(grams)
def _repeated_ngram_ratio(tokens: Sequence[int], n: int) -> float:
return 1.0 - _ngram_distinct_ratio(tokens, n)
def _longest_run(tokens: Sequence[int]) -> int:
if not tokens:
return 0
best = 1
current = 1
for idx in range(1, len(tokens)):
if tokens[idx] == tokens[idx - 1]:
current += 1
best = max(best, current)
else:
current = 1
return best
def compute_generation_collapse_metrics(tokens: Sequence[int]) -> dict[str, float]:
token_count = len(tokens)
if token_count == 0:
return {
"num_tokens": 0.0,
"unique_token_ratio": 0.0,
"distinct_2": 0.0,
"distinct_3": 0.0,
"repeated_3gram_ratio": 1.0,
"max_token_run": 0.0,
"max_token_run_ratio": 1.0,
"normalized_entropy": 0.0,
"collapse_score": 1.0,
}
counts = Counter(int(token) for token in tokens)
unique_ratio = len(counts) / token_count
distinct_2 = _ngram_distinct_ratio(tokens, 2)
distinct_3 = _ngram_distinct_ratio(tokens, 3)
repeated_3gram_ratio = _repeated_ngram_ratio(tokens, 3)
longest_run = _longest_run(tokens)
max_token_run_ratio = longest_run / token_count
entropy = 0.0
for count in counts.values():
prob = count / token_count
entropy -= prob * math.log(prob + 1e-12)
if token_count <= 1:
normalized_entropy = 0.0
else:
normalized_entropy = entropy / math.log(token_count)
normalized_entropy = max(0.0, min(1.0, normalized_entropy))
collapse_score = (
0.30 * (1.0 - unique_ratio)
+ 0.25 * repeated_3gram_ratio
+ 0.20 * (1.0 - distinct_2)
+ 0.10 * (1.0 - distinct_3)
+ 0.10 * max_token_run_ratio
+ 0.05 * (1.0 - normalized_entropy)
)
collapse_score = max(0.0, min(1.0, collapse_score))
return {
"num_tokens": float(token_count),
"unique_token_ratio": float(unique_ratio),
"distinct_2": float(distinct_2),
"distinct_3": float(distinct_3),
"repeated_3gram_ratio": float(repeated_3gram_ratio),
"max_token_run": float(longest_run),
"max_token_run_ratio": float(max_token_run_ratio),
"normalized_entropy": float(normalized_entropy),
"collapse_score": float(collapse_score),
}
@torch.no_grad()
def evaluate_prompt_bank(
model,
tokenizer: AutoTokenizer,
prompts: Sequence[str],
max_new_tokens: int = 64,
) -> dict:
if not prompts:
raise ValueError("Prompt bank must contain at least one prompt.")
device = next(model.parameters()).device
model.eval()
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
samples = []
mean_collapse = 0.0
mean_unique = 0.0
mean_distinct_2 = 0.0
mean_repeat_3 = 0.0
mean_entropy = 0.0
total_tokens = 0
total_time = 0.0
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
t0 = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
top_p=1.0,
pad_token_id=tokenizer.pad_token_id,
)
duration = time.time() - t0
generated = outputs[0][inputs["input_ids"].shape[1] :].detach().cpu()
token_list = [int(token) for token in generated.tolist()]
text = tokenizer.decode(generated, skip_special_tokens=True)
metrics = compute_generation_collapse_metrics(token_list)
metrics["generation_time_sec"] = float(duration)
metrics["tokens_per_sec"] = float(len(token_list) / max(duration, 1e-6))
samples.append(
{
"prompt": prompt,
"text": text,
"generated_token_ids": token_list,
"metrics": metrics,
}
)
mean_collapse += metrics["collapse_score"]
mean_unique += metrics["unique_token_ratio"]
mean_distinct_2 += metrics["distinct_2"]
mean_repeat_3 += metrics["repeated_3gram_ratio"]
mean_entropy += metrics["normalized_entropy"]
total_tokens += len(token_list)
total_time += duration
count = len(samples)
return {
"primary_prompt": samples[0]["prompt"],
"primary_text": samples[0]["text"],
"n_prompts": count,
"total_generated_tokens": int(total_tokens),
"total_generation_time_sec": float(total_time),
"tokens_per_sec": float(total_tokens / max(total_time, 1e-6)),
"avg_collapse_score": float(mean_collapse / count),
"worst_collapse_score": float(
max(sample["metrics"]["collapse_score"] for sample in samples)
),
"avg_unique_token_ratio": float(mean_unique / count),
"avg_distinct_2": float(mean_distinct_2 / count),
"avg_repeated_3gram_ratio": float(mean_repeat_3 / count),
"avg_normalized_entropy": float(mean_entropy / count),
"samples": samples,
}
@torch.no_grad()
def benchmark_runtime(
model,
tokenizer: AutoTokenizer,
prompts: Sequence[str],
max_new_tokens: int = 64,
warmup_runs: int = 1,
timed_runs: int = 3,
) -> dict:
"""Benchmark prefill and end-to-end generation on a prompt bank."""
if not prompts:
raise ValueError("Prompt list must contain at least one prompt.")
device = next(model.parameters()).device
model.eval()
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
encoded_prompts = [
tokenizer(prompt, return_tensors="pt").to(device)
for prompt in prompts
]
def run_prefill() -> tuple[float, int]:
total_time = 0.0
total_tokens = 0
for inputs in encoded_prompts:
_synchronize_device(device)
t0 = time.perf_counter()
model(**inputs)
_synchronize_device(device)
total_time += time.perf_counter() - t0
total_tokens += int(inputs["input_ids"].numel())
return total_time, total_tokens
def run_generate() -> tuple[float, int]:
total_time = 0.0
total_tokens = 0
for inputs in encoded_prompts:
_synchronize_device(device)
t0 = time.perf_counter()
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
top_p=1.0,
pad_token_id=tokenizer.pad_token_id,
)
_synchronize_device(device)
total_time += time.perf_counter() - t0
total_tokens += int(outputs.shape[1] - inputs["input_ids"].shape[1])
return total_time, total_tokens
for _ in range(max(0, int(warmup_runs))):
run_prefill()
run_generate()
prefill_latencies = []
prefill_tps = []
generation_latencies = []
generation_tps = []
for _ in range(max(1, int(timed_runs))):
prefill_time, prefill_tokens = run_prefill()
generation_time, generation_tokens = run_generate()
prefill_latencies.append(prefill_time)
generation_latencies.append(generation_time)
prefill_tps.append(prefill_tokens / max(prefill_time, 1e-6))
generation_tps.append(generation_tokens / max(generation_time, 1e-6))
prefill_sorted = sorted(prefill_latencies)
generation_sorted = sorted(generation_latencies)
prefill_mid = len(prefill_sorted) // 2
generation_mid = len(generation_sorted) // 2
return {
"n_prompts": len(prompts),
"warmup_runs": int(warmup_runs),
"timed_runs": int(timed_runs),
"prefill_tokens_per_sec_mean": float(sum(prefill_tps) / len(prefill_tps)),
"prefill_tokens_per_sec_min": float(min(prefill_tps)),
"prefill_latency_sec_median": float(prefill_sorted[prefill_mid]),
"generation_tokens_per_sec_mean": float(
sum(generation_tps) / len(generation_tps)
),
"generation_tokens_per_sec_min": float(min(generation_tps)),
"generation_latency_sec_median": float(generation_sorted[generation_mid]),
}
@torch.no_grad()
def evaluate_perplexity(
model,
tokenizer: AutoTokenizer,
dataset_name: str = "wikitext",
dataset_config: str = "wikitext-2-raw-v1",
split: str = "test",
seq_len: int = 2048,
batch_size: int = 1,
max_samples: Optional[int] = None,
) -> float:
"""
Evaluate perplexity of a model on a dataset.
Args:
model: HuggingFace model (original or ternary-wrapped)
tokenizer: tokenizer
dataset_name: dataset to evaluate on
dataset_config: dataset config
split: dataset split
seq_len: sequence length for evaluation
batch_size: batch size
max_samples: maximum number of samples (None = use all)
Returns:
Perplexity (float)
"""
dataset = _load_text_dataset(dataset_name, dataset_config, split)
target_tokens = None if max_samples is None else seq_len * max_samples + 1
text = _build_eval_text(dataset, tokenizer, max_tokens=target_tokens)
encodings = tokenizer(text, return_tensors="pt", truncation=False)
input_ids = encodings["input_ids"][0]
device = next(model.parameters()).device
model.eval()
total_nll = 0.0
total_tokens = 0
n_chunks = input_ids.shape[0] // seq_len
if max_samples is not None:
n_chunks = min(n_chunks, max_samples)
for i in tqdm(range(0, n_chunks, batch_size), desc="Evaluating perplexity"):
batch_chunks = []
for j in range(i, min(i + batch_size, n_chunks)):
start = j * seq_len
chunk = input_ids[start : start + seq_len]
batch_chunks.append(chunk)
batch = torch.stack(batch_chunks).to(device)
outputs = model(batch)
logits = outputs.logits
# Shift logits and labels for next-token prediction
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch[:, 1:].contiguous()
# Compute cross-entropy loss
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
reduction="sum",
)
total_nll += loss.item()
total_tokens += shift_labels.numel()
avg_nll = total_nll / total_tokens
perplexity = torch.exp(torch.tensor(avg_nll)).item()
return perplexity
def compare_models(
original_model_name: str,
ternary_model_dir: str,
device: str = "auto",
seq_len: int = 2048,
max_samples: int = 40,
) -> dict:
"""
Compare original and ternary model perplexity side by side.
Returns dict with both perplexities and the ratio.
"""
from ternary_quant.inference import load_ternary_model
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(original_model_name)
# Evaluate original
print("Evaluating original model...")
original_model = AutoModelForCausalLM.from_pretrained(
original_model_name,
torch_dtype=torch.float16,
device_map=device if device != "cpu" else None,
)
if device == "cpu":
original_model = original_model.to(device)
original_model.eval()
original_ppl = evaluate_perplexity(
original_model, tokenizer, seq_len=seq_len, max_samples=max_samples
)
del original_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Evaluate ternary
print("Evaluating ternary model...")
ternary_model, _ = load_ternary_model(ternary_model_dir, device=device)
ternary_ppl = evaluate_perplexity(
ternary_model, tokenizer, seq_len=seq_len, max_samples=max_samples
)
results = {
"original_perplexity": original_ppl,
"ternary_perplexity": ternary_ppl,
"perplexity_ratio": ternary_ppl / original_ppl,
"perplexity_increase": ternary_ppl - original_ppl,
}
print(f"\nOriginal PPL: {original_ppl:.2f}")
print(f"Ternary PPL: {ternary_ppl:.2f}")
print(f"Ratio: {results['perplexity_ratio']:.2f}x")
print(f"Increase: +{results['perplexity_increase']:.2f}")
return results