dino_plankton_classifier / inference.py
danielaivanova's picture
Upload folder using huggingface_hub
dff7e68 verified
import torch
from transformers import AutoModel, AutoImageProcessor
from model import DinoV3LinearMultiLinear
def load_model(weights_path, device="cuda"):
"""
Load the pre-trained classifier.
Args:
weights_path: Path to the saved weights (.pt file)
device: Device to load model on ('cuda' or 'cpu')
Returns:
model: Loaded DinoV3LinearMultiLinear model in eval mode
processor: Image processor for preprocessing input images
"""
# Load config
import json
with open("config.json", "r") as f:
config = json.load(f)
# Load backbone
backbone = AutoModel.from_pretrained(config["model_name"])
hidden_size = backbone.config.hidden_size
# Instantiate classifier head
model = DinoV3LinearMultiLinear(
backbone=backbone,
num_classes=config["num_classes"],
hidden_size=hidden_size,
freeze_backbone=True
)
# Load trained weights
model.load_state_dict(torch.load(weights_path, map_location=device)["model_state_dict"])
model.to(device)
model.eval()
# Load image processor
processor = AutoImageProcessor.from_pretrained(config["model_name"])
# Load labels
with open("id2label.json", "r") as f:
id2label = json.load(f)
return model, processor, id2label
def probs_to_labels(probs, id2label):
"""
Convert probability distribution to labels.
"""
predicted_indices = probs.argmax(dim=1)
predicted_labels = [id2label[str(idx.item())] for idx in predicted_indices]
return predicted_labels