|
|
|
|
|
import os |
|
|
import io |
|
|
import json |
|
|
import base64 |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REPO = "Arew99/dinov2-costum" |
|
|
ID2NAME_PATH = os.path.join(os.path.dirname(__file__), "id2name.json") |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"π§ Using device: {DEVICE}") |
|
|
|
|
|
_model = None |
|
|
_processor = None |
|
|
_id2name = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_id2name(): |
|
|
if os.path.exists(ID2NAME_PATH): |
|
|
with open(ID2NAME_PATH, "r") as f: |
|
|
data = json.load(f) |
|
|
return {int(k): v for k, v in data.items()} |
|
|
print("β οΈ id2name.json not found β using placeholder labels.") |
|
|
return {i: f"Class {i}" for i in range(101)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_classification_model(): |
|
|
global _model, _processor, _id2name |
|
|
|
|
|
if _model is not None: |
|
|
return _model, _processor, _id2name |
|
|
|
|
|
print(f"π Loading model from Hugging Face repo: {MODEL_REPO}") |
|
|
_processor = AutoImageProcessor.from_pretrained(MODEL_REPO) |
|
|
_model = AutoModelForImageClassification.from_pretrained( |
|
|
MODEL_REPO, |
|
|
ignore_mismatched_sizes=True, |
|
|
).to(DEVICE) |
|
|
_model.eval() |
|
|
_id2name = _load_id2name() |
|
|
|
|
|
print(f"β
Model loaded and ready on {DEVICE}") |
|
|
return _model, _processor, _id2name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_bytes(image_bytes: bytes): |
|
|
model, processor, id2name = load_classification_model() |
|
|
|
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
inputs = processor(images=image, return_tensors="pt", padding="True").to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = F.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
|
|
|
topk = torch.topk(probs, k=5) |
|
|
indices = topk.indices[0].tolist() |
|
|
values = topk.values[0].tolist() |
|
|
|
|
|
results = [] |
|
|
for rank, (idx, prob) in enumerate(zip(indices, values), 1): |
|
|
label = id2name.get(int(idx), f"Class {idx}") |
|
|
results.append({ |
|
|
"rank": rank, |
|
|
"id": int(idx), |
|
|
"label": label, |
|
|
"score": float(prob), |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
top3 = results[:3] |
|
|
labels = [p["label"] for p in top3] |
|
|
probs_top3 = [p["score"] * 100 for p in top3] |
|
|
|
|
|
plt.style.use("seaborn-v0_8-whitegrid") |
|
|
fig, ax = plt.subplots(1, 2, figsize=(9, 4)) |
|
|
|
|
|
|
|
|
ax[0].imshow(image) |
|
|
ax[0].axis("off") |
|
|
ax[0].set_title("Input Image", fontsize=12, weight="bold") |
|
|
|
|
|
|
|
|
bars = ax[1].barh(labels[::-1], probs_top3[::-1], |
|
|
color=["#C44E52", "#55A868", "#4C72B0"], |
|
|
edgecolor="none", height=0.6) |
|
|
ax[1].set_xlim(0, 100) |
|
|
ax[1].set_xlabel("Probability (%)", fontsize=11) |
|
|
ax[1].set_title("Top-3 Predicted Species (DINOv2-G)", fontsize=12, weight="bold") |
|
|
|
|
|
for bar, prob in zip(bars, probs_top3[::-1]): |
|
|
ax[1].text(prob + 1, bar.get_y() + bar.get_height()/2, |
|
|
f"{prob:.1f}%", va="center", fontsize=10, weight="bold") |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format="png", bbox_inches="tight") |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
plot_b64 = base64.b64encode(buf.read()).decode("utf-8") |
|
|
buf.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"top1": results[0], |
|
|
"top5": results, |
|
|
"plot": f"data:image/png;base64,{plot_b64}" |
|
|
} |