Sehamsaa commited on
Commit
b41e161
·
verified ·
1 Parent(s): a619053

Add backend files

Browse files
Files changed (3) hide show
  1. Dockerfile +34 -0
  2. main.py +150 -0
  3. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile لـ Hugging Face Spaces
2
+ FROM python:3.11-slim
3
+
4
+ # تثبيت متطلبات النظام
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ git \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ # إعداد متغيرات البيئة
10
+ ENV PYTHONUNBUFFERED=1 \
11
+ PYTHONDONTWRITEBYTECODE=1 \
12
+ PIP_NO_CACHE_DIR=1 \
13
+ HF_HOME=/tmp/hf_home \
14
+ TRANSFORMERS_CACHE=/tmp/hf_home
15
+
16
+ # إنشاء مستخدم غير root (مطلب Hugging Face Spaces)
17
+ RUN useradd -m -u 1000 user
18
+ USER user
19
+ ENV PATH="/home/user/.local/bin:$PATH"
20
+ WORKDIR /home/user/app
21
+
22
+ # نسخ requirements أولاً (للاستفادة من Docker cache)
23
+ COPY --chown=user requirements.txt .
24
+ RUN pip install --user --upgrade pip && \
25
+ pip install --user -r requirements.txt
26
+
27
+ # نسخ بقية الملفات
28
+ COPY --chown=user . .
29
+
30
+ # Hugging Face Spaces يستخدم المنفذ 7860
31
+ EXPOSE 7860
32
+
33
+ # تشغيل الخادم
34
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Arabic Consumer Complaint Severity Classifier — Hugging Face Spaces Version
3
+ ============================================================================
4
+ نسخة معدّلة للنشر على Hugging Face Spaces (Docker mode).
5
+ الفرق الرئيسي عن النسخة المحلية: المنفذ 7860 (المعياري في HF Spaces).
6
+ """
7
+
8
+ from contextlib import asynccontextmanager
9
+ from pathlib import Path
10
+ import os
11
+
12
+ import torch
13
+ from fastapi import FastAPI, HTTPException, Request
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import HTMLResponse
16
+ from fastapi.staticfiles import StaticFiles
17
+ from fastapi.templating import Jinja2Templates
18
+ from pydantic import BaseModel, Field
19
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
20
+
21
+
22
+ # ============================================================================
23
+ # CONFIGURATION
24
+ # ============================================================================
25
+
26
+ MODEL_PATH = os.getenv("MODEL_PATH", "./saved_model")
27
+ MAX_LENGTH = 256
28
+
29
+ LABELS_EN = ["Low", "Medium", "High", "Critical"]
30
+ LABELS_AR = ["منخفضة", "متوسطة", "عالية", "حرجة"]
31
+
32
+ SEVERITY_COLORS = ["#1F9D55", "#D69E2E", "#DD6B20", "#C53030"]
33
+ SEVERITY_DESCRIPTIONS = [
34
+ "شكوى ذات تأثير محدود، تُعالَج ضمن المسار العادي.",
35
+ "شكوى تستوجب المتابعة من الجهة المختصّة في وقت معقول.",
36
+ "شكوى ذات أولوية عالية وتحتاج إلى معالجة سريعة.",
37
+ "شكوى حرجة تستدعي تدخّلاً فورياً وعاجلاً.",
38
+ ]
39
+
40
+ state: dict = {}
41
+
42
+
43
+ @asynccontextmanager
44
+ async def lifespan(app: FastAPI):
45
+ print(f"[startup] Loading model from: {MODEL_PATH}")
46
+ if not Path(MODEL_PATH).exists():
47
+ print(f"[error] MODEL_PATH '{MODEL_PATH}' not found.")
48
+ state["model"] = None
49
+ state["tokenizer"] = None
50
+ else:
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
53
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
54
+ model.to(device).eval()
55
+ state["tokenizer"] = tokenizer
56
+ state["model"] = model
57
+ state["device"] = device
58
+ print(f"[startup] Model ready on {device} | num_labels={model.config.num_labels}")
59
+ yield
60
+ state.clear()
61
+
62
+
63
+ app = FastAPI(
64
+ title="Arabic Complaint Severity Classifier",
65
+ description="Thesis demo — Vision 2030 consumer protection NLP",
66
+ version="1.0.0",
67
+ lifespan=lifespan,
68
+ )
69
+
70
+ app.add_middleware(
71
+ CORSMiddleware,
72
+ allow_origins=["*"],
73
+ allow_methods=["*"],
74
+ allow_headers=["*"],
75
+ )
76
+
77
+ BASE_DIR = Path(__file__).parent
78
+ app.mount("/static", StaticFiles(directory=BASE_DIR / "static"), name="static")
79
+ templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
80
+
81
+
82
+ class ComplaintRequest(BaseModel):
83
+ complaint: str = Field(..., min_length=5)
84
+ product_name: str | None = None
85
+ store_type: str | None = None
86
+ violation_type: str | None = None
87
+
88
+
89
+ def predict_severity(text: str) -> dict:
90
+ tokenizer = state.get("tokenizer")
91
+ model = state.get("model")
92
+ if model is None or tokenizer is None:
93
+ raise RuntimeError("Model not loaded.")
94
+
95
+ device = state["device"]
96
+ inputs = tokenizer(text, return_tensors="pt", truncation=True,
97
+ padding=True, max_length=MAX_LENGTH).to(device)
98
+ with torch.no_grad():
99
+ logits = model(**inputs).logits
100
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
101
+
102
+ pred_idx = int(probs.argmax())
103
+ return {
104
+ "severity_ar": LABELS_AR[pred_idx],
105
+ "severity_en": LABELS_EN[pred_idx],
106
+ "severity_index": pred_idx,
107
+ "confidence": float(probs[pred_idx]),
108
+ "color": SEVERITY_COLORS[pred_idx],
109
+ "description": SEVERITY_DESCRIPTIONS[pred_idx],
110
+ "all_probabilities": {LABELS_EN[i]: float(probs[i]) for i in range(len(LABELS_EN))},
111
+ "input_length": len(text),
112
+ }
113
+
114
+
115
+ @app.get("/", response_class=HTMLResponse)
116
+ async def root(request: Request):
117
+ return templates.TemplateResponse("index.html", {"request": request})
118
+
119
+
120
+ @app.post("/api/predict")
121
+ async def predict(req: ComplaintRequest):
122
+ if state.get("model") is None:
123
+ raise HTTPException(503, "المودل غير محمّل")
124
+ parts = []
125
+ if req.product_name:
126
+ parts.append(f"السلعة: {req.product_name.strip()}")
127
+ if req.violation_type:
128
+ parts.append(f"نوع المخالفة: {req.violation_type.strip()}")
129
+ parts.append(req.complaint.strip())
130
+ full_text = " | ".join(parts)
131
+ try:
132
+ return predict_severity(full_text)
133
+ except Exception as e:
134
+ raise HTTPException(500, f"Prediction error: {e}")
135
+
136
+
137
+ @app.get("/api/health")
138
+ async def health():
139
+ return {
140
+ "status": "ok",
141
+ "model_loaded": state.get("model") is not None,
142
+ "device": str(state.get("device", "n/a")),
143
+ "labels": LABELS_AR,
144
+ }
145
+
146
+
147
+ if __name__ == "__main__":
148
+ import uvicorn
149
+ # Hugging Face Spaces يستخدم port 7860
150
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.32.0
3
+ transformers==4.45.0
4
+ torch>=2.0.0
5
+ jinja2==3.1.4
6
+ python-multipart==0.0.12
7
+ pydantic==2.9.0