| """ |
| Evaluation Utilities for CDD |
| ============================== |
| |
| Provides evaluation metrics used in the paper: |
| - Perplexity (PPL) using GPT-2-XL |
| - Toxicity scoring |
| - LLM-as-a-Judge coherence evaluation |
| - Entropy diversity metric |
| - Molecular metrics (validity, QED, SA, novelty) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import List, Dict, Optional |
| from collections import Counter |
|
|
|
|
| def compute_entropy(texts: List[str]) -> float: |
| """Compute entropy-based diversity metric (Appendix E). |
| |
| For a sequence of length L with K distinct tokens, |
| entropy H = -Σ (L_k/L) * log(L_k/L) |
| |
| Higher entropy = more diverse generation. |
| |
| Args: |
| texts: List of generated texts. |
| |
| Returns: |
| Average entropy across texts. |
| """ |
| entropies = [] |
| |
| for text in texts: |
| tokens = list(text) |
| if not tokens: |
| continue |
| |
| counter = Counter(tokens) |
| L = len(tokens) |
| |
| entropy = 0.0 |
| for count in counter.values(): |
| p = count / L |
| if p > 0: |
| entropy -= p * np.log(p) |
| |
| entropies.append(entropy) |
| |
| return np.mean(entropies) if entropies else 0.0 |
|
|
|
|
| def compute_self_bleu(texts: List[str], n_gram: int = 4) -> float: |
| """Compute Self-BLEU diversity metric. |
| |
| Lower Self-BLEU = more diverse generation. |
| |
| Args: |
| texts: List of generated texts. |
| n_gram: N-gram size for BLEU. |
| |
| Returns: |
| Average Self-BLEU score. |
| """ |
| try: |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
| |
| smoothie = SmoothingFunction().method1 |
| scores = [] |
| |
| for i, text in enumerate(texts): |
| hypothesis = text.split() |
| references = [t.split() for j, t in enumerate(texts) if j != i] |
| |
| if not hypothesis or not references: |
| continue |
| |
| |
| refs_subset = references[:min(10, len(references))] |
| |
| try: |
| score = sentence_bleu( |
| refs_subset, hypothesis, |
| smoothing_function=smoothie, |
| ) |
| scores.append(score) |
| except Exception: |
| continue |
| |
| return np.mean(scores) if scores else 0.0 |
| |
| except ImportError: |
| print("NLTK not available for Self-BLEU computation.") |
| return -1.0 |
|
|
|
|
| def violation_rate( |
| scores: List[float], |
| threshold: float, |
| ) -> float: |
| """Compute constraint violation rate. |
| |
| Args: |
| scores: List of constraint metric values. |
| threshold: Constraint threshold. |
| |
| Returns: |
| Fraction of samples that violate the constraint. |
| """ |
| violations = sum(1 for s in scores if s > threshold) |
| return violations / len(scores) if scores else 0.0 |
|
|
|
|
| def format_results_table(results: Dict) -> str: |
| """Format results as a readable table matching paper format. |
| |
| Args: |
| results: Dictionary of evaluation results. |
| |
| Returns: |
| Formatted string table. |
| """ |
| lines = [] |
| lines.append("=" * 70) |
| lines.append(f"{'Metric':<30} {'Value':>15}") |
| lines.append("-" * 70) |
| |
| for key, value in results.items(): |
| if isinstance(value, float): |
| lines.append(f"{key:<30} {value:>15.4f}") |
| elif isinstance(value, int): |
| lines.append(f"{key:<30} {value:>15d}") |
| elif isinstance(value, str): |
| lines.append(f"{key:<30} {value:>15}") |
| elif isinstance(value, dict): |
| lines.append(f"{key}:") |
| for k, v in value.items(): |
| if isinstance(v, float): |
| lines.append(f" {k:<28} {v:>15.4f}") |
| else: |
| lines.append(f" {k:<28} {str(v):>15}") |
| |
| lines.append("=" * 70) |
| return "\n".join(lines) |
|
|