Spaces:
Configuration error
Configuration error
| import torch | |
| from torchvision.models import efficientnet_b7 | |
| from torchvision.transforms import InterpolationMode | |
| from torchvision import transforms | |
| from PIL import Image | |
| from torch.nn import functional as F | |
| class EfficientNetB7: | |
| def __init__(self, weights_path: str): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((600), interpolation=InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(600), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| self.model = efficientnet_b7() | |
| self.model.classifier[1] = torch.nn.Linear(2560, 7) | |
| state_dict = torch.load(weights_path, map_location=torch.device(self.device)) | |
| self.model.load_state_dict(state_dict) | |
| self.model.eval() | |
| def make_prediction(self, image: Image): | |
| image = self.transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred = F.softmax(self.model(image), dim=1) | |
| return pred |