File size: 1,799 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Authorship verification module.
Uses a fine-tuned model to verify whether the corrected output
could plausibly have been written by the same author as the input.
Target: > 0.80 same-author probability.
"""

from typing import Tuple
from loguru import logger
import torch
import torch.nn.functional as F


class AuthorshipVerifier:
    """Verifies authorship consistency between input and output text."""

    def __init__(self, model_name: str = "roberta-base"):
        try:
            from sentence_transformers import SentenceTransformer
            self.model = SentenceTransformer(model_name, device="cpu")
            logger.info(f"AuthorshipVerifier loaded with {model_name}")
        except Exception as e:
            logger.warning(f"Failed to load authorship model: {e}")
            self.model = None

    def verify(self, text_a: str, text_b: str) -> float:
        """Return probability that both texts were written by the same author.

        Uses sentence embedding similarity as a proxy for authorship.
        Higher cosine similarity suggests same author.
        """
        if self.model is None:
            return 0.5  # Neutral score if model unavailable

        if not text_a or not text_b:
            return 0.5

        try:
            embeddings = self.model.encode([text_a, text_b], convert_to_tensor=True)
            sim = F.cosine_similarity(
                embeddings[0].unsqueeze(0),
                embeddings[1].unsqueeze(0),
            )
            # Scale similarity to [0, 1] probability
            # Cosine similarity is already in [-1, 1], shift to [0, 1]
            prob = (sim.item() + 1.0) / 2.0
            return prob
        except Exception as e:
            logger.warning(f"Authorship verification failed: {e}")
            return 0.5