fake-news-api / src /models /inference.py
aviseth
fix: f-string backslash syntax error in generate_explanation_text
69e2ee8
"""
Model inference — lazy-loads fine-tuned models and runs predictions with explainability.
"""
import os
import json
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from dotenv import load_dotenv
load_dotenv()
ID2LABEL = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
LABEL2ID = {v: k for k, v in ID2LABEL.items()}
_here = Path(__file__).resolve()
PROJECT_ROOT = next(
(p for p in _here.parents if (p / "models" / "distilbert" / "config.json").exists()),
_here.parents[2]
)
MODELS_DIR = PROJECT_ROOT / "models"
MODEL_NAMES = {
"distilbert": os.getenv("HF_REPO_DISTILBERT", "aviseth/distilbert-fakenews"),
"roberta": os.getenv("HF_REPO_ROBERTA", "aviseth/roberta-fakenews"),
"xlnet": os.getenv("HF_REPO_XLNET", "aviseth/xlnet-fakenews"),
}
def _patch_xlnet_configs(source: str):
"""Fix known XLNet config issues that cause warnings or errors on load."""
tok_cfg_path = Path(source) / "tokenizer_config.json"
if tok_cfg_path.exists():
tok_cfg = json.loads(tok_cfg_path.read_text())
if isinstance(tok_cfg.get("extra_special_tokens"), list):
tok_cfg["extra_special_tokens"] = {}
tok_cfg_path.write_text(json.dumps(tok_cfg, indent=2))
cfg_path = Path(source) / "config.json"
if cfg_path.exists():
cfg = json.loads(cfg_path.read_text())
if "use_cache" in cfg:
del cfg["use_cache"]
cfg_path.write_text(json.dumps(cfg, indent=2))
class FakeNewsClassifier:
"""Wraps a fine-tuned HuggingFace model. Lazy-loads on first call and caches in memory."""
def __init__(self, model_key: str = "distilbert", max_length: int = 256):
self.model_key = model_key
self.max_length = max_length
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = None
self._tokenizer = None
def _load(self):
local_path = MODELS_DIR / self.model_key
source = str(local_path) if (
local_path / "config.json").exists() else MODEL_NAMES[self.model_key]
print(f"[inference] Loading {self.model_key} from: {source}")
if self.model_key == "xlnet":
_patch_xlnet_configs(source)
# RoBERTa: use slow tokenizer to avoid tokenizer.json format incompatibilities
use_fast = self.model_key != "roberta"
self._tokenizer = AutoTokenizer.from_pretrained(
source, use_fast=use_fast)
self._model = AutoModelForSequenceClassification.from_pretrained(
source,
num_labels=4,
id2label=ID2LABEL,
label2id=LABEL2ID,
ignore_mismatched_sizes=True,
)
self._model.to(self.device)
self._model.eval()
print(f"[inference] Model ready on {self.device}")
@property
def model(self):
if self._model is None:
self._load()
return self._model
@property
def tokenizer(self):
if self._tokenizer is None:
self._load()
return self._tokenizer
def predict(self, text: str) -> dict:
enc = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=True,
)
inputs = {}
for k in {"input_ids", "attention_mask"}:
if k not in enc:
continue
v = enc[k]
if not isinstance(v, torch.Tensor):
v = torch.tensor(v)
inputs[k] = v.to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
pred_id = int(np.argmax(probs))
label = ID2LABEL[pred_id]
confidence = float(probs[pred_id])
scores = {ID2LABEL[i]: round(float(p), 4) for i, p in enumerate(probs)}
tokens = self._token_importance(inputs, pred_id)
return {
"label": label,
"confidence": round(confidence, 4),
"scores": scores,
"tokens": tokens,
}
def _token_importance(self, enc, pred_id: int, top_k: int = 8) -> List[Dict]:
"""Gradient saliency — returns top-k tokens sorted by importance."""
try:
self.model.zero_grad()
input_ids = enc["input_ids"]
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor(input_ids)
input_ids = input_ids.to(self.device)
attn_mask = enc.get("attention_mask")
if attn_mask is not None and not isinstance(attn_mask, torch.Tensor):
attn_mask = torch.tensor(attn_mask)
if attn_mask is not None:
attn_mask = attn_mask.to(self.device)
embeds = self.model.get_input_embeddings()(
input_ids).detach().requires_grad_(True)
outputs = self.model(inputs_embeds=embeds,
attention_mask=attn_mask)
outputs.logits[0, pred_id].backward()
importance = embeds.grad[0].norm(dim=-1).cpu().numpy()
tokens = self.tokenizer.convert_ids_to_tokens(
input_ids[0].cpu().tolist())
special = {"[CLS]", "[SEP]", "[PAD]", "<s>",
"</s>", "<pad>", "<cls>", "<sep>", "▁", "Ġ"}
pairs = [
(t.replace("##", "").replace("▁", "").replace("Ġ", ""), float(s))
for t, s in zip(tokens, importance)
if t not in special and len(t.strip()) > 1
]
if pairs:
max_s = max(s for _, s in pairs) or 1.0
pairs = [(t, round(s / max_s, 4)) for t, s in pairs]
pairs.sort(key=lambda x: x[1], reverse=True)
return [{"token": t, "score": s} for t, s in pairs[:top_k]]
except Exception:
return []
def attention_weights(self, text: str) -> List[Dict]:
"""Gradient saliency mapped to original words in reading order."""
try:
enc = self.tokenizer(
text, return_tensors="pt", truncation=True,
max_length=self.max_length, padding=False,
)
enc = {k: v.to(self.device) if isinstance(
v, torch.Tensor) else v for k, v in enc.items()}
input_ids = enc["input_ids"]
self.model.zero_grad()
embeds = self.model.get_input_embeddings()(
input_ids).detach().requires_grad_(True)
outputs = self.model(inputs_embeds=embeds,
attention_mask=enc.get("attention_mask"))
pred_id = int(torch.argmax(outputs.logits, dim=-1)[0])
outputs.logits[0, pred_id].backward()
importance = embeds.grad[0].norm(dim=-1).cpu().numpy()
tokens = self.tokenizer.convert_ids_to_tokens(
input_ids[0].cpu().tolist())
SPECIAL = {"[CLS]", "[SEP]", "[PAD]", "<s>",
"</s>", "<pad>", "<cls>", "<sep>", "<unk>"}
words, current_word, current_score = [], "", 0.0
for tok, score in zip(tokens, importance):
if tok in SPECIAL:
if current_word:
words.append((current_word, current_score))
current_word, current_score = "", 0.0
continue
is_continuation = tok.startswith("##")
is_new_word = tok.startswith("Ġ") or tok.startswith("▁")
clean = tok.replace("##", "").replace("Ġ", "").replace("▁", "")
if is_continuation:
current_word += clean
current_score = max(current_score, float(score))
elif is_new_word:
if current_word:
words.append((current_word, current_score))
current_word, current_score = clean, float(score)
else:
if current_word:
words.append((current_word, current_score))
current_word, current_score = clean, float(score)
if current_word:
words.append((current_word, current_score))
if not words:
return []
max_s = max(s for _, s in words) or 1.0
return [{"word": w, "attention": round(s / max_s, 4)} for w, s in words if w.strip()]
except Exception as e:
print(f"[attention_weights] failed: {e}")
return []
def shap_explain(self, text: str) -> List[Dict]:
"""Word-level SHAP explanation using RoBERTa."""
try:
import shap
clf = get_classifier("roberta")
def predict_proba(texts):
results = []
for t in texts:
enc = clf.tokenizer(
t, return_tensors="pt", truncation=True,
max_length=clf.max_length, padding=True,
).to(clf.device)
with torch.no_grad():
logits = clf.model(**enc).logits
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
results.append(probs)
return np.array(results)
masker = shap.maskers.Text(r"\W+")
explainer = shap.Explainer(
predict_proba, masker, output_names=list(ID2LABEL.values()))
shap_values = explainer([text], max_evals=200, batch_size=8)
enc = clf.tokenizer(text, return_tensors="pt", truncation=True,
max_length=clf.max_length).to(clf.device)
with torch.no_grad():
pred_id = int(torch.argmax(clf.model(**enc).logits, dim=-1)[0])
words = shap_values.data[0]
values = shap_values.values[0, :, pred_id]
max_abs = float(np.max(np.abs(values))) if len(values) else 1.0
if max_abs == 0:
max_abs = 1.0
return [
{"word": w.strip(), "shap_value": round(float(v) / max_abs, 4)}
for w, v in zip(words, values) if w.strip()
]
except Exception as e:
print(f"[shap_explain] failed: {e}")
return []
_classifiers: dict[str, FakeNewsClassifier] = {}
def get_classifier(model_key: str = "distilbert") -> FakeNewsClassifier:
if model_key not in _classifiers:
_classifiers[model_key] = FakeNewsClassifier(model_key)
return _classifiers[model_key]
def predict(text: str, model_key: str = "distilbert") -> dict:
return get_classifier(model_key).predict(text)
def generate_explanation_text(shap_tokens: List[Dict], label: str, confidence: float, model_key: str) -> str:
if not shap_tokens:
return (
f"The {model_key} model classified this article as {label} "
f"with {round(confidence * 100)}% confidence, but no word-level explanation was available."
)
positive = sorted([t for t in shap_tokens if t["shap_value"]
> 0.05], key=lambda x: x["shap_value"], reverse=True)[:5]
negative = sorted([t for t in shap_tokens if t["shap_value"]
< -0.05], key=lambda x: x["shap_value"])[:3]
conf_pct = round(confidence * 100)
model_display = {"distilbert": "DistilBERT", "roberta": "RoBERTa",
"xlnet": "XLNet"}.get(model_key, model_key)
conf_phrase = (
"with very high confidence" if conf_pct >= 90 else
"with high confidence" if conf_pct >= 75 else
"with moderate confidence" if conf_pct >= 55 else
"with low confidence"
)
label_descriptions = {
"True": "factual and credible reporting",
"Fake": "fabricated or misleading content",
"Satire": "satirical or parody content",
"Bias": "politically or ideologically biased reporting",
}
parts = [
f"{model_display} classified this article as {label} ({label_descriptions.get(label, label)}) {conf_phrase} ({conf_pct}%)."]
if positive:
pos_words = ', '.join(f'"{t["word"]}"' for t in positive)
parts.append(
f"The words most strongly associated with this classification were {pos_words}, which the model weighted heavily toward a {label} prediction.")
if negative:
neg_words = ', '.join(f'"{t["word"]}"' for t in negative)
parts.append(
f"On the other hand, terms like {neg_words} pulled against this classification.")
else:
parts.append(
"The model found little linguistic evidence contradicting this classification.")
if conf_pct < 65:
parts.append(
"The relatively lower confidence suggests the article contains mixed signals and the prediction should be interpreted with caution.")
return " ".join(parts)