File size: 2,476 Bytes
e66fc58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
from typing import Dict

# Initialize FastAPI app
app = FastAPI(title="Toxicity Detection API", description="API to detect hate/toxicity in text using unitary/unbiased-toxic-roberta")

# Load pretrained hate/toxicity detection model
classifier = pipeline("text-classification", model="unitary/unbiased-toxic-roberta", top_k=None)

THRESHOLD = 0.6  # 60%

# Pydantic model for request body
class TextInput(BaseModel):
    text: str

def check_hate(text: str) -> Dict:
    # Get model predictions
    results = classifier(text)[0]  # List of dicts with label and score
    
    # Define toxic labels as per the model
    toxic_labels = {"toxic", "insult", "obscene", "identity_attack", "threat", "sexual_explicit"}
    
    # Initialize variables
    flagged = False
    prediction = "✅ Clean"
    max_toxic_score = 0.0
    max_toxic_label = "non_toxic"
    
    # Check all labels for toxicity
    for result in results:
        label = result['label'].lower()
        score = result['score']
        if label in toxic_labels and score >= THRESHOLD:
            flagged = True
            prediction = "⚠️ Hate/Toxic"
            if score > max_toxic_score:
                max_toxic_score = score
                max_toxic_label = label
    
    # If no toxic labels are found, use the highest-scoring label
    if not flagged:
        best = max(results, key=lambda x: x['score'])
        max_toxic_label = best['label'].lower()
        max_toxic_score = best['score']
    
    return {
        "text": text,
        "prediction": prediction,
        "confidence": round(max_toxic_score, 2),
        "flagged": flagged,
        "label": max_toxic_label
    }

# API endpoint to check toxicity
@app.post("/check-toxicity", response_model=Dict)
async def check_toxicity(input: TextInput):
    try:
        if not input.text.strip():
            raise HTTPException(status_code=400, detail="Text input cannot be empty")
        result = check_hate(input.text)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing text: {str(e)}")

# Root endpoint for API welcome message
@app.get("/")
async def root():
    return {"message": "Welcome to the Toxicity Detection API. Use POST /check-toxicity with a JSON body containing 'text'."}