import torch from torchvision.models import efficientnet_b7 from torchvision.transforms import InterpolationMode from torchvision import transforms from PIL import Image from torch.nn import functional as F class EfficientNetB7: def __init__(self, weights_path: str): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.transform = transforms.Compose([ transforms.Resize((600), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(600), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.model = efficientnet_b7() self.model.classifier[1] = torch.nn.Linear(2560, 7) state_dict = torch.load(weights_path, map_location=torch.device(self.device)) self.model.load_state_dict(state_dict) self.model.eval() def make_prediction(self, image: Image): image = self.transform(image).unsqueeze(0) with torch.no_grad(): pred = F.softmax(self.model(image), dim=1) return pred