from fastapi import FastAPI, Query from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import time MODEL_DIR = "./marbert-toxic-model" LABEL_COLS = ["Appearance", "Cussing", "Hatred", "NOT", "Racial", "Sexual", "Violence"] LABEL_DESCRIPTIONS = { "Appearance": "Insult targeting physical appearance", "Cussing": "Explicit profanity or direct obscene insult", "Hatred": "Hostile or degrading language without necessarily containing profanity", "NOT": "Non-toxic text", "Racial": "Attack based on race, ethnicity, or national identity", "Sexual": "Sexual harassment or sexually abusive language", "Violence": "Threats or violent language" } app = FastAPI(title="Arabic Toxicity API", version="1.0.0") if torch.cuda.is_available(): DEVICE = torch.device("cuda") else: DEVICE = torch.device("cpu") tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) model.to(DEVICE) model.eval() def predict_text(text: str, threshold: float = 0.5): start = time.time() text = text.strip() inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=128 ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) probs = torch.sigmoid(outputs.logits).squeeze(0).cpu().tolist() scores = { label: round(float(score), 6) for label, score in zip(LABEL_COLS, probs) } predicted_labels = [ label for label, score in scores.items() if score >= threshold ] top_label = max(scores, key=scores.get) top_score = scores[top_label] toxic_labels_only = {k: v for k, v in scores.items() if k != "NOT"} toxic_score = round(max(toxic_labels_only.values()), 6) safe_score = scores["NOT"] is_toxic = top_label != "NOT" severity = "none" if is_toxic: if toxic_score >= 0.90: severity = "high" elif toxic_score >= 0.70: severity = "medium" else: severity = "low" return { "text": text, "is_toxic": is_toxic, "top_label": top_label, "top_score": top_score, "predicted_labels": predicted_labels, "toxic_score": toxic_score, "safe_score": safe_score, "severity": severity, "scores": scores, "label_descriptions": LABEL_DESCRIPTIONS, "processing_time_ms": round((time.time() - start) * 1000, 2), "device": str(DEVICE) } @app.get("/") def root(): return {"message": "Arabic Toxicity API is running"} @app.get("/health") def health(): return {"status": "ok", "device": str(DEVICE)} @app.get("/toxic") def toxic( p: str = Query(..., description="Arabic text to analyze"), threshold: float = Query(0.5, ge=0.0, le=1.0) ): return predict_text(p, threshold)