NEMOtools / app /inference.py
AndrewKof's picture
πŸš€ Update UI with LFS for images and models
5b11294
# app/inference.py
import os
import io
import json
import base64
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
MODEL_REPO = "Arew99/dinov2-costum" # your Hugging Face repo
ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🧠 Using device: {DEVICE}")
_model = None
_processor = None
_id2name = None
# ─────────────────────────────────────────────
# HELPER β€” load id2name mapping
# ─────────────────────────────────────────────
def _load_id2name():
if os.path.exists(ID2NAME_PATH):
with open(ID2NAME_PATH, "r") as f:
data = json.load(f)
return {int(k): v for k, v in data.items()}
print("⚠️ id2name.json not found β€” using placeholder labels.")
return {i: f"Class {i}" for i in range(101)}
# ─────────────────────────────────────────────
# LOAD MODEL (cached globally)
# ─────────────────────────────────────────────
def load_classification_model():
global _model, _processor, _id2name
if _model is not None:
return _model, _processor, _id2name
print(f"πŸ” Loading model from Hugging Face repo: {MODEL_REPO}")
_processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
_model = AutoModelForImageClassification.from_pretrained(
MODEL_REPO,
ignore_mismatched_sizes=True,
).to(DEVICE)
_model.eval()
_id2name = _load_id2name()
print(f"βœ… Model loaded and ready on {DEVICE}")
return _model, _processor, _id2name
# ─────────────────────────────────────────────
# CLASSIFY IMAGE BYTES
# ─────────────────────────────────────────────
def classify_bytes(image_bytes: bytes):
model, processor, id2name = load_classification_model()
# Load and preprocess image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt", padding="True").to(DEVICE)
# Forward pass
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
# Top-5 predictions
topk = torch.topk(probs, k=5)
indices = topk.indices[0].tolist()
values = topk.values[0].tolist()
results = []
for rank, (idx, prob) in enumerate(zip(indices, values), 1):
label = id2name.get(int(idx), f"Class {idx}")
results.append({
"rank": rank,
"id": int(idx),
"label": label,
"score": float(prob),
})
# ───────────────────────────────
# MATPLOTLIB TOP-3 PLOT
# ───────────────────────────────
top3 = results[:3]
labels = [p["label"] for p in top3]
probs_top3 = [p["score"] * 100 for p in top3]
plt.style.use("seaborn-v0_8-whitegrid")
fig, ax = plt.subplots(1, 2, figsize=(9, 4))
# Left: input image
ax[0].imshow(image)
ax[0].axis("off")
ax[0].set_title("Input Image", fontsize=12, weight="bold")
# Right: horizontal bar chart
bars = ax[1].barh(labels[::-1], probs_top3[::-1],
color=["#C44E52", "#55A868", "#4C72B0"],
edgecolor="none", height=0.6)
ax[1].set_xlim(0, 100)
ax[1].set_xlabel("Probability (%)", fontsize=11)
ax[1].set_title("Top-3 Predicted Species (DINOv2-G)", fontsize=12, weight="bold")
for bar, prob in zip(bars, probs_top3[::-1]):
ax[1].text(prob + 1, bar.get_y() + bar.get_height()/2,
f"{prob:.1f}%", va="center", fontsize=10, weight="bold")
plt.tight_layout()
# Encode plot as base64
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close(fig)
buf.seek(0)
plot_b64 = base64.b64encode(buf.read()).decode("utf-8")
buf.close()
# ───────────────────────────────
# FINAL OUTPUT
# ───────────────────────────────
return {
"top1": results[0],
"top5": results,
"plot": f"data:image/png;base64,{plot_b64}"
}