| from fastapi import FastAPI, Query |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch |
| import time |
|
|
| MODEL_DIR = "./marbert-toxic-model" |
| LABEL_COLS = ["Appearance", "Cussing", "Hatred", "NOT", "Racial", "Sexual", "Violence"] |
|
|
| LABEL_DESCRIPTIONS = { |
| "Appearance": "Insult targeting physical appearance", |
| "Cussing": "Explicit profanity or direct obscene insult", |
| "Hatred": "Hostile or degrading language without necessarily containing profanity", |
| "NOT": "Non-toxic text", |
| "Racial": "Attack based on race, ethnicity, or national identity", |
| "Sexual": "Sexual harassment or sexually abusive language", |
| "Violence": "Threats or violent language" |
| } |
|
|
| app = FastAPI(title="Arabic Toxicity API", version="1.0.0") |
|
|
| if torch.cuda.is_available(): |
| DEVICE = torch.device("cuda") |
| else: |
| DEVICE = torch.device("cpu") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
| model.to(DEVICE) |
| model.eval() |
|
|
| def predict_text(text: str, threshold: float = 0.5): |
| start = time.time() |
| text = text.strip() |
|
|
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=128 |
| ) |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.sigmoid(outputs.logits).squeeze(0).cpu().tolist() |
|
|
| scores = { |
| label: round(float(score), 6) |
| for label, score in zip(LABEL_COLS, probs) |
| } |
|
|
| predicted_labels = [ |
| label for label, score in scores.items() |
| if score >= threshold |
| ] |
|
|
| top_label = max(scores, key=scores.get) |
| top_score = scores[top_label] |
|
|
| toxic_labels_only = {k: v for k, v in scores.items() if k != "NOT"} |
| toxic_score = round(max(toxic_labels_only.values()), 6) |
| safe_score = scores["NOT"] |
|
|
| is_toxic = top_label != "NOT" |
|
|
| severity = "none" |
| if is_toxic: |
| if toxic_score >= 0.90: |
| severity = "high" |
| elif toxic_score >= 0.70: |
| severity = "medium" |
| else: |
| severity = "low" |
|
|
| return { |
| "text": text, |
| "is_toxic": is_toxic, |
| "top_label": top_label, |
| "top_score": top_score, |
| "predicted_labels": predicted_labels, |
| "toxic_score": toxic_score, |
| "safe_score": safe_score, |
| "severity": severity, |
| "scores": scores, |
| "label_descriptions": LABEL_DESCRIPTIONS, |
| "processing_time_ms": round((time.time() - start) * 1000, 2), |
| "device": str(DEVICE) |
| } |
|
|
| @app.get("/") |
| def root(): |
| return {"message": "Arabic Toxicity API is running"} |
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "device": str(DEVICE)} |
|
|
| @app.get("/toxic") |
| def toxic( |
| p: str = Query(..., description="Arabic text to analyze"), |
| threshold: float = Query(0.5, ge=0.0, le=1.0) |
| ): |
| return predict_text(p, threshold) |