keduClassifier / inference.py
Hajorda's picture
Upload inference.py with huggingface_hub
3696b3b verified
import torch
import pytorch_lightning as pl
from timm import create_model
import torch.nn as nn
from box import Box
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import pickle
from PIL import Image
import numpy as np
import os
# --- Re-define necessary parts of your config ---
cfg_dict_for_inference = {
'model_name': 'swin_tiny_patch4_window7_224',
'dropout_backbone': 0.1,
'dropout_fc': 0.2,
'img_size': (224, 224),
'num_classes': 37,
}
cfg_inference = Box(cfg_dict_for_inference)
# --- (Paste the PetBreedModel class definition from Cell 5 here) ---
class PetBreedModel(pl.LightningModule):
def __init__(self, cfg: Box):
super().__init__()
self.cfg = cfg
self.backbone = create_model(
self.cfg.model_name, pretrained=False, num_classes=0,
in_chans=3, drop_rate=self.cfg.dropout_backbone
)
h, w = self.cfg.img_size if isinstance(self.cfg.img_size, tuple) else (224,224)
dummy_input = torch.randn(1, 3, h, w)
with torch.no_grad(): num_features = self.backbone(dummy_input).shape[-1]
self.fc = nn.Sequential(
nn.Linear(num_features, num_features // 2), nn.ReLU(),
nn.Dropout(self.cfg.dropout_fc),
nn.Linear(num_features // 2, self.cfg.num_classes)
)
def forward(self, x):
features = self.backbone(x); output = self.fc(features)
return output
# --- End of PetBreedModel definition ---
def load_model_from_hf(repo_id="Hajorda/keduClassifier", ckpt_filename="pytorch_model.ckpt"):
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id=repo_id, filename=ckpt_filename)
if cfg_inference.num_classes is None: # This should be caught by the definition of cfg_num_classes_val above
raise ValueError("num_classes must be set in cfg_inference to load the model.")
loaded_model = PetBreedModel.load_from_checkpoint(model_path, cfg=cfg_inference, strict=False)
loaded_model.eval()
return loaded_model
def load_label_encoder_from_hf(repo_id="Hajorda/keduClassifier", le_filename="label_encoder.pkl"):
from huggingface_hub import hf_hub_download
le_path = hf_hub_download(repo_id=repo_id, filename=le_filename)
with open(le_path, 'rb') as f: label_encoder = pickle.load(f)
return label_encoder
def predict_breed(image_path, model, label_encoder, device='cpu'):
model.to(device); img = cv2.imread(image_path); img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = cfg_inference.img_size if isinstance(cfg_inference.img_size, tuple) else (224,224)
transforms = A.Compose([
A.Resize(height=h, width=w),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(),])
input_tensor = transforms(image=img)['image'].unsqueeze(0).to(device)
with torch.no_grad():
logits = model(input_tensor); probabilities = torch.softmax(logits, dim=1)
top_prob, top_class_idx = torch.max(probabilities, dim=1)
predicted_breed_id = top_class_idx.item()
predicted_breed_name = label_encoder.inverse_transform([predicted_breed_id])[0]
confidence = top_prob.item()
return predicted_breed_name, confidence