Sehamsaa's picture
Update main.py
e8fc8aa verified
"""
Arabic Consumer Complaint Severity Classifier — Hugging Face Spaces Version
"""
from contextlib import asynccontextmanager
from pathlib import Path
import os
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ============================================================================
# CONFIGURATION
# ============================================================================
MODEL_PATH = os.getenv("MODEL_PATH", "./saved_model")
TOKENIZER_NAME = "aubmindlab/bert-base-arabertv02"
MAX_LENGTH = 128 # نفس القيمة في التدريب الأصلي
LABELS_EN = ["Low", "Medium", "High", "Critical"]
LABELS_AR = ["منخفضة", "متوسطة", "عالية", "حرجة"]
SEVERITY_COLORS = ["#1F9D55", "#D69E2E", "#DD6B20", "#C53030"]
SEVERITY_DESCRIPTIONS = [
"شكوى ذات تأثير محدود، تُعالَج ضمن المسار العادي.",
"شكوى تستوجب المتابعة من الجهة المختصّة في وقت معقول.",
"شكوى ذات أولوية عالية وتحتاج إلى معالجة سريعة.",
"شكوى حرجة تستدعي تدخّلاً فورياً وعاجلاً.",
]
state: dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
print(f"[startup] Loading tokenizer from: {TOKENIZER_NAME}")
print(f"[startup] Loading model from: {MODEL_PATH}")
if not Path(MODEL_PATH).exists():
print(f"[error] MODEL_PATH '{MODEL_PATH}' not found.")
state["model"] = None
state["tokenizer"] = None
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
# تشخيص أحجام المفردات
tokenizer_vocab = len(tokenizer)
model_vocab = model.config.vocab_size
print(f"[diagnostic] Tokenizer vocab size: {tokenizer_vocab}")
print(f"[diagnostic] Model vocab size: {model_vocab}")
# مزامنة الأحجام إذا كان فيه اختلاف
if tokenizer_vocab != model_vocab:
print(f"[fix] Resizing model token embeddings: {model_vocab} -> {tokenizer_vocab}")
model.resize_token_embeddings(tokenizer_vocab)
model.to(device).eval()
state["tokenizer"] = tokenizer
state["model"] = model
state["device"] = device
print(f"[startup] ✅ Model ready on {device} | num_labels={model.config.num_labels}")
yield
state.clear()
app = FastAPI(
title="Arabic Complaint Severity Classifier",
description="Thesis demo — Vision 2030 consumer protection NLP",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
BASE_DIR = Path(__file__).parent
app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static")
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
class ComplaintRequest(BaseModel):
complaint: str = Field(..., min_length=5)
product_name: str | None = None
store_type: str | None = None
violation_type: str | None = None
def predict_severity(text: str) -> dict:
tokenizer = state.get("tokenizer")
model = state.get("model")
if model is None or tokenizer is None:
raise RuntimeError("Model not loaded.")
device = state["device"]
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LENGTH,
).to(device)
# حماية إضافية: تأكدي من أن جميع الـ token IDs ضمن النطاق
vocab_size = model.config.vocab_size
inputs["input_ids"] = torch.clamp(inputs["input_ids"], max=vocab_size - 1)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
pred_idx = int(probs.argmax())
return {
"severity_ar": LABELS_AR[pred_idx],
"severity_en": LABELS_EN[pred_idx],
"severity_index": pred_idx,
"confidence": float(probs[pred_idx]),
"color": SEVERITY_COLORS[pred_idx],
"description": SEVERITY_DESCRIPTIONS[pred_idx],
"all_probabilities": {LABELS_EN[i]: float(probs[i]) for i in range(len(LABELS_EN))},
"input_length": len(text),
}
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/api/predict")
async def predict(req: ComplaintRequest):
if state.get("model") is None:
raise HTTPException(503, "المودل غير محمّل")
parts = []
if req.product_name:
parts.append(f"السلعة: {req.product_name.strip()}")
if req.violation_type:
parts.append(f"نوع المخالفة: {req.violation_type.strip()}")
parts.append(req.complaint.strip())
full_text = " | ".join(parts)
try:
return predict_severity(full_text)
except Exception as e:
raise HTTPException(500, f"Prediction error: {e}")
@app.get("/api/health")
async def health():
return {
"status": "ok",
"model_loaded": state.get("model") is not None,
"device": str(state.get("device", "n/a")),
"labels": LABELS_AR,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)