File size: 7,028 Bytes
b1c84b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c78c2c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1c84b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c78c2c1
 
b1c84b5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
PhilVerify β€” XLM-RoBERTa Sequence Classifier (Layer 1, Phase 10)

Fine-tuned on Philippine misinformation data (English / Filipino / Taglish).
Drop-in replacement for TFIDFClassifier β€” same predict() interface.

Uses `ml/models/xlmr_model/` if it exists (populated by train_xlmr.py).
Raises ModelNotFoundError if the model has not been trained yet; the
scoring engine falls back to TFIDFClassifier in that case.
"""
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from pathlib import Path

logger = logging.getLogger(__name__)

# Where train_xlmr.py saves the fine-tuned checkpoint
MODEL_DIR = Path(__file__).parent / "models" / "xlmr_model"

# Labels must match the id2label mapping saved during training
LABEL_NAMES = {0: "Credible", 1: "Unverified", 2: "Likely Fake"}
NUM_LABELS  = 3
MAX_LENGTH  = 256   # tokens; 256 covers 95%+ of PH news headlines/paragraphs


class ModelNotFoundError(FileNotFoundError):
    """Raised when the fine-tuned checkpoint directory is missing."""


@dataclass
class Layer1Result:
    verdict: str                                          # "Credible" | "Unverified" | "Likely Fake"
    confidence: float                                     # 0.0 – 100.0
    triggered_features: list[str] = field(default_factory=list)  # salient tokens


class XLMRobertaClassifier:
    """
    XLM-RoBERTa-based misinformation classifier.

    Loading is lazy: the model is not loaded until the first call to predict().
    This keeps FastAPI startup fast when the model is available.

    Raises ModelNotFoundError on instantiation if MODEL_DIR does not exist,
    so the scoring engine can detect the missing checkpoint immediately.
    """

    def __init__(self) -> None:
        if not MODEL_DIR.exists():
            raise ModelNotFoundError(
                f"XLM-RoBERTa checkpoint not found at {MODEL_DIR}. "
                "Run `python ml/train_xlmr.py` to fine-tune the model first."
            )
        self._tokenizer = None
        self._model     = None

    # ── Lazy load ─────────────────────────────────────────────────────────────

    def _ensure_loaded(self) -> None:
        if self._model is not None:
            return
        try:
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
            import torch
            self._torch = torch
            logger.info("Loading XLM-RoBERTa from %s …", MODEL_DIR)
            self._tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR))
            self._model = AutoModelForSequenceClassification.from_pretrained(
                str(MODEL_DIR),
                num_labels=NUM_LABELS,
            )
            self._model.eval()
            logger.info("XLM-RoBERTa loaded β€” device: %s", self._device)
        except Exception as exc:
            logger.exception("Failed to load XLM-RoBERTa model: %s", exc)
            raise

    @property
    def _device(self) -> str:
        try:
            import torch
            if torch.backends.mps.is_available():
                return "mps"
        except Exception:
            pass
        try:
            import torch
            if torch.cuda.is_available():
                return "cuda"
        except Exception:
            pass
        return "cpu"

    # ── Saliency: attention-based token importance ────────────────────────────

    def _salient_tokens(
        self,
        input_ids,       # (1, seq_len) torch.Tensor
        attentions,      # tuple of (1, heads, seq_len, seq_len) per layer
        n: int = 5,
    ) -> list[str]:
        """
        Average last-layer attention from CLS β†’ all tokens.
        Returns top-N decoded sub-word tokens as human-readable strings.
        Strips the sentencepiece ▁ prefix and SFX tokens.
        """
        import torch
        last_layer_attn = attentions[-1]               # (1, heads, seq, seq)
        cls_attn = last_layer_attn[0, :, 0, :].mean(0)  # (seq,) β€” avg over heads
        seq_len  = cls_attn.shape[-1]
        tokens   = self._tokenizer.convert_ids_to_tokens(
            input_ids[0].tolist()[:seq_len]
        )

        # Score each token; skip special tokens
        scored = []
        for i, (tok, score) in enumerate(zip(tokens, cls_attn.tolist())):
            if tok in ("<s>", "</s>", "<pad>", "<unk>"):
                continue
            clean = tok.lstrip("▁").strip()
            if len(clean) >= 3 and clean.isalpha():
                scored.append((clean, score))

        # Sort descending, dedup, return top N
        seen: set[str] = set()
        result = []
        for word, _ in sorted(scored, key=lambda x: x[1], reverse=True):
            if word.lower() not in seen:
                seen.add(word.lower())
                result.append(word)
            if len(result) >= n:
                break
        return result

    # ── Public API (same interface as TFIDFClassifier) ────────────────────────

    def predict_probs(self, text: str):
        """Return raw softmax probability tensor for ensemble averaging."""
        self._ensure_loaded()
        import torch

        encoding = self._tokenizer(
            text,
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors="pt",
        )
        with torch.no_grad():
            outputs = self._model(
                input_ids=encoding["input_ids"],
                attention_mask=encoding["attention_mask"],
                output_attentions=True,
            )
        return torch.softmax(outputs.logits[0], dim=-1), outputs.attentions, encoding["input_ids"]

    def predict(self, text: str) -> Layer1Result:
        self._ensure_loaded()
        import torch

        encoding = self._tokenizer(
            text,
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors="pt",
        )
        input_ids      = encoding["input_ids"]
        attention_mask = encoding["attention_mask"]

        with torch.no_grad():
            outputs = self._model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=True,
            )

        logits     = outputs.logits[0]                        # (num_labels,)
        probs      = torch.softmax(logits, dim=-1)
        pred_label = int(probs.argmax().item())
        confidence = round(float(probs[pred_label].item()) * 100, 1)
        verdict    = LABEL_NAMES[pred_label]

        # SDPA attention doesn't return attentions; fallback to empty
        triggered  = self._salient_tokens(input_ids, outputs.attentions) if outputs.attentions else []

        return Layer1Result(
            verdict=verdict,
            confidence=confidence,
            triggered_features=triggered,
        )