""" Model inference — lazy-loads fine-tuned models and runs predictions with explainability. """ import os import json import torch import numpy as np from pathlib import Path from typing import List, Dict from transformers import AutoTokenizer, AutoModelForSequenceClassification from dotenv import load_dotenv load_dotenv() ID2LABEL = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"} LABEL2ID = {v: k for k, v in ID2LABEL.items()} _here = Path(__file__).resolve() PROJECT_ROOT = next( (p for p in _here.parents if (p / "models" / "distilbert" / "config.json").exists()), _here.parents[2] ) MODELS_DIR = PROJECT_ROOT / "models" MODEL_NAMES = { "distilbert": os.getenv("HF_REPO_DISTILBERT", "aviseth/distilbert-fakenews"), "roberta": os.getenv("HF_REPO_ROBERTA", "aviseth/roberta-fakenews"), "xlnet": os.getenv("HF_REPO_XLNET", "aviseth/xlnet-fakenews"), } def _patch_xlnet_configs(source: str): """Fix known XLNet config issues that cause warnings or errors on load.""" tok_cfg_path = Path(source) / "tokenizer_config.json" if tok_cfg_path.exists(): tok_cfg = json.loads(tok_cfg_path.read_text()) if isinstance(tok_cfg.get("extra_special_tokens"), list): tok_cfg["extra_special_tokens"] = {} tok_cfg_path.write_text(json.dumps(tok_cfg, indent=2)) cfg_path = Path(source) / "config.json" if cfg_path.exists(): cfg = json.loads(cfg_path.read_text()) if "use_cache" in cfg: del cfg["use_cache"] cfg_path.write_text(json.dumps(cfg, indent=2)) class FakeNewsClassifier: """Wraps a fine-tuned HuggingFace model. Lazy-loads on first call and caches in memory.""" def __init__(self, model_key: str = "distilbert", max_length: int = 256): self.model_key = model_key self.max_length = max_length self.device = "cuda" if torch.cuda.is_available() else "cpu" self._model = None self._tokenizer = None def _load(self): local_path = MODELS_DIR / self.model_key source = str(local_path) if ( local_path / "config.json").exists() else MODEL_NAMES[self.model_key] print(f"[inference] Loading {self.model_key} from: {source}") if self.model_key == "xlnet": _patch_xlnet_configs(source) # RoBERTa: use slow tokenizer to avoid tokenizer.json format incompatibilities use_fast = self.model_key != "roberta" self._tokenizer = AutoTokenizer.from_pretrained( source, use_fast=use_fast) self._model = AutoModelForSequenceClassification.from_pretrained( source, num_labels=4, id2label=ID2LABEL, label2id=LABEL2ID, ignore_mismatched_sizes=True, ) self._model.to(self.device) self._model.eval() print(f"[inference] Model ready on {self.device}") @property def model(self): if self._model is None: self._load() return self._model @property def tokenizer(self): if self._tokenizer is None: self._load() return self._tokenizer def predict(self, text: str) -> dict: enc = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True, ) inputs = {} for k in {"input_ids", "attention_mask"}: if k not in enc: continue v = enc[k] if not isinstance(v, torch.Tensor): v = torch.tensor(v) inputs[k] = v.to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy() pred_id = int(np.argmax(probs)) label = ID2LABEL[pred_id] confidence = float(probs[pred_id]) scores = {ID2LABEL[i]: round(float(p), 4) for i, p in enumerate(probs)} tokens = self._token_importance(inputs, pred_id) return { "label": label, "confidence": round(confidence, 4), "scores": scores, "tokens": tokens, } def _token_importance(self, enc, pred_id: int, top_k: int = 8) -> List[Dict]: """Gradient saliency — returns top-k tokens sorted by importance.""" try: self.model.zero_grad() input_ids = enc["input_ids"] if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor(input_ids) input_ids = input_ids.to(self.device) attn_mask = enc.get("attention_mask") if attn_mask is not None and not isinstance(attn_mask, torch.Tensor): attn_mask = torch.tensor(attn_mask) if attn_mask is not None: attn_mask = attn_mask.to(self.device) embeds = self.model.get_input_embeddings()( input_ids).detach().requires_grad_(True) outputs = self.model(inputs_embeds=embeds, attention_mask=attn_mask) outputs.logits[0, pred_id].backward() importance = embeds.grad[0].norm(dim=-1).cpu().numpy() tokens = self.tokenizer.convert_ids_to_tokens( input_ids[0].cpu().tolist()) special = {"[CLS]", "[SEP]", "[PAD]", "", "", "", "", "", "▁", "Ġ"} pairs = [ (t.replace("##", "").replace("▁", "").replace("Ġ", ""), float(s)) for t, s in zip(tokens, importance) if t not in special and len(t.strip()) > 1 ] if pairs: max_s = max(s for _, s in pairs) or 1.0 pairs = [(t, round(s / max_s, 4)) for t, s in pairs] pairs.sort(key=lambda x: x[1], reverse=True) return [{"token": t, "score": s} for t, s in pairs[:top_k]] except Exception: return [] def attention_weights(self, text: str) -> List[Dict]: """Gradient saliency mapped to original words in reading order.""" try: enc = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=False, ) enc = {k: v.to(self.device) if isinstance( v, torch.Tensor) else v for k, v in enc.items()} input_ids = enc["input_ids"] self.model.zero_grad() embeds = self.model.get_input_embeddings()( input_ids).detach().requires_grad_(True) outputs = self.model(inputs_embeds=embeds, attention_mask=enc.get("attention_mask")) pred_id = int(torch.argmax(outputs.logits, dim=-1)[0]) outputs.logits[0, pred_id].backward() importance = embeds.grad[0].norm(dim=-1).cpu().numpy() tokens = self.tokenizer.convert_ids_to_tokens( input_ids[0].cpu().tolist()) SPECIAL = {"[CLS]", "[SEP]", "[PAD]", "", "", "", "", "", ""} words, current_word, current_score = [], "", 0.0 for tok, score in zip(tokens, importance): if tok in SPECIAL: if current_word: words.append((current_word, current_score)) current_word, current_score = "", 0.0 continue is_continuation = tok.startswith("##") is_new_word = tok.startswith("Ġ") or tok.startswith("▁") clean = tok.replace("##", "").replace("Ġ", "").replace("▁", "") if is_continuation: current_word += clean current_score = max(current_score, float(score)) elif is_new_word: if current_word: words.append((current_word, current_score)) current_word, current_score = clean, float(score) else: if current_word: words.append((current_word, current_score)) current_word, current_score = clean, float(score) if current_word: words.append((current_word, current_score)) if not words: return [] max_s = max(s for _, s in words) or 1.0 return [{"word": w, "attention": round(s / max_s, 4)} for w, s in words if w.strip()] except Exception as e: print(f"[attention_weights] failed: {e}") return [] def shap_explain(self, text: str) -> List[Dict]: """Word-level SHAP explanation using RoBERTa.""" try: import shap clf = get_classifier("roberta") def predict_proba(texts): results = [] for t in texts: enc = clf.tokenizer( t, return_tensors="pt", truncation=True, max_length=clf.max_length, padding=True, ).to(clf.device) with torch.no_grad(): logits = clf.model(**enc).logits probs = torch.softmax(logits, dim=-1)[0].cpu().numpy() results.append(probs) return np.array(results) masker = shap.maskers.Text(r"\W+") explainer = shap.Explainer( predict_proba, masker, output_names=list(ID2LABEL.values())) shap_values = explainer([text], max_evals=200, batch_size=8) enc = clf.tokenizer(text, return_tensors="pt", truncation=True, max_length=clf.max_length).to(clf.device) with torch.no_grad(): pred_id = int(torch.argmax(clf.model(**enc).logits, dim=-1)[0]) words = shap_values.data[0] values = shap_values.values[0, :, pred_id] max_abs = float(np.max(np.abs(values))) if len(values) else 1.0 if max_abs == 0: max_abs = 1.0 return [ {"word": w.strip(), "shap_value": round(float(v) / max_abs, 4)} for w, v in zip(words, values) if w.strip() ] except Exception as e: print(f"[shap_explain] failed: {e}") return [] _classifiers: dict[str, FakeNewsClassifier] = {} def get_classifier(model_key: str = "distilbert") -> FakeNewsClassifier: if model_key not in _classifiers: _classifiers[model_key] = FakeNewsClassifier(model_key) return _classifiers[model_key] def predict(text: str, model_key: str = "distilbert") -> dict: return get_classifier(model_key).predict(text) def generate_explanation_text(shap_tokens: List[Dict], label: str, confidence: float, model_key: str) -> str: if not shap_tokens: return ( f"The {model_key} model classified this article as {label} " f"with {round(confidence * 100)}% confidence, but no word-level explanation was available." ) positive = sorted([t for t in shap_tokens if t["shap_value"] > 0.05], key=lambda x: x["shap_value"], reverse=True)[:5] negative = sorted([t for t in shap_tokens if t["shap_value"] < -0.05], key=lambda x: x["shap_value"])[:3] conf_pct = round(confidence * 100) model_display = {"distilbert": "DistilBERT", "roberta": "RoBERTa", "xlnet": "XLNet"}.get(model_key, model_key) conf_phrase = ( "with very high confidence" if conf_pct >= 90 else "with high confidence" if conf_pct >= 75 else "with moderate confidence" if conf_pct >= 55 else "with low confidence" ) label_descriptions = { "True": "factual and credible reporting", "Fake": "fabricated or misleading content", "Satire": "satirical or parody content", "Bias": "politically or ideologically biased reporting", } parts = [ f"{model_display} classified this article as {label} ({label_descriptions.get(label, label)}) {conf_phrase} ({conf_pct}%)."] if positive: pos_words = ', '.join(f'"{t["word"]}"' for t in positive) parts.append( f"The words most strongly associated with this classification were {pos_words}, which the model weighted heavily toward a {label} prediction.") if negative: neg_words = ', '.join(f'"{t["word"]}"' for t in negative) parts.append( f"On the other hand, terms like {neg_words} pulled against this classification.") else: parts.append( "The model found little linguistic evidence contradicting this classification.") if conf_pct < 65: parts.append( "The relatively lower confidence suggests the article contains mixed signals and the prediction should be interpreted with caution.") return " ".join(parts)