from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch import sys import os from transformers import AutoTokenizer # ── Load custom model class ── sys.path.append(os.path.dirname(__file__)) from modeling_roberta_multitask import RobertaMultiTask app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Label mapping ── ID_TO_BIAS = { 0: "Neutral", 1: "Anchoring Bias", 2: "Bandwagon Bias", 3: "Confirmation Bias", 4: "Framing Bias", 5: "Emotional Appeal", 6: "Appeal to Authority", 7: "False Cause & Certainty", 8: "Hasty Generalization" } # ── Load model once at startup ── print("Loading model...") MODEL_ID = "PreranTej/roberta-cognitive-bias-multitask" tokenizer = AutoTokenizer.from_pretrained("roberta-base") model = RobertaMultiTask.from_pretrained(MODEL_ID, num_labels=9) model.eval() print("Model loaded successfully") class TextRequest(BaseModel): text: str confidence_threshold: float = 0.45 @app.get("/") def root(): return {"status": "Bias Detection API is running"} @app.get("/health") def health(): return {"status": "ok", "model": MODEL_ID} @app.post("/predict") def predict(request: TextRequest): """ Single sentence prediction. Returns bias label, confidence, and extracted span. """ text = request.text.strip() if not text: return {"error": "Empty text"} inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=128, padding=True, return_offsets_mapping=True ) offsets = inputs.pop("offset_mapping")[0] with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs["logits"], dim=-1) pred_class = torch.argmax(probs).item() confidence = probs[0][pred_class].item() bias_label = ID_TO_BIAS[pred_class] # Extract span span_preds = torch.argmax(outputs["span_logits"], dim=-1)[0].numpy() span_chars = [] for idx, (s, e) in enumerate(offsets.tolist()): if s == e: continue if span_preds[idx] == 1: span_chars.append((s, e)) extracted_span = "" if span_chars: extracted_span = text[span_chars[0][0]:span_chars[-1][1]].strip() return { "label": bias_label, "label_index": pred_class, "confidence": round(confidence, 4), "span": extracted_span, "is_biased": pred_class != 0 and confidence >= request.confidence_threshold } @app.post("/analyze") def analyze(request: TextRequest): """ Full text analysis — splits into sentences, runs predict on each, returns all flagged sentences. """ import re text = request.text.strip() sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if len(s.strip()) > 5] flags = [] bias_score = 0.0 for sent in sentences: result = predict(TextRequest( text=sent, confidence_threshold=request.confidence_threshold )) if result.get("is_biased"): start_idx = text.find(sent) flags.append({ "type": result["label"], "text": sent, "span": result["span"], "confidence": result["confidence"], "suggestion": "Consider revising this phrase for neutrality.", "start": start_idx, "end": start_idx + len(sent) if start_idx != -1 else -1, "label_index": result["label_index"] }) bias_score += result["confidence"] * 0.2 return { "bias_score": round(min(bias_score, 1.0), 4), "flags": flags, "total_sentences": len(sentences), "biased_sentences": len(flags) }