import os import io import json from pathlib import Path from typing import List, Dict, Any, Optional import torch from PIL import Image, ImageOps from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoImageProcessor, AutoModelForImageClassification MODEL_DIR = os.getenv("MODEL_DIR", "model") TOP_K_DEFAULT = int(os.getenv("TOP_K", "5")) # Isi CORS_ORIGINS bisa: # CORS_ORIGINS=http://localhost:3000,https://nama-app.vercel.app cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:3000") CORS_ORIGINS = [origin.strip() for origin in cors_origins_env.split(",") if origin.strip()] device = "cuda" if torch.cuda.is_available() else "cpu" app = FastAPI( title="Batik ViT Classifier API", description="API klasifikasi jenis batik menggunakan Vision Transformer", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class PredictionItem(BaseModel): label: str confidence: float class PredictionResponse(BaseModel): status: str reason: str top_prediction: PredictionItem second_prediction: Optional[PredictionItem] margin: float predictions: List[PredictionItem] processor = None model = None model_info: Dict[str, Any] = {} def load_model() -> None: global processor, model, model_info hf_token = os.getenv("HF_TOKEN") model_source = MODEL_DIR local_model_path = Path(MODEL_DIR) if local_model_path.exists(): model_source = str(local_model_path.resolve()) print(f"Loading local model from: {model_source}") else: print(f"Loading remote model from Hugging Face Hub: {model_source}") model_kwargs = {} if hf_token: model_kwargs["token"] = hf_token processor = AutoImageProcessor.from_pretrained(model_source, **model_kwargs) model = AutoModelForImageClassification.from_pretrained(model_source, **model_kwargs) model.to(device) model.eval() model_info = {} info_path = local_model_path / "model_info.json" if local_model_path.exists() and info_path.exists(): with open(info_path, "r", encoding="utf-8") as f: model_info = json.load(f) print(f"Model loaded from: {model_source}") print(f"Device: {device}") @app.on_event("startup") def startup_event(): load_model() def get_status(label: str, top1_conf: float, margin: float) -> tuple[str, str]: """ Logic status final. Kelas Parang dibuat lebih ketat karena Solo_Parang dan Yogyakarta_Parang cenderung mirip dan sering tertukar. """ parang_classes = {"Solo_Parang", "Yogyakarta_Parang"} if label in parang_classes: if top1_conf >= 0.75 and margin >= 0.30: return ( "Model yakin", "Prediksi kelas Parang memiliki confidence tinggi dan margin cukup aman." ) if top1_conf >= 0.50 and margin >= 0.25: return ( "Model cukup yakin", "Prediksi kelas Parang cukup kuat, tetapi tetap perlu hati-hati karena kelas Parang mirip." ) return ( "Model belum yakin", "Prediksi kelas Parang belum cukup aman karena confidence atau margin masih rendah." ) if top1_conf >= 0.60 and margin >= 0.20: return ( "Model yakin", "Confidence tinggi dan jarak prediksi pertama dengan kedua cukup jauh." ) if top1_conf >= 0.40 and margin >= 0.25: return ( "Model cukup yakin", "Confidence sedang, tetapi prediksi pertama jauh lebih dominan dari prediksi kedua." ) if top1_conf >= 0.35 and margin >= 0.35: return ( "Model cukup yakin", "Confidence tidak terlalu tinggi, tetapi prediksi pertama sangat jauh dari prediksi kedua." ) return ( "Model belum yakin", "Confidence rendah atau prediksi pertama terlalu dekat dengan prediksi kedua." ) def predict_image(image: Image.Image, top_k: int = TOP_K_DEFAULT, use_tta: bool = True) -> Dict[str, Any]: if processor is None or model is None: raise RuntimeError("Model belum diload.") image = image.convert("RGB") if use_tta: images = [image, ImageOps.mirror(image)] else: images = [image] inputs = processor(images=images, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Rata-rata logits original + mirror avg_logits = logits.mean(dim=0, keepdim=True) probs = torch.softmax(avg_logits, dim=-1)[0] max_k = min(top_k, probs.shape[-1]) top_probs, top_indices = torch.topk(probs, k=max_k) predictions = [] for prob, idx in zip(top_probs, top_indices): idx_int = int(idx.item()) label = model.config.id2label.get(idx_int, str(idx_int)) confidence = float(prob.item()) predictions.append({ "label": label, "confidence": confidence, }) top1 = predictions[0] top2 = predictions[1] if len(predictions) > 1 else None top1_conf = top1["confidence"] top2_conf = top2["confidence"] if top2 else 0.0 margin = top1_conf - top2_conf status, reason = get_status( label=top1["label"], top1_conf=top1_conf, margin=margin, ) return { "status": status, "reason": reason, "top_prediction": top1, "second_prediction": top2, "margin": margin, "predictions": predictions, } @app.get("/") def root(): return { "message": "Batik ViT Classifier API", "docs": "/docs", "health": "/health", } @app.get("/health") def health(): return { "status": "ok", "device": device, "model_dir": str(Path(MODEL_DIR).resolve()), "num_labels": getattr(model.config, "num_labels", None) if model else None, "cors_origins": CORS_ORIGINS, } @app.get("/model-info") def get_model_info(): return { "model_info": model_info, "labels": getattr(model.config, "id2label", {}) if model else {}, } @app.post("/predict", response_model=PredictionResponse) async def predict( file: UploadFile = File(...), top_k: int = TOP_K_DEFAULT, use_tta: bool = True, ): if not file.content_type or not file.content_type.startswith("image/"): raise HTTPException( status_code=400, detail="File harus berupa gambar." ) try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as exc: raise HTTPException( status_code=400, detail=f"Gagal membaca gambar: {exc}" ) try: result = predict_image( image=image, top_k=top_k, use_tta=use_tta, ) return result except Exception as exc: raise HTTPException( status_code=500, detail=f"Gagal melakukan prediksi: {exc}" )