File size: 3,240 Bytes
8b0aeb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""
compute_diversity.py – Compute n-gram diversity metrics for generated text samples.

Metrics (for n ∈ {1, 2, 4}):
  - distinct-n (corpus-level):
      (# unique n-grams across all samples) / (total # n-grams across all samples)
  - repetition-n (sample-level average):
      For each sample: (total_ngrams - unique_ngrams) / total_ngrams
      Then average over all samples.

Input : JSON file with {"samples": ["text1", "text2", ...]}
Output: JSON file with {"distinct_1", "distinct_2", "distinct_4",
                         "repetition_1", "repetition_2", "repetition_4",
                         "d1", "d2", "d4", "r1", "r2", "r4"}

Usage:
  python scripts/compute_diversity.py \
      --input eval/block_ar_gen_ppl/samples/pps1_lam1.0_temp1.0_gpu0.json \
      --output eval/block_ar_gen_ppl/diversity/pps1_lam1.0_temp1.0_gpu0.json
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path


def get_ngrams(tokens: list[str], n: int) -> list[tuple[str, ...]]:
    """Extract n-grams from a token list."""
    if len(tokens) < n:
        return []
    return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]


def compute_metrics(texts: list[str], n_values: tuple[int, ...] = (1, 2, 4)) -> dict:
    """Compute distinct-n and repetition-n for a list of texts."""
    metrics = {}

    for n in n_values:
        total_ngrams = 0
        unique_ngrams = set()
        sample_reps = []

        for text in texts:
            tokens = text.strip().split()
            ngrams = get_ngrams(tokens, n)
            if not ngrams:
                continue

            total_ngrams += len(ngrams)
            unique_ngrams.update(ngrams)

            unique_in_sample = len(set(ngrams))
            total_in_sample = len(ngrams)
            rep = (total_in_sample - unique_in_sample) / total_in_sample
            sample_reps.append(rep)

        distinct = len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0.0
        avg_repetition = sum(sample_reps) / len(sample_reps) if sample_reps else 0.0

        metrics[f"distinct_{n}"] = distinct
        metrics[f"repetition_{n}"] = avg_repetition
        metrics[f"d{n}"] = distinct
        metrics[f"r{n}"] = avg_repetition

    return metrics


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--input", type=str, required=True,
                   help="Path to input JSON with {'samples': [...]}.")
    p.add_argument("--output", type=str, required=True,
                   help="Path to write output diversity JSON.")
    args = p.parse_args()

    in_path = Path(args.input)
    out_path = Path(args.output)

    with open(in_path) as f:
        data = json.load(f)

    texts = data.get("samples", [])
    if not texts:
        raise RuntimeError(f"No samples found in {in_path}")

    metrics = compute_metrics(texts)
    metrics["num_samples"] = len(texts)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_path, "w") as f:
        json.dump(metrics, f, indent=2)

    print(f"Diversity metrics for {in_path.name}:")
    for k, v in metrics.items():
        print(f"  {k}: {v:.4f}")
    print(f"Saved -> {out_path}")


if __name__ == "__main__":
    main()