File size: 7,343 Bytes
fa6f400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
import io
import json
from pathlib import Path
from typing import List, Dict, Any, Optional

import torch
from PIL import Image, ImageOps
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoImageProcessor, AutoModelForImageClassification


MODEL_DIR = os.getenv("MODEL_DIR", "model")
TOP_K_DEFAULT = int(os.getenv("TOP_K", "5"))

# Isi CORS_ORIGINS bisa:
# CORS_ORIGINS=http://localhost:3000,https://nama-app.vercel.app
cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:3000")
CORS_ORIGINS = [origin.strip() for origin in cors_origins_env.split(",") if origin.strip()]

device = "cuda" if torch.cuda.is_available() else "cpu"

app = FastAPI(
    title="Batik ViT Classifier API",
    description="API klasifikasi jenis batik menggunakan Vision Transformer",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class PredictionItem(BaseModel):
    label: str
    confidence: float


class PredictionResponse(BaseModel):
    status: str
    reason: str
    top_prediction: PredictionItem
    second_prediction: Optional[PredictionItem]
    margin: float
    predictions: List[PredictionItem]


processor = None
model = None
model_info: Dict[str, Any] = {}


def load_model() -> None:
    global processor, model, model_info

    hf_token = os.getenv("HF_TOKEN")
    model_source = MODEL_DIR

    local_model_path = Path(MODEL_DIR)

    if local_model_path.exists():
        model_source = str(local_model_path.resolve())
        print(f"Loading local model from: {model_source}")
    else:
        print(f"Loading remote model from Hugging Face Hub: {model_source}")

    model_kwargs = {}

    if hf_token:
        model_kwargs["token"] = hf_token

    processor = AutoImageProcessor.from_pretrained(model_source, **model_kwargs)
    model = AutoModelForImageClassification.from_pretrained(model_source, **model_kwargs)

    model.to(device)
    model.eval()

    model_info = {}

    info_path = local_model_path / "model_info.json"

    if local_model_path.exists() and info_path.exists():
        with open(info_path, "r", encoding="utf-8") as f:
            model_info = json.load(f)

    print(f"Model loaded from: {model_source}")
    print(f"Device: {device}")

@app.on_event("startup")
def startup_event():
    load_model()


def get_status(label: str, top1_conf: float, margin: float) -> tuple[str, str]:
    """
    Logic status final.
    Kelas Parang dibuat lebih ketat karena Solo_Parang dan Yogyakarta_Parang
    cenderung mirip dan sering tertukar.
    """

    parang_classes = {"Solo_Parang", "Yogyakarta_Parang"}

    if label in parang_classes:
        if top1_conf >= 0.75 and margin >= 0.30:
            return (
                "Model yakin",
                "Prediksi kelas Parang memiliki confidence tinggi dan margin cukup aman."
            )

        if top1_conf >= 0.50 and margin >= 0.25:
            return (
                "Model cukup yakin",
                "Prediksi kelas Parang cukup kuat, tetapi tetap perlu hati-hati karena kelas Parang mirip."
            )

        return (
            "Model belum yakin",
            "Prediksi kelas Parang belum cukup aman karena confidence atau margin masih rendah."
        )

    if top1_conf >= 0.60 and margin >= 0.20:
        return (
            "Model yakin",
            "Confidence tinggi dan jarak prediksi pertama dengan kedua cukup jauh."
        )

    if top1_conf >= 0.40 and margin >= 0.25:
        return (
            "Model cukup yakin",
            "Confidence sedang, tetapi prediksi pertama jauh lebih dominan dari prediksi kedua."
        )

    if top1_conf >= 0.35 and margin >= 0.35:
        return (
            "Model cukup yakin",
            "Confidence tidak terlalu tinggi, tetapi prediksi pertama sangat jauh dari prediksi kedua."
        )

    return (
        "Model belum yakin",
        "Confidence rendah atau prediksi pertama terlalu dekat dengan prediksi kedua."
    )


def predict_image(image: Image.Image, top_k: int = TOP_K_DEFAULT, use_tta: bool = True) -> Dict[str, Any]:
    if processor is None or model is None:
        raise RuntimeError("Model belum diload.")

    image = image.convert("RGB")

    if use_tta:
        images = [image, ImageOps.mirror(image)]
    else:
        images = [image]

    inputs = processor(images=images, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

        # Rata-rata logits original + mirror
        avg_logits = logits.mean(dim=0, keepdim=True)
        probs = torch.softmax(avg_logits, dim=-1)[0]

    max_k = min(top_k, probs.shape[-1])
    top_probs, top_indices = torch.topk(probs, k=max_k)

    predictions = []

    for prob, idx in zip(top_probs, top_indices):
        idx_int = int(idx.item())

        label = model.config.id2label.get(idx_int, str(idx_int))
        confidence = float(prob.item())

        predictions.append({
            "label": label,
            "confidence": confidence,
        })

    top1 = predictions[0]
    top2 = predictions[1] if len(predictions) > 1 else None

    top1_conf = top1["confidence"]
    top2_conf = top2["confidence"] if top2 else 0.0
    margin = top1_conf - top2_conf

    status, reason = get_status(
        label=top1["label"],
        top1_conf=top1_conf,
        margin=margin,
    )

    return {
        "status": status,
        "reason": reason,
        "top_prediction": top1,
        "second_prediction": top2,
        "margin": margin,
        "predictions": predictions,
    }


@app.get("/")
def root():
    return {
        "message": "Batik ViT Classifier API",
        "docs": "/docs",
        "health": "/health",
    }


@app.get("/health")
def health():
    return {
        "status": "ok",
        "device": device,
        "model_dir": str(Path(MODEL_DIR).resolve()),
        "num_labels": getattr(model.config, "num_labels", None) if model else None,
        "cors_origins": CORS_ORIGINS,
    }


@app.get("/model-info")
def get_model_info():
    return {
        "model_info": model_info,
        "labels": getattr(model.config, "id2label", {}) if model else {},
    }


@app.post("/predict", response_model=PredictionResponse)
async def predict(
    file: UploadFile = File(...),
    top_k: int = TOP_K_DEFAULT,
    use_tta: bool = True,
):
    if not file.content_type or not file.content_type.startswith("image/"):
        raise HTTPException(
            status_code=400,
            detail="File harus berupa gambar."
        )

    try:
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception as exc:
        raise HTTPException(
            status_code=400,
            detail=f"Gagal membaca gambar: {exc}"
        )

    try:
        result = predict_image(
            image=image,
            top_k=top_k,
            use_tta=use_tta,
        )
        return result
    except Exception as exc:
        raise HTTPException(
            status_code=500,
            detail=f"Gagal melakukan prediksi: {exc}"
        )