JustFadjrin commited on
Commit
fa6f400
·
1 Parent(s): ac55f65

Deploy Batik ViT FastAPI backend

Browse files
Files changed (4) hide show
  1. Dockerfile +26 -0
  2. README.md +44 -6
  3. main.py +272 -0
  4. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV MODEL_DIR=JustFadjrin/batik-vit-model-classification
8
+ ENV TOP_K=5
9
+ ENV CORS_ORIGINS=http://localhost:3000
10
+
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ build-essential \
13
+ libgl1 \
14
+ libglib2.0-0 \
15
+ git \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ COPY requirements.txt .
19
+
20
+ RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt
21
+
22
+ COPY . .
23
+
24
+ EXPOSE 8000
25
+
26
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,11 +1,49 @@
 
1
  ---
2
- title: Batik Vit Api
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: docker
 
7
  pinned: false
8
- license: mit
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  ---
3
+ title: Batik ViT API
4
+ emoji: 🧵
5
+ colorFrom: amber
6
+ colorTo: brown
7
  sdk: docker
8
+ app_port: 8000
9
  pinned: false
 
10
  ---
11
 
12
+ # Backend FastAPI Batik ViT
13
+ ## Struktur
14
+
15
+ ```text
16
+ backend/
17
+ main.py
18
+ requirements.txt
19
+ Dockerfile
20
+ model/
21
+ config.json
22
+ model.safetensors atau pytorch_model.bin
23
+ preprocessor_config.json
24
+ labels.json
25
+ model_info.json
26
+ ```
27
+
28
+ ## Jalankan lokal
29
+
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ uvicorn main:app --host 0.0.0.0 --port 8000 --reload
33
+ ```
34
+
35
+ ## Endpoint
36
+
37
+ ```text
38
+ GET /
39
+ GET /health
40
+ GET /model-info
41
+ POST /predict
42
+ ```
43
+
44
+ ## Contoh cURL
45
+
46
+ ```bash
47
+ curl -X POST "http://localhost:8000/predict" \
48
+ -F "file=@contoh_batik.jpg"
49
+ ```
main.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ from pathlib import Path
5
+ from typing import List, Dict, Any, Optional
6
+
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ from fastapi import FastAPI, UploadFile, File, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
13
+
14
+
15
+ MODEL_DIR = os.getenv("MODEL_DIR", "model")
16
+ TOP_K_DEFAULT = int(os.getenv("TOP_K", "5"))
17
+
18
+ # Isi CORS_ORIGINS bisa:
19
+ # CORS_ORIGINS=http://localhost:3000,https://nama-app.vercel.app
20
+ cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:3000")
21
+ CORS_ORIGINS = [origin.strip() for origin in cors_origins_env.split(",") if origin.strip()]
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ app = FastAPI(
26
+ title="Batik ViT Classifier API",
27
+ description="API klasifikasi jenis batik menggunakan Vision Transformer",
28
+ version="1.0.0",
29
+ )
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=CORS_ORIGINS,
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+
40
+ class PredictionItem(BaseModel):
41
+ label: str
42
+ confidence: float
43
+
44
+
45
+ class PredictionResponse(BaseModel):
46
+ status: str
47
+ reason: str
48
+ top_prediction: PredictionItem
49
+ second_prediction: Optional[PredictionItem]
50
+ margin: float
51
+ predictions: List[PredictionItem]
52
+
53
+
54
+ processor = None
55
+ model = None
56
+ model_info: Dict[str, Any] = {}
57
+
58
+
59
+ def load_model() -> None:
60
+ global processor, model, model_info
61
+
62
+ hf_token = os.getenv("HF_TOKEN")
63
+ model_source = MODEL_DIR
64
+
65
+ local_model_path = Path(MODEL_DIR)
66
+
67
+ if local_model_path.exists():
68
+ model_source = str(local_model_path.resolve())
69
+ print(f"Loading local model from: {model_source}")
70
+ else:
71
+ print(f"Loading remote model from Hugging Face Hub: {model_source}")
72
+
73
+ model_kwargs = {}
74
+
75
+ if hf_token:
76
+ model_kwargs["token"] = hf_token
77
+
78
+ processor = AutoImageProcessor.from_pretrained(model_source, **model_kwargs)
79
+ model = AutoModelForImageClassification.from_pretrained(model_source, **model_kwargs)
80
+
81
+ model.to(device)
82
+ model.eval()
83
+
84
+ model_info = {}
85
+
86
+ info_path = local_model_path / "model_info.json"
87
+
88
+ if local_model_path.exists() and info_path.exists():
89
+ with open(info_path, "r", encoding="utf-8") as f:
90
+ model_info = json.load(f)
91
+
92
+ print(f"Model loaded from: {model_source}")
93
+ print(f"Device: {device}")
94
+
95
+ @app.on_event("startup")
96
+ def startup_event():
97
+ load_model()
98
+
99
+
100
+ def get_status(label: str, top1_conf: float, margin: float) -> tuple[str, str]:
101
+ """
102
+ Logic status final.
103
+ Kelas Parang dibuat lebih ketat karena Solo_Parang dan Yogyakarta_Parang
104
+ cenderung mirip dan sering tertukar.
105
+ """
106
+
107
+ parang_classes = {"Solo_Parang", "Yogyakarta_Parang"}
108
+
109
+ if label in parang_classes:
110
+ if top1_conf >= 0.75 and margin >= 0.30:
111
+ return (
112
+ "Model yakin",
113
+ "Prediksi kelas Parang memiliki confidence tinggi dan margin cukup aman."
114
+ )
115
+
116
+ if top1_conf >= 0.50 and margin >= 0.25:
117
+ return (
118
+ "Model cukup yakin",
119
+ "Prediksi kelas Parang cukup kuat, tetapi tetap perlu hati-hati karena kelas Parang mirip."
120
+ )
121
+
122
+ return (
123
+ "Model belum yakin",
124
+ "Prediksi kelas Parang belum cukup aman karena confidence atau margin masih rendah."
125
+ )
126
+
127
+ if top1_conf >= 0.60 and margin >= 0.20:
128
+ return (
129
+ "Model yakin",
130
+ "Confidence tinggi dan jarak prediksi pertama dengan kedua cukup jauh."
131
+ )
132
+
133
+ if top1_conf >= 0.40 and margin >= 0.25:
134
+ return (
135
+ "Model cukup yakin",
136
+ "Confidence sedang, tetapi prediksi pertama jauh lebih dominan dari prediksi kedua."
137
+ )
138
+
139
+ if top1_conf >= 0.35 and margin >= 0.35:
140
+ return (
141
+ "Model cukup yakin",
142
+ "Confidence tidak terlalu tinggi, tetapi prediksi pertama sangat jauh dari prediksi kedua."
143
+ )
144
+
145
+ return (
146
+ "Model belum yakin",
147
+ "Confidence rendah atau prediksi pertama terlalu dekat dengan prediksi kedua."
148
+ )
149
+
150
+
151
+ def predict_image(image: Image.Image, top_k: int = TOP_K_DEFAULT, use_tta: bool = True) -> Dict[str, Any]:
152
+ if processor is None or model is None:
153
+ raise RuntimeError("Model belum diload.")
154
+
155
+ image = image.convert("RGB")
156
+
157
+ if use_tta:
158
+ images = [image, ImageOps.mirror(image)]
159
+ else:
160
+ images = [image]
161
+
162
+ inputs = processor(images=images, return_tensors="pt")
163
+ inputs = {k: v.to(device) for k, v in inputs.items()}
164
+
165
+ with torch.no_grad():
166
+ outputs = model(**inputs)
167
+ logits = outputs.logits
168
+
169
+ # Rata-rata logits original + mirror
170
+ avg_logits = logits.mean(dim=0, keepdim=True)
171
+ probs = torch.softmax(avg_logits, dim=-1)[0]
172
+
173
+ max_k = min(top_k, probs.shape[-1])
174
+ top_probs, top_indices = torch.topk(probs, k=max_k)
175
+
176
+ predictions = []
177
+
178
+ for prob, idx in zip(top_probs, top_indices):
179
+ idx_int = int(idx.item())
180
+
181
+ label = model.config.id2label.get(idx_int, str(idx_int))
182
+ confidence = float(prob.item())
183
+
184
+ predictions.append({
185
+ "label": label,
186
+ "confidence": confidence,
187
+ })
188
+
189
+ top1 = predictions[0]
190
+ top2 = predictions[1] if len(predictions) > 1 else None
191
+
192
+ top1_conf = top1["confidence"]
193
+ top2_conf = top2["confidence"] if top2 else 0.0
194
+ margin = top1_conf - top2_conf
195
+
196
+ status, reason = get_status(
197
+ label=top1["label"],
198
+ top1_conf=top1_conf,
199
+ margin=margin,
200
+ )
201
+
202
+ return {
203
+ "status": status,
204
+ "reason": reason,
205
+ "top_prediction": top1,
206
+ "second_prediction": top2,
207
+ "margin": margin,
208
+ "predictions": predictions,
209
+ }
210
+
211
+
212
+ @app.get("/")
213
+ def root():
214
+ return {
215
+ "message": "Batik ViT Classifier API",
216
+ "docs": "/docs",
217
+ "health": "/health",
218
+ }
219
+
220
+
221
+ @app.get("/health")
222
+ def health():
223
+ return {
224
+ "status": "ok",
225
+ "device": device,
226
+ "model_dir": str(Path(MODEL_DIR).resolve()),
227
+ "num_labels": getattr(model.config, "num_labels", None) if model else None,
228
+ "cors_origins": CORS_ORIGINS,
229
+ }
230
+
231
+
232
+ @app.get("/model-info")
233
+ def get_model_info():
234
+ return {
235
+ "model_info": model_info,
236
+ "labels": getattr(model.config, "id2label", {}) if model else {},
237
+ }
238
+
239
+
240
+ @app.post("/predict", response_model=PredictionResponse)
241
+ async def predict(
242
+ file: UploadFile = File(...),
243
+ top_k: int = TOP_K_DEFAULT,
244
+ use_tta: bool = True,
245
+ ):
246
+ if not file.content_type or not file.content_type.startswith("image/"):
247
+ raise HTTPException(
248
+ status_code=400,
249
+ detail="File harus berupa gambar."
250
+ )
251
+
252
+ try:
253
+ image_bytes = await file.read()
254
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
255
+ except Exception as exc:
256
+ raise HTTPException(
257
+ status_code=400,
258
+ detail=f"Gagal membaca gambar: {exc}"
259
+ )
260
+
261
+ try:
262
+ result = predict_image(
263
+ image=image,
264
+ top_k=top_k,
265
+ use_tta=use_tta,
266
+ )
267
+ return result
268
+ except Exception as exc:
269
+ raise HTTPException(
270
+ status_code=500,
271
+ detail=f"Gagal melakukan prediksi: {exc}"
272
+ )
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ python-multipart==0.0.9
4
+ pillow==10.4.0
5
+ torch
6
+ torchvision
7
+ transformers
8
+ safetensors
9
+ pydantic==2.8.2