# app/inference.py import os import io import json import base64 import torch import torch.nn.functional as F import matplotlib.pyplot as plt from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification # ───────────────────────────────────────────── # CONFIG # ───────────────────────────────────────────── MODEL_REPO = "Arew99/dinov2-costum" # your Hugging Face repo ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🧠 Using device: {DEVICE}") _model = None _processor = None _id2name = None # ───────────────────────────────────────────── # HELPER — load id2name mapping # ───────────────────────────────────────────── def _load_id2name(): if os.path.exists(ID2NAME_PATH): with open(ID2NAME_PATH, "r") as f: data = json.load(f) return {int(k): v for k, v in data.items()} print("⚠️ id2name.json not found — using placeholder labels.") return {i: f"Class {i}" for i in range(101)} # ───────────────────────────────────────────── # LOAD MODEL (cached globally) # ───────────────────────────────────────────── def load_classification_model(): global _model, _processor, _id2name if _model is not None: return _model, _processor, _id2name print(f"🔁 Loading model from Hugging Face repo: {MODEL_REPO}") _processor = AutoImageProcessor.from_pretrained(MODEL_REPO) _model = AutoModelForImageClassification.from_pretrained( MODEL_REPO, ignore_mismatched_sizes=True, ).to(DEVICE) _model.eval() _id2name = _load_id2name() print(f"✅ Model loaded and ready on {DEVICE}") return _model, _processor, _id2name # ───────────────────────────────────────────── # CLASSIFY IMAGE BYTES # ───────────────────────────────────────────── def classify_bytes(image_bytes: bytes): model, processor, id2name = load_classification_model() # Load and preprocess image image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = processor(images=image, return_tensors="pt", padding="True").to(DEVICE) # Forward pass with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1) # Top-5 predictions topk = torch.topk(probs, k=5) indices = topk.indices[0].tolist() values = topk.values[0].tolist() results = [] for rank, (idx, prob) in enumerate(zip(indices, values), 1): label = id2name.get(int(idx), f"Class {idx}") results.append({ "rank": rank, "id": int(idx), "label": label, "score": float(prob), }) # ─────────────────────────────── # MATPLOTLIB TOP-3 PLOT # ─────────────────────────────── top3 = results[:3] labels = [p["label"] for p in top3] probs_top3 = [p["score"] * 100 for p in top3] plt.style.use("seaborn-v0_8-whitegrid") fig, ax = plt.subplots(1, 2, figsize=(9, 4)) # Left: input image ax[0].imshow(image) ax[0].axis("off") ax[0].set_title("Input Image", fontsize=12, weight="bold") # Right: horizontal bar chart bars = ax[1].barh(labels[::-1], probs_top3[::-1], color=["#C44E52", "#55A868", "#4C72B0"], edgecolor="none", height=0.6) ax[1].set_xlim(0, 100) ax[1].set_xlabel("Probability (%)", fontsize=11) ax[1].set_title("Top-3 Predicted Species (DINOv2-G)", fontsize=12, weight="bold") for bar, prob in zip(bars, probs_top3[::-1]): ax[1].text(prob + 1, bar.get_y() + bar.get_height()/2, f"{prob:.1f}%", va="center", fontsize=10, weight="bold") plt.tight_layout() # Encode plot as base64 buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") plt.close(fig) buf.seek(0) plot_b64 = base64.b64encode(buf.read()).decode("utf-8") buf.close() # ─────────────────────────────── # FINAL OUTPUT # ─────────────────────────────── return { "top1": results[0], "top5": results, "plot": f"data:image/png;base64,{plot_b64}" }