| |
| """ |
| compute_diversity.py – Compute n-gram diversity metrics for generated text samples. |
| |
| Metrics (for n ∈ {1, 2, 4}): |
| - distinct-n (corpus-level): |
| (# unique n-grams across all samples) / (total # n-grams across all samples) |
| - repetition-n (sample-level average): |
| For each sample: (total_ngrams - unique_ngrams) / total_ngrams |
| Then average over all samples. |
| |
| Input : JSON file with {"samples": ["text1", "text2", ...]} |
| Output: JSON file with {"distinct_1", "distinct_2", "distinct_4", |
| "repetition_1", "repetition_2", "repetition_4", |
| "d1", "d2", "d4", "r1", "r2", "r4"} |
| |
| Usage: |
| python scripts/compute_diversity.py \ |
| --input eval/block_ar_gen_ppl/samples/pps1_lam1.0_temp1.0_gpu0.json \ |
| --output eval/block_ar_gen_ppl/diversity/pps1_lam1.0_temp1.0_gpu0.json |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
|
|
| def get_ngrams(tokens: list[str], n: int) -> list[tuple[str, ...]]: |
| """Extract n-grams from a token list.""" |
| if len(tokens) < n: |
| return [] |
| return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] |
|
|
|
|
| def compute_metrics(texts: list[str], n_values: tuple[int, ...] = (1, 2, 4)) -> dict: |
| """Compute distinct-n and repetition-n for a list of texts.""" |
| metrics = {} |
|
|
| for n in n_values: |
| total_ngrams = 0 |
| unique_ngrams = set() |
| sample_reps = [] |
|
|
| for text in texts: |
| tokens = text.strip().split() |
| ngrams = get_ngrams(tokens, n) |
| if not ngrams: |
| continue |
|
|
| total_ngrams += len(ngrams) |
| unique_ngrams.update(ngrams) |
|
|
| unique_in_sample = len(set(ngrams)) |
| total_in_sample = len(ngrams) |
| rep = (total_in_sample - unique_in_sample) / total_in_sample |
| sample_reps.append(rep) |
|
|
| distinct = len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0.0 |
| avg_repetition = sum(sample_reps) / len(sample_reps) if sample_reps else 0.0 |
|
|
| metrics[f"distinct_{n}"] = distinct |
| metrics[f"repetition_{n}"] = avg_repetition |
| metrics[f"d{n}"] = distinct |
| metrics[f"r{n}"] = avg_repetition |
|
|
| return metrics |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--input", type=str, required=True, |
| help="Path to input JSON with {'samples': [...]}.") |
| p.add_argument("--output", type=str, required=True, |
| help="Path to write output diversity JSON.") |
| args = p.parse_args() |
|
|
| in_path = Path(args.input) |
| out_path = Path(args.output) |
|
|
| with open(in_path) as f: |
| data = json.load(f) |
|
|
| texts = data.get("samples", []) |
| if not texts: |
| raise RuntimeError(f"No samples found in {in_path}") |
|
|
| metrics = compute_metrics(texts) |
| metrics["num_samples"] = len(texts) |
|
|
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(out_path, "w") as f: |
| json.dump(metrics, f, indent=2) |
|
|
| print(f"Diversity metrics for {in_path.name}:") |
| for k, v in metrics.items(): |
| print(f" {k}: {v:.4f}") |
| print(f"Saved -> {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|