sad / scripts /compute_diversity.py
haochengsama's picture
Add files using upload-large-folder tool
8b0aeb2 verified
Raw
History Blame Contribute Delete
3.24 kB
#!/usr/bin/env python3
"""
compute_diversity.py – Compute n-gram diversity metrics for generated text samples.
Metrics (for n ∈ {1, 2, 4}):
- distinct-n (corpus-level):
(# unique n-grams across all samples) / (total # n-grams across all samples)
- repetition-n (sample-level average):
For each sample: (total_ngrams - unique_ngrams) / total_ngrams
Then average over all samples.
Input : JSON file with {"samples": ["text1", "text2", ...]}
Output: JSON file with {"distinct_1", "distinct_2", "distinct_4",
"repetition_1", "repetition_2", "repetition_4",
"d1", "d2", "d4", "r1", "r2", "r4"}
Usage:
python scripts/compute_diversity.py \
--input eval/block_ar_gen_ppl/samples/pps1_lam1.0_temp1.0_gpu0.json \
--output eval/block_ar_gen_ppl/diversity/pps1_lam1.0_temp1.0_gpu0.json
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
def get_ngrams(tokens: list[str], n: int) -> list[tuple[str, ...]]:
"""Extract n-grams from a token list."""
if len(tokens) < n:
return []
return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]
def compute_metrics(texts: list[str], n_values: tuple[int, ...] = (1, 2, 4)) -> dict:
"""Compute distinct-n and repetition-n for a list of texts."""
metrics = {}
for n in n_values:
total_ngrams = 0
unique_ngrams = set()
sample_reps = []
for text in texts:
tokens = text.strip().split()
ngrams = get_ngrams(tokens, n)
if not ngrams:
continue
total_ngrams += len(ngrams)
unique_ngrams.update(ngrams)
unique_in_sample = len(set(ngrams))
total_in_sample = len(ngrams)
rep = (total_in_sample - unique_in_sample) / total_in_sample
sample_reps.append(rep)
distinct = len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0.0
avg_repetition = sum(sample_reps) / len(sample_reps) if sample_reps else 0.0
metrics[f"distinct_{n}"] = distinct
metrics[f"repetition_{n}"] = avg_repetition
metrics[f"d{n}"] = distinct
metrics[f"r{n}"] = avg_repetition
return metrics
def main():
p = argparse.ArgumentParser()
p.add_argument("--input", type=str, required=True,
help="Path to input JSON with {'samples': [...]}.")
p.add_argument("--output", type=str, required=True,
help="Path to write output diversity JSON.")
args = p.parse_args()
in_path = Path(args.input)
out_path = Path(args.output)
with open(in_path) as f:
data = json.load(f)
texts = data.get("samples", [])
if not texts:
raise RuntimeError(f"No samples found in {in_path}")
metrics = compute_metrics(texts)
metrics["num_samples"] = len(texts)
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump(metrics, f, indent=2)
print(f"Diversity metrics for {in_path.name}:")
for k, v in metrics.items():
print(f" {k}: {v:.4f}")
print(f"Saved -> {out_path}")
if __name__ == "__main__":
main()