""" Evaluation Utilities for CDD ============================== Provides evaluation metrics used in the paper: - Perplexity (PPL) using GPT-2-XL - Toxicity scoring - LLM-as-a-Judge coherence evaluation - Entropy diversity metric - Molecular metrics (validity, QED, SA, novelty) """ import torch import torch.nn.functional as F import numpy as np from typing import List, Dict, Optional from collections import Counter def compute_entropy(texts: List[str]) -> float: """Compute entropy-based diversity metric (Appendix E). For a sequence of length L with K distinct tokens, entropy H = -Σ (L_k/L) * log(L_k/L) Higher entropy = more diverse generation. Args: texts: List of generated texts. Returns: Average entropy across texts. """ entropies = [] for text in texts: tokens = list(text) # Character-level if not tokens: continue counter = Counter(tokens) L = len(tokens) entropy = 0.0 for count in counter.values(): p = count / L if p > 0: entropy -= p * np.log(p) entropies.append(entropy) return np.mean(entropies) if entropies else 0.0 def compute_self_bleu(texts: List[str], n_gram: int = 4) -> float: """Compute Self-BLEU diversity metric. Lower Self-BLEU = more diverse generation. Args: texts: List of generated texts. n_gram: N-gram size for BLEU. Returns: Average Self-BLEU score. """ try: from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction smoothie = SmoothingFunction().method1 scores = [] for i, text in enumerate(texts): hypothesis = text.split() references = [t.split() for j, t in enumerate(texts) if j != i] if not hypothesis or not references: continue # Use only a subset for efficiency refs_subset = references[:min(10, len(references))] try: score = sentence_bleu( refs_subset, hypothesis, smoothing_function=smoothie, ) scores.append(score) except Exception: continue return np.mean(scores) if scores else 0.0 except ImportError: print("NLTK not available for Self-BLEU computation.") return -1.0 def violation_rate( scores: List[float], threshold: float, ) -> float: """Compute constraint violation rate. Args: scores: List of constraint metric values. threshold: Constraint threshold. Returns: Fraction of samples that violate the constraint. """ violations = sum(1 for s in scores if s > threshold) return violations / len(scores) if scores else 0.0 def format_results_table(results: Dict) -> str: """Format results as a readable table matching paper format. Args: results: Dictionary of evaluation results. Returns: Formatted string table. """ lines = [] lines.append("=" * 70) lines.append(f"{'Metric':<30} {'Value':>15}") lines.append("-" * 70) for key, value in results.items(): if isinstance(value, float): lines.append(f"{key:<30} {value:>15.4f}") elif isinstance(value, int): lines.append(f"{key:<30} {value:>15d}") elif isinstance(value, str): lines.append(f"{key:<30} {value:>15}") elif isinstance(value, dict): lines.append(f"{key}:") for k, v in value.items(): if isinstance(v, float): lines.append(f" {k:<28} {v:>15.4f}") else: lines.append(f" {k:<28} {str(v):>15}") lines.append("=" * 70) return "\n".join(lines)