PreranTej's picture
Update app.py
4bfcca6 verified
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import sys
import os
from transformers import AutoTokenizer
# ── Load custom model class ──
sys.path.append(os.path.dirname(__file__))
from modeling_roberta_multitask import RobertaMultiTask
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Label mapping ──
ID_TO_BIAS = {
0: "Neutral",
1: "Anchoring Bias",
2: "Bandwagon Bias",
3: "Confirmation Bias",
4: "Framing Bias",
5: "Emotional Appeal",
6: "Appeal to Authority",
7: "False Cause & Certainty",
8: "Hasty Generalization"
}
# ── Load model once at startup ──
print("Loading model...")
MODEL_ID = "PreranTej/roberta-cognitive-bias-multitask"
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = RobertaMultiTask.from_pretrained(MODEL_ID, num_labels=9)
model.eval()
print("Model loaded successfully")
class TextRequest(BaseModel):
text: str
confidence_threshold: float = 0.45
@app.get("/")
def root():
return {"status": "Bias Detection API is running"}
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_ID}
@app.post("/predict")
def predict(request: TextRequest):
"""
Single sentence prediction.
Returns bias label, confidence, and extracted span.
"""
text = request.text.strip()
if not text:
return {"error": "Empty text"}
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True,
return_offsets_mapping=True
)
offsets = inputs.pop("offset_mapping")[0]
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs["logits"], dim=-1)
pred_class = torch.argmax(probs).item()
confidence = probs[0][pred_class].item()
bias_label = ID_TO_BIAS[pred_class]
# Extract span
span_preds = torch.argmax(outputs["span_logits"], dim=-1)[0].numpy()
span_chars = []
for idx, (s, e) in enumerate(offsets.tolist()):
if s == e:
continue
if span_preds[idx] == 1:
span_chars.append((s, e))
extracted_span = ""
if span_chars:
extracted_span = text[span_chars[0][0]:span_chars[-1][1]].strip()
return {
"label": bias_label,
"label_index": pred_class,
"confidence": round(confidence, 4),
"span": extracted_span,
"is_biased": pred_class != 0 and confidence >= request.confidence_threshold
}
@app.post("/analyze")
def analyze(request: TextRequest):
"""
Full text analysis β€” splits into sentences,
runs predict on each, returns all flagged sentences.
"""
import re
text = request.text.strip()
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text)
if len(s.strip()) > 5]
flags = []
bias_score = 0.0
for sent in sentences:
result = predict(TextRequest(
text=sent,
confidence_threshold=request.confidence_threshold
))
if result.get("is_biased"):
start_idx = text.find(sent)
flags.append({
"type": result["label"],
"text": sent,
"span": result["span"],
"confidence": result["confidence"],
"suggestion": "Consider revising this phrase for neutrality.",
"start": start_idx,
"end": start_idx + len(sent) if start_idx != -1 else -1,
"label_index": result["label_index"]
})
bias_score += result["confidence"] * 0.2
return {
"bias_score": round(min(bias_score, 1.0), 4),
"flags": flags,
"total_sentences": len(sentences),
"biased_sentences": len(flags)
}