""" Measures style preservation between input and output. Key metrics: - Style Vector Cosine Similarity (target: > 0.85) - AWL Coverage Score (target: > 0.25) - Authorship Verification Score (target: > 0.80) """ import torch import torch.nn.functional as F from typing import List, Tuple from ..style.fingerprinter import StyleFingerprinter from ..vocabulary.awl_loader import AWLLoader from loguru import logger import numpy as np class StyleEvaluator: """Evaluates style preservation and academic vocabulary coverage.""" def __init__(self, fingerprinter: StyleFingerprinter, awl: AWLLoader): self.fingerprinter = fingerprinter self.awl = awl def style_similarity(self, text_a: str, text_b: str) -> float: """Cosine similarity between style vectors. Target: > 0.85.""" vec_a = self.fingerprinter.extract_vector(text_a) vec_b = self.fingerprinter.extract_vector(text_b) if vec_a.dim() == 1: vec_a = vec_a.unsqueeze(0) if vec_b.dim() == 1: vec_b = vec_b.unsqueeze(0) sim = F.cosine_similarity(vec_a, vec_b, dim=-1) return sim.item() def awl_coverage(self, text: str) -> float: """Fraction of content words in AWL. Target: > 0.25.""" if not text or not text.strip(): return 0.0 words = text.lower().split() # Filter to content words (longer than 3 chars, alphabetic) content_words = [w for w in words if len(w) > 3 and w.isalpha()] if not content_words: return 0.0 awl_count = sum(1 for w in content_words if self.awl.is_academic(w)) return awl_count / len(content_words) def evaluate_batch( self, inputs: List[str], outputs: List[str], references: List[str], ) -> dict: """Compute style and AWL metrics for a batch.""" style_sims = [] awl_coverages = [] ref_style_sims = [] for inp, out, ref in zip(inputs, outputs, references): # Style similarity between input and output (preservation) style_sims.append(self.style_similarity(inp, out)) # AWL coverage of output awl_coverages.append(self.awl_coverage(out)) # Style similarity between output and reference ref_style_sims.append(self.style_similarity(out, ref)) return { "style_similarity_mean": float(np.mean(style_sims)), "style_similarity_std": float(np.std(style_sims)), "awl_coverage_mean": float(np.mean(awl_coverages)), "awl_coverage_std": float(np.std(awl_coverages)), "ref_style_similarity_mean": float(np.mean(ref_style_sims)), }