| """ |
| Arabic Consumer Complaint Severity Classifier — Hugging Face Spaces Version |
| """ |
|
|
| from contextlib import asynccontextmanager |
| from pathlib import Path |
| import os |
|
|
| import torch |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from pydantic import BaseModel, Field |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
| |
| |
| |
|
|
| MODEL_PATH = os.getenv("MODEL_PATH", "./saved_model") |
| TOKENIZER_NAME = "aubmindlab/bert-base-arabertv02" |
| MAX_LENGTH = 128 |
|
|
| LABELS_EN = ["Low", "Medium", "High", "Critical"] |
| LABELS_AR = ["منخفضة", "متوسطة", "عالية", "حرجة"] |
|
|
| SEVERITY_COLORS = ["#1F9D55", "#D69E2E", "#DD6B20", "#C53030"] |
| SEVERITY_DESCRIPTIONS = [ |
| "شكوى ذات تأثير محدود، تُعالَج ضمن المسار العادي.", |
| "شكوى تستوجب المتابعة من الجهة المختصّة في وقت معقول.", |
| "شكوى ذات أولوية عالية وتحتاج إلى معالجة سريعة.", |
| "شكوى حرجة تستدعي تدخّلاً فورياً وعاجلاً.", |
| ] |
|
|
| state: dict = {} |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| print(f"[startup] Loading tokenizer from: {TOKENIZER_NAME}") |
| print(f"[startup] Loading model from: {MODEL_PATH}") |
| if not Path(MODEL_PATH).exists(): |
| print(f"[error] MODEL_PATH '{MODEL_PATH}' not found.") |
| state["model"] = None |
| state["tokenizer"] = None |
| else: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
|
|
| |
| tokenizer_vocab = len(tokenizer) |
| model_vocab = model.config.vocab_size |
| print(f"[diagnostic] Tokenizer vocab size: {tokenizer_vocab}") |
| print(f"[diagnostic] Model vocab size: {model_vocab}") |
|
|
| |
| if tokenizer_vocab != model_vocab: |
| print(f"[fix] Resizing model token embeddings: {model_vocab} -> {tokenizer_vocab}") |
| model.resize_token_embeddings(tokenizer_vocab) |
|
|
| model.to(device).eval() |
| state["tokenizer"] = tokenizer |
| state["model"] = model |
| state["device"] = device |
| print(f"[startup] ✅ Model ready on {device} | num_labels={model.config.num_labels}") |
| yield |
| state.clear() |
|
|
|
|
| app = FastAPI( |
| title="Arabic Complaint Severity Classifier", |
| description="Thesis demo — Vision 2030 consumer protection NLP", |
| version="1.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| BASE_DIR = Path(__file__).parent |
| app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static") |
| templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) |
|
|
|
|
| class ComplaintRequest(BaseModel): |
| complaint: str = Field(..., min_length=5) |
| product_name: str | None = None |
| store_type: str | None = None |
| violation_type: str | None = None |
|
|
|
|
| def predict_severity(text: str) -> dict: |
| tokenizer = state.get("tokenizer") |
| model = state.get("model") |
| if model is None or tokenizer is None: |
| raise RuntimeError("Model not loaded.") |
|
|
| device = state["device"] |
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=MAX_LENGTH, |
| ).to(device) |
|
|
| |
| vocab_size = model.config.vocab_size |
| inputs["input_ids"] = torch.clamp(inputs["input_ids"], max=vocab_size - 1) |
|
|
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
| pred_idx = int(probs.argmax()) |
| return { |
| "severity_ar": LABELS_AR[pred_idx], |
| "severity_en": LABELS_EN[pred_idx], |
| "severity_index": pred_idx, |
| "confidence": float(probs[pred_idx]), |
| "color": SEVERITY_COLORS[pred_idx], |
| "description": SEVERITY_DESCRIPTIONS[pred_idx], |
| "all_probabilities": {LABELS_EN[i]: float(probs[i]) for i in range(len(LABELS_EN))}, |
| "input_length": len(text), |
| } |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(request: Request): |
| return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
| @app.post("/api/predict") |
| async def predict(req: ComplaintRequest): |
| if state.get("model") is None: |
| raise HTTPException(503, "المودل غير محمّل") |
| parts = [] |
| if req.product_name: |
| parts.append(f"السلعة: {req.product_name.strip()}") |
| if req.violation_type: |
| parts.append(f"نوع المخالفة: {req.violation_type.strip()}") |
| parts.append(req.complaint.strip()) |
| full_text = " | ".join(parts) |
| try: |
| return predict_severity(full_text) |
| except Exception as e: |
| raise HTTPException(500, f"Prediction error: {e}") |
|
|
|
|
| @app.get("/api/health") |
| async def health(): |
| return { |
| "status": "ok", |
| "model_loaded": state.get("model") is not None, |
| "device": str(state.get("device", "n/a")), |
| "labels": LABELS_AR, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False) |