import torch from torchvision.models import convnext_base, ConvNeXt_Base_Weights from PIL import Image import torch.nn.functional as F class ConvNextBase: def __init__(self, weights_path: str): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = convnext_base() self.model.classifier[2] = torch.nn.Linear(self.model.classifier[2].in_features, 7) state_dict = torch.load(weights_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.eval() self.transform = ConvNeXt_Base_Weights.IMAGENET1K_V1.transforms() def make_prediction(self, image: Image): image = self.transform(image).unsqueeze(0) with torch.no_grad(): pred = F.softmax(self.model(image), dim=1) return pred