File size: 2,740 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Measures style preservation between input and output.

Key metrics:
  - Style Vector Cosine Similarity (target: > 0.85)
  - AWL Coverage Score (target: > 0.25)
  - Authorship Verification Score (target: > 0.80)
"""

import torch
import torch.nn.functional as F
from typing import List, Tuple
from ..style.fingerprinter import StyleFingerprinter
from ..vocabulary.awl_loader import AWLLoader
from loguru import logger
import numpy as np


class StyleEvaluator:
    """Evaluates style preservation and academic vocabulary coverage."""

    def __init__(self, fingerprinter: StyleFingerprinter, awl: AWLLoader):
        self.fingerprinter = fingerprinter
        self.awl = awl

    def style_similarity(self, text_a: str, text_b: str) -> float:
        """Cosine similarity between style vectors. Target: > 0.85."""
        vec_a = self.fingerprinter.extract_vector(text_a)
        vec_b = self.fingerprinter.extract_vector(text_b)

        if vec_a.dim() == 1:
            vec_a = vec_a.unsqueeze(0)
        if vec_b.dim() == 1:
            vec_b = vec_b.unsqueeze(0)

        sim = F.cosine_similarity(vec_a, vec_b, dim=-1)
        return sim.item()

    def awl_coverage(self, text: str) -> float:
        """Fraction of content words in AWL. Target: > 0.25."""
        if not text or not text.strip():
            return 0.0

        words = text.lower().split()
        # Filter to content words (longer than 3 chars, alphabetic)
        content_words = [w for w in words if len(w) > 3 and w.isalpha()]

        if not content_words:
            return 0.0

        awl_count = sum(1 for w in content_words if self.awl.is_academic(w))
        return awl_count / len(content_words)

    def evaluate_batch(
        self,
        inputs: List[str],
        outputs: List[str],
        references: List[str],
    ) -> dict:
        """Compute style and AWL metrics for a batch."""
        style_sims = []
        awl_coverages = []
        ref_style_sims = []

        for inp, out, ref in zip(inputs, outputs, references):
            # Style similarity between input and output (preservation)
            style_sims.append(self.style_similarity(inp, out))

            # AWL coverage of output
            awl_coverages.append(self.awl_coverage(out))

            # Style similarity between output and reference
            ref_style_sims.append(self.style_similarity(out, ref))

        return {
            "style_similarity_mean": float(np.mean(style_sims)),
            "style_similarity_std": float(np.std(style_sims)),
            "awl_coverage_mean": float(np.mean(awl_coverages)),
            "awl_coverage_std": float(np.std(awl_coverages)),
            "ref_style_similarity_mean": float(np.mean(ref_style_sims)),
        }