| import torch | |
| import torchvision | |
| import cv2 | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| from .utils import to_rgb | |
| def create_effNetv2s(): | |
| model = torchvision.models.efficientnet_v2_s(weights='IMAGENET1K_V1') | |
| num_features = model.classifier[1].in_features | |
| model.classifier[1] = nn.Sequential( | |
| nn.Linear(num_features, NUM_CLASSES), | |
| nn.Sigmoid() | |
| ) | |
| return model | |
| def create_convnet(): | |
| model = torchvision.models.convnext_base(weights='IMAGENET1K_V1') | |
| num_features = model.classifier[2].in_features | |
| model.classifier[2] = nn.Sequential( | |
| nn.Linear(num_features, NUM_CLASSES), | |
| nn.Sigmoid() | |
| ) | |
| return model | |
| def create_model(model_name): | |
| model = _MODEL[model_name]() | |
| model.load_state_dict(torch.load(_WEIGHT[model_name], map_location=torch.device('cpu'))) | |
| model.to(DEVICE) | |
| return model | |
| def create_transform(): | |
| transform = transforms.Compose([ | |
| transforms.Resize((HEIGHT, WEIGHT)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.6078, 0.6078, 0.6078), (0.1932, 0.1932, 0.1932)) | |
| ]) | |
| return transform | |
| _MODEL = { | |
| "effNetv2s": create_effNetv2s, | |
| "convnet": create_convnet | |
| } | |
| _WEIGHT = { | |
| "effNetv2s": './classification/weights/effnetv2s.pt', | |
| "convnet": './classification/weights/convnet.pt', | |
| } | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| HEIGHT = 224 | |
| WEIGHT = 224 | |
| NUM_CLASSES = 44 | |
| class Classifier(): | |
| def __init__(self, model_name="effNetv2s"): | |
| self.model = create_model(model_name) | |
| self.transform = create_transform() | |
| def predict(self, image): | |
| ''' | |
| input: cv2 image | |
| output: multi-label probability vector | |
| ''' | |
| image = to_rgb(image) | |
| image = Image.fromarray(image) | |
| image = self.transform(image) | |
| self.model.eval() | |
| with torch.no_grad(): | |
| out = self.model(image.unsqueeze(0).to(DEVICE)).cpu().numpy() | |
| return out | |