File size: 5,158 Bytes
c34dda4
 
 
361c20d
c34dda4
361c20d
c34dda4
 
361c20d
c34dda4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a47fe92
c34dda4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b11294
c34dda4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b11294
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# app/inference.py
import os
import io
import json
import base64
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification

# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
MODEL_REPO = "Arew99/dinov2-costum"  # your Hugging Face repo
ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🧠 Using device: {DEVICE}")

_model = None
_processor = None
_id2name = None


# ─────────────────────────────────────────────
# HELPER β€” load id2name mapping
# ─────────────────────────────────────────────
def _load_id2name():
    if os.path.exists(ID2NAME_PATH):
        with open(ID2NAME_PATH, "r") as f:
            data = json.load(f)
        return {int(k): v for k, v in data.items()}
    print("⚠️ id2name.json not found β€” using placeholder labels.")
    return {i: f"Class {i}" for i in range(101)}


# ─────────────────────────────────────────────
# LOAD MODEL (cached globally)
# ─────────────────────────────────────────────
def load_classification_model():
    global _model, _processor, _id2name

    if _model is not None:
        return _model, _processor, _id2name

    print(f"πŸ” Loading model from Hugging Face repo: {MODEL_REPO}")
    _processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
    _model = AutoModelForImageClassification.from_pretrained(
        MODEL_REPO,
        ignore_mismatched_sizes=True,
    ).to(DEVICE)
    _model.eval()
    _id2name = _load_id2name()

    print(f"βœ… Model loaded and ready on {DEVICE}")
    return _model, _processor, _id2name


# ─────────────────────────────────────────────
# CLASSIFY IMAGE BYTES
# ─────────────────────────────────────────────
def classify_bytes(image_bytes: bytes):
    model, processor, id2name = load_classification_model()

    # Load and preprocess image
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    inputs = processor(images=image, return_tensors="pt", padding="True").to(DEVICE)

    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)

    # Top-5 predictions
    topk = torch.topk(probs, k=5)
    indices = topk.indices[0].tolist()
    values = topk.values[0].tolist()

    results = []
    for rank, (idx, prob) in enumerate(zip(indices, values), 1):
        label = id2name.get(int(idx), f"Class {idx}")
        results.append({
            "rank": rank,
            "id": int(idx),
            "label": label,
            "score": float(prob),
        })

    # ───────────────────────────────
    # MATPLOTLIB TOP-3 PLOT
    # ───────────────────────────────
    top3 = results[:3]
    labels = [p["label"] for p in top3]
    probs_top3 = [p["score"] * 100 for p in top3]

    plt.style.use("seaborn-v0_8-whitegrid")
    fig, ax = plt.subplots(1, 2, figsize=(9, 4))

    # Left: input image
    ax[0].imshow(image)
    ax[0].axis("off")
    ax[0].set_title("Input Image", fontsize=12, weight="bold")

    # Right: horizontal bar chart
    bars = ax[1].barh(labels[::-1], probs_top3[::-1],
                      color=["#C44E52", "#55A868", "#4C72B0"],
                      edgecolor="none", height=0.6)
    ax[1].set_xlim(0, 100)
    ax[1].set_xlabel("Probability (%)", fontsize=11)
    ax[1].set_title("Top-3 Predicted Species (DINOv2-G)", fontsize=12, weight="bold")

    for bar, prob in zip(bars, probs_top3[::-1]):
        ax[1].text(prob + 1, bar.get_y() + bar.get_height()/2,
                   f"{prob:.1f}%", va="center", fontsize=10, weight="bold")

    plt.tight_layout()

    # Encode plot as base64
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    plot_b64 = base64.b64encode(buf.read()).decode("utf-8")
    buf.close()

    # ───────────────────────────────
    # FINAL OUTPUT
    # ───────────────────────────────
    return {
        "top1": results[0],
        "top5": results,
        "plot": f"data:image/png;base64,{plot_b64}"
    }