""" Arabic Consumer Complaint Severity Classifier — Hugging Face Spaces Version """ from contextlib import asynccontextmanager from pathlib import Path import os import torch from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForSequenceClassification # ============================================================================ # CONFIGURATION # ============================================================================ MODEL_PATH = os.getenv("MODEL_PATH", "./saved_model") TOKENIZER_NAME = "aubmindlab/bert-base-arabertv02" MAX_LENGTH = 128 # نفس القيمة في التدريب الأصلي LABELS_EN = ["Low", "Medium", "High", "Critical"] LABELS_AR = ["منخفضة", "متوسطة", "عالية", "حرجة"] SEVERITY_COLORS = ["#1F9D55", "#D69E2E", "#DD6B20", "#C53030"] SEVERITY_DESCRIPTIONS = [ "شكوى ذات تأثير محدود، تُعالَج ضمن المسار العادي.", "شكوى تستوجب المتابعة من الجهة المختصّة في وقت معقول.", "شكوى ذات أولوية عالية وتحتاج إلى معالجة سريعة.", "شكوى حرجة تستدعي تدخّلاً فورياً وعاجلاً.", ] state: dict = {} @asynccontextmanager async def lifespan(app: FastAPI): print(f"[startup] Loading tokenizer from: {TOKENIZER_NAME}") print(f"[startup] Loading model from: {MODEL_PATH}") if not Path(MODEL_PATH).exists(): print(f"[error] MODEL_PATH '{MODEL_PATH}' not found.") state["model"] = None state["tokenizer"] = None else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) # تشخيص أحجام المفردات tokenizer_vocab = len(tokenizer) model_vocab = model.config.vocab_size print(f"[diagnostic] Tokenizer vocab size: {tokenizer_vocab}") print(f"[diagnostic] Model vocab size: {model_vocab}") # مزامنة الأحجام إذا كان فيه اختلاف if tokenizer_vocab != model_vocab: print(f"[fix] Resizing model token embeddings: {model_vocab} -> {tokenizer_vocab}") model.resize_token_embeddings(tokenizer_vocab) model.to(device).eval() state["tokenizer"] = tokenizer state["model"] = model state["device"] = device print(f"[startup] ✅ Model ready on {device} | num_labels={model.config.num_labels}") yield state.clear() app = FastAPI( title="Arabic Complaint Severity Classifier", description="Thesis demo — Vision 2030 consumer protection NLP", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) BASE_DIR = Path(__file__).parent app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static") templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) class ComplaintRequest(BaseModel): complaint: str = Field(..., min_length=5) product_name: str | None = None store_type: str | None = None violation_type: str | None = None def predict_severity(text: str) -> dict: tokenizer = state.get("tokenizer") model = state.get("model") if model is None or tokenizer is None: raise RuntimeError("Model not loaded.") device = state["device"] inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH, ).to(device) # حماية إضافية: تأكدي من أن جميع الـ token IDs ضمن النطاق vocab_size = model.config.vocab_size inputs["input_ids"] = torch.clamp(inputs["input_ids"], max=vocab_size - 1) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] pred_idx = int(probs.argmax()) return { "severity_ar": LABELS_AR[pred_idx], "severity_en": LABELS_EN[pred_idx], "severity_index": pred_idx, "confidence": float(probs[pred_idx]), "color": SEVERITY_COLORS[pred_idx], "description": SEVERITY_DESCRIPTIONS[pred_idx], "all_probabilities": {LABELS_EN[i]: float(probs[i]) for i in range(len(LABELS_EN))}, "input_length": len(text), } @app.get("/", response_class=HTMLResponse) async def root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/api/predict") async def predict(req: ComplaintRequest): if state.get("model") is None: raise HTTPException(503, "المودل غير محمّل") parts = [] if req.product_name: parts.append(f"السلعة: {req.product_name.strip()}") if req.violation_type: parts.append(f"نوع المخالفة: {req.violation_type.strip()}") parts.append(req.complaint.strip()) full_text = " | ".join(parts) try: return predict_severity(full_text) except Exception as e: raise HTTPException(500, f"Prediction error: {e}") @app.get("/api/health") async def health(): return { "status": "ok", "model_loaded": state.get("model") is not None, "device": str(state.get("device", "n/a")), "labels": LABELS_AR, } if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)