batik-vit-api / main.py
JustFadjrin's picture
Deploy Batik ViT FastAPI backend
fa6f400
import os
import io
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
import torch
from PIL import Image, ImageOps
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODEL_DIR = os.getenv("MODEL_DIR", "model")
TOP_K_DEFAULT = int(os.getenv("TOP_K", "5"))
# Isi CORS_ORIGINS bisa:
# CORS_ORIGINS=http://localhost:3000,https://nama-app.vercel.app
cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:3000")
CORS_ORIGINS = [origin.strip() for origin in cors_origins_env.split(",") if origin.strip()]
device = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI(
title="Batik ViT Classifier API",
description="API klasifikasi jenis batik menggunakan Vision Transformer",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionItem(BaseModel):
label: str
confidence: float
class PredictionResponse(BaseModel):
status: str
reason: str
top_prediction: PredictionItem
second_prediction: Optional[PredictionItem]
margin: float
predictions: List[PredictionItem]
processor = None
model = None
model_info: Dict[str, Any] = {}
def load_model() -> None:
global processor, model, model_info
hf_token = os.getenv("HF_TOKEN")
model_source = MODEL_DIR
local_model_path = Path(MODEL_DIR)
if local_model_path.exists():
model_source = str(local_model_path.resolve())
print(f"Loading local model from: {model_source}")
else:
print(f"Loading remote model from Hugging Face Hub: {model_source}")
model_kwargs = {}
if hf_token:
model_kwargs["token"] = hf_token
processor = AutoImageProcessor.from_pretrained(model_source, **model_kwargs)
model = AutoModelForImageClassification.from_pretrained(model_source, **model_kwargs)
model.to(device)
model.eval()
model_info = {}
info_path = local_model_path / "model_info.json"
if local_model_path.exists() and info_path.exists():
with open(info_path, "r", encoding="utf-8") as f:
model_info = json.load(f)
print(f"Model loaded from: {model_source}")
print(f"Device: {device}")
@app.on_event("startup")
def startup_event():
load_model()
def get_status(label: str, top1_conf: float, margin: float) -> tuple[str, str]:
"""
Logic status final.
Kelas Parang dibuat lebih ketat karena Solo_Parang dan Yogyakarta_Parang
cenderung mirip dan sering tertukar.
"""
parang_classes = {"Solo_Parang", "Yogyakarta_Parang"}
if label in parang_classes:
if top1_conf >= 0.75 and margin >= 0.30:
return (
"Model yakin",
"Prediksi kelas Parang memiliki confidence tinggi dan margin cukup aman."
)
if top1_conf >= 0.50 and margin >= 0.25:
return (
"Model cukup yakin",
"Prediksi kelas Parang cukup kuat, tetapi tetap perlu hati-hati karena kelas Parang mirip."
)
return (
"Model belum yakin",
"Prediksi kelas Parang belum cukup aman karena confidence atau margin masih rendah."
)
if top1_conf >= 0.60 and margin >= 0.20:
return (
"Model yakin",
"Confidence tinggi dan jarak prediksi pertama dengan kedua cukup jauh."
)
if top1_conf >= 0.40 and margin >= 0.25:
return (
"Model cukup yakin",
"Confidence sedang, tetapi prediksi pertama jauh lebih dominan dari prediksi kedua."
)
if top1_conf >= 0.35 and margin >= 0.35:
return (
"Model cukup yakin",
"Confidence tidak terlalu tinggi, tetapi prediksi pertama sangat jauh dari prediksi kedua."
)
return (
"Model belum yakin",
"Confidence rendah atau prediksi pertama terlalu dekat dengan prediksi kedua."
)
def predict_image(image: Image.Image, top_k: int = TOP_K_DEFAULT, use_tta: bool = True) -> Dict[str, Any]:
if processor is None or model is None:
raise RuntimeError("Model belum diload.")
image = image.convert("RGB")
if use_tta:
images = [image, ImageOps.mirror(image)]
else:
images = [image]
inputs = processor(images=images, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Rata-rata logits original + mirror
avg_logits = logits.mean(dim=0, keepdim=True)
probs = torch.softmax(avg_logits, dim=-1)[0]
max_k = min(top_k, probs.shape[-1])
top_probs, top_indices = torch.topk(probs, k=max_k)
predictions = []
for prob, idx in zip(top_probs, top_indices):
idx_int = int(idx.item())
label = model.config.id2label.get(idx_int, str(idx_int))
confidence = float(prob.item())
predictions.append({
"label": label,
"confidence": confidence,
})
top1 = predictions[0]
top2 = predictions[1] if len(predictions) > 1 else None
top1_conf = top1["confidence"]
top2_conf = top2["confidence"] if top2 else 0.0
margin = top1_conf - top2_conf
status, reason = get_status(
label=top1["label"],
top1_conf=top1_conf,
margin=margin,
)
return {
"status": status,
"reason": reason,
"top_prediction": top1,
"second_prediction": top2,
"margin": margin,
"predictions": predictions,
}
@app.get("/")
def root():
return {
"message": "Batik ViT Classifier API",
"docs": "/docs",
"health": "/health",
}
@app.get("/health")
def health():
return {
"status": "ok",
"device": device,
"model_dir": str(Path(MODEL_DIR).resolve()),
"num_labels": getattr(model.config, "num_labels", None) if model else None,
"cors_origins": CORS_ORIGINS,
}
@app.get("/model-info")
def get_model_info():
return {
"model_info": model_info,
"labels": getattr(model.config, "id2label", {}) if model else {},
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(
file: UploadFile = File(...),
top_k: int = TOP_K_DEFAULT,
use_tta: bool = True,
):
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="File harus berupa gambar."
)
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as exc:
raise HTTPException(
status_code=400,
detail=f"Gagal membaca gambar: {exc}"
)
try:
result = predict_image(
image=image,
top_k=top_k,
use_tta=use_tta,
)
return result
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"Gagal melakukan prediksi: {exc}"
)