File size: 4,008 Bytes
07c81f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bfcca6
07c81f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import sys
import os
from transformers import AutoTokenizer

# ── Load custom model class ──
sys.path.append(os.path.dirname(__file__))
from modeling_roberta_multitask import RobertaMultiTask

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Label mapping ──
ID_TO_BIAS = {
    0: "Neutral",
    1: "Anchoring Bias",
    2: "Bandwagon Bias",
    3: "Confirmation Bias",
    4: "Framing Bias",
    5: "Emotional Appeal",
    6: "Appeal to Authority",
    7: "False Cause & Certainty",
    8: "Hasty Generalization"
}

# ── Load model once at startup ──
print("Loading model...")
MODEL_ID  = "PreranTej/roberta-cognitive-bias-multitask"
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model     = RobertaMultiTask.from_pretrained(MODEL_ID, num_labels=9)
model.eval()
print("Model loaded successfully")

class TextRequest(BaseModel):
    text: str
    confidence_threshold: float = 0.45

@app.get("/")
def root():
    return {"status": "Bias Detection API is running"}

@app.get("/health")
def health():
    return {"status": "ok", "model": MODEL_ID}

@app.post("/predict")
def predict(request: TextRequest):
    """
    Single sentence prediction.
    Returns bias label, confidence, and extracted span.
    """
    text = request.text.strip()
    if not text:
        return {"error": "Empty text"}

    inputs  = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding=True,
        return_offsets_mapping=True
    )
    offsets = inputs.pop("offset_mapping")[0]

    with torch.no_grad():
        outputs = model(**inputs)

    probs      = torch.softmax(outputs["logits"], dim=-1)
    pred_class = torch.argmax(probs).item()
    confidence = probs[0][pred_class].item()
    bias_label = ID_TO_BIAS[pred_class]

    # Extract span
    span_preds = torch.argmax(outputs["span_logits"], dim=-1)[0].numpy()
    span_chars = []
    for idx, (s, e) in enumerate(offsets.tolist()):
        if s == e:
            continue
        if span_preds[idx] == 1:
            span_chars.append((s, e))

    extracted_span = ""
    if span_chars:
        extracted_span = text[span_chars[0][0]:span_chars[-1][1]].strip()

    return {
        "label":      bias_label,
        "label_index": pred_class,
        "confidence": round(confidence, 4),
        "span":       extracted_span,
        "is_biased":  pred_class != 0 and confidence >= request.confidence_threshold
    }

@app.post("/analyze")
def analyze(request: TextRequest):
    """
    Full text analysis β€” splits into sentences,
    runs predict on each, returns all flagged sentences.
    """
    import re
    text      = request.text.strip()
    sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text)
                 if len(s.strip()) > 5]

    flags      = []
    bias_score = 0.0

    for sent in sentences:
        result = predict(TextRequest(
            text=sent,
            confidence_threshold=request.confidence_threshold
        ))

        if result.get("is_biased"):
            start_idx = text.find(sent)
            flags.append({
                "type":       result["label"],
                "text":       sent,
                "span":       result["span"],
                "confidence": result["confidence"],
                "suggestion": "Consider revising this phrase for neutrality.",
                "start":      start_idx,
                "end":        start_idx + len(sent) if start_idx != -1 else -1,
                "label_index": result["label_index"]
            })
            bias_score += result["confidence"] * 0.2

    return {
        "bias_score": round(min(bias_score, 1.0), 4),
        "flags":      flags,
        "total_sentences": len(sentences),
        "biased_sentences": len(flags)
    }