Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| import torch | |
| from PIL import Image, ImageOps | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| MODEL_DIR = os.getenv("MODEL_DIR", "model") | |
| TOP_K_DEFAULT = int(os.getenv("TOP_K", "5")) | |
| # Isi CORS_ORIGINS bisa: | |
| # CORS_ORIGINS=http://localhost:3000,https://nama-app.vercel.app | |
| cors_origins_env = os.getenv("CORS_ORIGINS", "http://localhost:3000") | |
| CORS_ORIGINS = [origin.strip() for origin in cors_origins_env.split(",") if origin.strip()] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| app = FastAPI( | |
| title="Batik ViT Classifier API", | |
| description="API klasifikasi jenis batik menggunakan Vision Transformer", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=CORS_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class PredictionItem(BaseModel): | |
| label: str | |
| confidence: float | |
| class PredictionResponse(BaseModel): | |
| status: str | |
| reason: str | |
| top_prediction: PredictionItem | |
| second_prediction: Optional[PredictionItem] | |
| margin: float | |
| predictions: List[PredictionItem] | |
| processor = None | |
| model = None | |
| model_info: Dict[str, Any] = {} | |
| def load_model() -> None: | |
| global processor, model, model_info | |
| hf_token = os.getenv("HF_TOKEN") | |
| model_source = MODEL_DIR | |
| local_model_path = Path(MODEL_DIR) | |
| if local_model_path.exists(): | |
| model_source = str(local_model_path.resolve()) | |
| print(f"Loading local model from: {model_source}") | |
| else: | |
| print(f"Loading remote model from Hugging Face Hub: {model_source}") | |
| model_kwargs = {} | |
| if hf_token: | |
| model_kwargs["token"] = hf_token | |
| processor = AutoImageProcessor.from_pretrained(model_source, **model_kwargs) | |
| model = AutoModelForImageClassification.from_pretrained(model_source, **model_kwargs) | |
| model.to(device) | |
| model.eval() | |
| model_info = {} | |
| info_path = local_model_path / "model_info.json" | |
| if local_model_path.exists() and info_path.exists(): | |
| with open(info_path, "r", encoding="utf-8") as f: | |
| model_info = json.load(f) | |
| print(f"Model loaded from: {model_source}") | |
| print(f"Device: {device}") | |
| def startup_event(): | |
| load_model() | |
| def get_status(label: str, top1_conf: float, margin: float) -> tuple[str, str]: | |
| """ | |
| Logic status final. | |
| Kelas Parang dibuat lebih ketat karena Solo_Parang dan Yogyakarta_Parang | |
| cenderung mirip dan sering tertukar. | |
| """ | |
| parang_classes = {"Solo_Parang", "Yogyakarta_Parang"} | |
| if label in parang_classes: | |
| if top1_conf >= 0.75 and margin >= 0.30: | |
| return ( | |
| "Model yakin", | |
| "Prediksi kelas Parang memiliki confidence tinggi dan margin cukup aman." | |
| ) | |
| if top1_conf >= 0.50 and margin >= 0.25: | |
| return ( | |
| "Model cukup yakin", | |
| "Prediksi kelas Parang cukup kuat, tetapi tetap perlu hati-hati karena kelas Parang mirip." | |
| ) | |
| return ( | |
| "Model belum yakin", | |
| "Prediksi kelas Parang belum cukup aman karena confidence atau margin masih rendah." | |
| ) | |
| if top1_conf >= 0.60 and margin >= 0.20: | |
| return ( | |
| "Model yakin", | |
| "Confidence tinggi dan jarak prediksi pertama dengan kedua cukup jauh." | |
| ) | |
| if top1_conf >= 0.40 and margin >= 0.25: | |
| return ( | |
| "Model cukup yakin", | |
| "Confidence sedang, tetapi prediksi pertama jauh lebih dominan dari prediksi kedua." | |
| ) | |
| if top1_conf >= 0.35 and margin >= 0.35: | |
| return ( | |
| "Model cukup yakin", | |
| "Confidence tidak terlalu tinggi, tetapi prediksi pertama sangat jauh dari prediksi kedua." | |
| ) | |
| return ( | |
| "Model belum yakin", | |
| "Confidence rendah atau prediksi pertama terlalu dekat dengan prediksi kedua." | |
| ) | |
| def predict_image(image: Image.Image, top_k: int = TOP_K_DEFAULT, use_tta: bool = True) -> Dict[str, Any]: | |
| if processor is None or model is None: | |
| raise RuntimeError("Model belum diload.") | |
| image = image.convert("RGB") | |
| if use_tta: | |
| images = [image, ImageOps.mirror(image)] | |
| else: | |
| images = [image] | |
| inputs = processor(images=images, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Rata-rata logits original + mirror | |
| avg_logits = logits.mean(dim=0, keepdim=True) | |
| probs = torch.softmax(avg_logits, dim=-1)[0] | |
| max_k = min(top_k, probs.shape[-1]) | |
| top_probs, top_indices = torch.topk(probs, k=max_k) | |
| predictions = [] | |
| for prob, idx in zip(top_probs, top_indices): | |
| idx_int = int(idx.item()) | |
| label = model.config.id2label.get(idx_int, str(idx_int)) | |
| confidence = float(prob.item()) | |
| predictions.append({ | |
| "label": label, | |
| "confidence": confidence, | |
| }) | |
| top1 = predictions[0] | |
| top2 = predictions[1] if len(predictions) > 1 else None | |
| top1_conf = top1["confidence"] | |
| top2_conf = top2["confidence"] if top2 else 0.0 | |
| margin = top1_conf - top2_conf | |
| status, reason = get_status( | |
| label=top1["label"], | |
| top1_conf=top1_conf, | |
| margin=margin, | |
| ) | |
| return { | |
| "status": status, | |
| "reason": reason, | |
| "top_prediction": top1, | |
| "second_prediction": top2, | |
| "margin": margin, | |
| "predictions": predictions, | |
| } | |
| def root(): | |
| return { | |
| "message": "Batik ViT Classifier API", | |
| "docs": "/docs", | |
| "health": "/health", | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "device": device, | |
| "model_dir": str(Path(MODEL_DIR).resolve()), | |
| "num_labels": getattr(model.config, "num_labels", None) if model else None, | |
| "cors_origins": CORS_ORIGINS, | |
| } | |
| def get_model_info(): | |
| return { | |
| "model_info": model_info, | |
| "labels": getattr(model.config, "id2label", {}) if model else {}, | |
| } | |
| async def predict( | |
| file: UploadFile = File(...), | |
| top_k: int = TOP_K_DEFAULT, | |
| use_tta: bool = True, | |
| ): | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File harus berupa gambar." | |
| ) | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Gagal membaca gambar: {exc}" | |
| ) | |
| try: | |
| result = predict_image( | |
| image=image, | |
| top_k=top_k, | |
| use_tta=use_tta, | |
| ) | |
| return result | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Gagal melakukan prediksi: {exc}" | |
| ) | |