toxic / app.py
Alshargi's picture
Create app.py
a13db66 verified
from fastapi import FastAPI, Query
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import time
MODEL_DIR = "./marbert-toxic-model"
LABEL_COLS = ["Appearance", "Cussing", "Hatred", "NOT", "Racial", "Sexual", "Violence"]
LABEL_DESCRIPTIONS = {
"Appearance": "Insult targeting physical appearance",
"Cussing": "Explicit profanity or direct obscene insult",
"Hatred": "Hostile or degrading language without necessarily containing profanity",
"NOT": "Non-toxic text",
"Racial": "Attack based on race, ethnicity, or national identity",
"Sexual": "Sexual harassment or sexually abusive language",
"Violence": "Threats or violent language"
}
app = FastAPI(title="Arabic Toxicity API", version="1.0.0")
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.to(DEVICE)
model.eval()
def predict_text(text: str, threshold: float = 0.5):
start = time.time()
text = text.strip()
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs.logits).squeeze(0).cpu().tolist()
scores = {
label: round(float(score), 6)
for label, score in zip(LABEL_COLS, probs)
}
predicted_labels = [
label for label, score in scores.items()
if score >= threshold
]
top_label = max(scores, key=scores.get)
top_score = scores[top_label]
toxic_labels_only = {k: v for k, v in scores.items() if k != "NOT"}
toxic_score = round(max(toxic_labels_only.values()), 6)
safe_score = scores["NOT"]
is_toxic = top_label != "NOT"
severity = "none"
if is_toxic:
if toxic_score >= 0.90:
severity = "high"
elif toxic_score >= 0.70:
severity = "medium"
else:
severity = "low"
return {
"text": text,
"is_toxic": is_toxic,
"top_label": top_label,
"top_score": top_score,
"predicted_labels": predicted_labels,
"toxic_score": toxic_score,
"safe_score": safe_score,
"severity": severity,
"scores": scores,
"label_descriptions": LABEL_DESCRIPTIONS,
"processing_time_ms": round((time.time() - start) * 1000, 2),
"device": str(DEVICE)
}
@app.get("/")
def root():
return {"message": "Arabic Toxicity API is running"}
@app.get("/health")
def health():
return {"status": "ok", "device": str(DEVICE)}
@app.get("/toxic")
def toxic(
p: str = Query(..., description="Arabic text to analyze"),
threshold: float = Query(0.5, ge=0.0, le=1.0)
):
return predict_text(p, threshold)