syedmohaiminulhoque's picture
Complete CDD implementation: Constrained Discrete Diffusion (arXiv:2503.09790v3)
2d0a056 verified
"""
Evaluation Utilities for CDD
==============================
Provides evaluation metrics used in the paper:
- Perplexity (PPL) using GPT-2-XL
- Toxicity scoring
- LLM-as-a-Judge coherence evaluation
- Entropy diversity metric
- Molecular metrics (validity, QED, SA, novelty)
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Optional
from collections import Counter
def compute_entropy(texts: List[str]) -> float:
"""Compute entropy-based diversity metric (Appendix E).
For a sequence of length L with K distinct tokens,
entropy H = -Σ (L_k/L) * log(L_k/L)
Higher entropy = more diverse generation.
Args:
texts: List of generated texts.
Returns:
Average entropy across texts.
"""
entropies = []
for text in texts:
tokens = list(text) # Character-level
if not tokens:
continue
counter = Counter(tokens)
L = len(tokens)
entropy = 0.0
for count in counter.values():
p = count / L
if p > 0:
entropy -= p * np.log(p)
entropies.append(entropy)
return np.mean(entropies) if entropies else 0.0
def compute_self_bleu(texts: List[str], n_gram: int = 4) -> float:
"""Compute Self-BLEU diversity metric.
Lower Self-BLEU = more diverse generation.
Args:
texts: List of generated texts.
n_gram: N-gram size for BLEU.
Returns:
Average Self-BLEU score.
"""
try:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
smoothie = SmoothingFunction().method1
scores = []
for i, text in enumerate(texts):
hypothesis = text.split()
references = [t.split() for j, t in enumerate(texts) if j != i]
if not hypothesis or not references:
continue
# Use only a subset for efficiency
refs_subset = references[:min(10, len(references))]
try:
score = sentence_bleu(
refs_subset, hypothesis,
smoothing_function=smoothie,
)
scores.append(score)
except Exception:
continue
return np.mean(scores) if scores else 0.0
except ImportError:
print("NLTK not available for Self-BLEU computation.")
return -1.0
def violation_rate(
scores: List[float],
threshold: float,
) -> float:
"""Compute constraint violation rate.
Args:
scores: List of constraint metric values.
threshold: Constraint threshold.
Returns:
Fraction of samples that violate the constraint.
"""
violations = sum(1 for s in scores if s > threshold)
return violations / len(scores) if scores else 0.0
def format_results_table(results: Dict) -> str:
"""Format results as a readable table matching paper format.
Args:
results: Dictionary of evaluation results.
Returns:
Formatted string table.
"""
lines = []
lines.append("=" * 70)
lines.append(f"{'Metric':<30} {'Value':>15}")
lines.append("-" * 70)
for key, value in results.items():
if isinstance(value, float):
lines.append(f"{key:<30} {value:>15.4f}")
elif isinstance(value, int):
lines.append(f"{key:<30} {value:>15d}")
elif isinstance(value, str):
lines.append(f"{key:<30} {value:>15}")
elif isinstance(value, dict):
lines.append(f"{key}:")
for k, v in value.items():
if isinstance(v, float):
lines.append(f" {k:<28} {v:>15.4f}")
else:
lines.append(f" {k:<28} {str(v):>15}")
lines.append("=" * 70)
return "\n".join(lines)