Spaces:
Configuration error
Configuration error
| import torch | |
| from torchvision.models import convnext_base, ConvNeXt_Base_Weights | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| class ConvNextBase: | |
| def __init__(self, weights_path: str): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = convnext_base() | |
| self.model.classifier[2] = torch.nn.Linear(self.model.classifier[2].in_features, 7) | |
| state_dict = torch.load(weights_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.eval() | |
| self.transform = ConvNeXt_Base_Weights.IMAGENET1K_V1.transforms() | |
| def make_prediction(self, image: Image): | |
| image = self.transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred = F.softmax(self.model(image), dim=1) | |
| return pred | |