#!/usr/bin/env python3 """ compute_diversity.py – Compute n-gram diversity metrics for generated text samples. Metrics (for n ∈ {1, 2, 4}): - distinct-n (corpus-level): (# unique n-grams across all samples) / (total # n-grams across all samples) - repetition-n (sample-level average): For each sample: (total_ngrams - unique_ngrams) / total_ngrams Then average over all samples. Input : JSON file with {"samples": ["text1", "text2", ...]} Output: JSON file with {"distinct_1", "distinct_2", "distinct_4", "repetition_1", "repetition_2", "repetition_4", "d1", "d2", "d4", "r1", "r2", "r4"} Usage: python scripts/compute_diversity.py \ --input eval/block_ar_gen_ppl/samples/pps1_lam1.0_temp1.0_gpu0.json \ --output eval/block_ar_gen_ppl/diversity/pps1_lam1.0_temp1.0_gpu0.json """ from __future__ import annotations import argparse import json from pathlib import Path def get_ngrams(tokens: list[str], n: int) -> list[tuple[str, ...]]: """Extract n-grams from a token list.""" if len(tokens) < n: return [] return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] def compute_metrics(texts: list[str], n_values: tuple[int, ...] = (1, 2, 4)) -> dict: """Compute distinct-n and repetition-n for a list of texts.""" metrics = {} for n in n_values: total_ngrams = 0 unique_ngrams = set() sample_reps = [] for text in texts: tokens = text.strip().split() ngrams = get_ngrams(tokens, n) if not ngrams: continue total_ngrams += len(ngrams) unique_ngrams.update(ngrams) unique_in_sample = len(set(ngrams)) total_in_sample = len(ngrams) rep = (total_in_sample - unique_in_sample) / total_in_sample sample_reps.append(rep) distinct = len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0.0 avg_repetition = sum(sample_reps) / len(sample_reps) if sample_reps else 0.0 metrics[f"distinct_{n}"] = distinct metrics[f"repetition_{n}"] = avg_repetition metrics[f"d{n}"] = distinct metrics[f"r{n}"] = avg_repetition return metrics def main(): p = argparse.ArgumentParser() p.add_argument("--input", type=str, required=True, help="Path to input JSON with {'samples': [...]}.") p.add_argument("--output", type=str, required=True, help="Path to write output diversity JSON.") args = p.parse_args() in_path = Path(args.input) out_path = Path(args.output) with open(in_path) as f: data = json.load(f) texts = data.get("samples", []) if not texts: raise RuntimeError(f"No samples found in {in_path}") metrics = compute_metrics(texts) metrics["num_samples"] = len(texts) out_path.parent.mkdir(parents=True, exist_ok=True) with open(out_path, "w") as f: json.dump(metrics, f, indent=2) print(f"Diversity metrics for {in_path.name}:") for k, v in metrics.items(): print(f" {k}: {v:.4f}") print(f"Saved -> {out_path}") if __name__ == "__main__": main()