File size: 5,158 Bytes
c34dda4 361c20d c34dda4 361c20d c34dda4 361c20d c34dda4 a47fe92 c34dda4 5b11294 c34dda4 5b11294 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | # app/inference.py
import os
import io
import json
import base64
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
# βββββββββββββββββββββββββββββββββββββββββββββ
# CONFIG
# βββββββββββββββββββββββββββββββββββββββββββββ
MODEL_REPO = "Arew99/dinov2-costum" # your Hugging Face repo
ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"π§ Using device: {DEVICE}")
_model = None
_processor = None
_id2name = None
# βββββββββββββββββββββββββββββββββββββββββββββ
# HELPER β load id2name mapping
# βββββββββββββββββββββββββββββββββββββββββββββ
def _load_id2name():
if os.path.exists(ID2NAME_PATH):
with open(ID2NAME_PATH, "r") as f:
data = json.load(f)
return {int(k): v for k, v in data.items()}
print("β οΈ id2name.json not found β using placeholder labels.")
return {i: f"Class {i}" for i in range(101)}
# βββββββββββββββββββββββββββββββββββββββββββββ
# LOAD MODEL (cached globally)
# βββββββββββββββββββββββββββββββββββββββββββββ
def load_classification_model():
global _model, _processor, _id2name
if _model is not None:
return _model, _processor, _id2name
print(f"π Loading model from Hugging Face repo: {MODEL_REPO}")
_processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
_model = AutoModelForImageClassification.from_pretrained(
MODEL_REPO,
ignore_mismatched_sizes=True,
).to(DEVICE)
_model.eval()
_id2name = _load_id2name()
print(f"β
Model loaded and ready on {DEVICE}")
return _model, _processor, _id2name
# βββββββββββββββββββββββββββββββββββββββββββββ
# CLASSIFY IMAGE BYTES
# βββββββββββββββββββββββββββββββββββββββββββββ
def classify_bytes(image_bytes: bytes):
model, processor, id2name = load_classification_model()
# Load and preprocess image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt", padding="True").to(DEVICE)
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
# Top-5 predictions
topk = torch.topk(probs, k=5)
indices = topk.indices[0].tolist()
values = topk.values[0].tolist()
results = []
for rank, (idx, prob) in enumerate(zip(indices, values), 1):
label = id2name.get(int(idx), f"Class {idx}")
results.append({
"rank": rank,
"id": int(idx),
"label": label,
"score": float(prob),
})
# βββββββββββββββββββββββββββββββ
# MATPLOTLIB TOP-3 PLOT
# βββββββββββββββββββββββββββββββ
top3 = results[:3]
labels = [p["label"] for p in top3]
probs_top3 = [p["score"] * 100 for p in top3]
plt.style.use("seaborn-v0_8-whitegrid")
fig, ax = plt.subplots(1, 2, figsize=(9, 4))
# Left: input image
ax[0].imshow(image)
ax[0].axis("off")
ax[0].set_title("Input Image", fontsize=12, weight="bold")
# Right: horizontal bar chart
bars = ax[1].barh(labels[::-1], probs_top3[::-1],
color=["#C44E52", "#55A868", "#4C72B0"],
edgecolor="none", height=0.6)
ax[1].set_xlim(0, 100)
ax[1].set_xlabel("Probability (%)", fontsize=11)
ax[1].set_title("Top-3 Predicted Species (DINOv2-G)", fontsize=12, weight="bold")
for bar, prob in zip(bars, probs_top3[::-1]):
ax[1].text(prob + 1, bar.get_y() + bar.get_height()/2,
f"{prob:.1f}%", va="center", fontsize=10, weight="bold")
plt.tight_layout()
# Encode plot as base64
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close(fig)
buf.seek(0)
plot_b64 = base64.b64encode(buf.read()).decode("utf-8")
buf.close()
# βββββββββββββββββββββββββββββββ
# FINAL OUTPUT
# βββββββββββββββββββββββββββββββ
return {
"top1": results[0],
"top5": results,
"plot": f"data:image/png;base64,{plot_b64}"
} |