Spaces:
Build error
Build error
| from torchvision import transforms | |
| from torchvision import models | |
| from efficientnet_pytorch import EfficientNet | |
| import torch | |
| from CustomModels import DinoVisionClassifier | |
| from functools import lru_cache | |
| classes = {0: 'Glas', 1: 'Organic', 2: 'Papier', 3: 'Restmüll', 4: 'Wertstoff'} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform = transforms.Compose( | |
| [transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ] | |
| ) | |
| transform_dinov2 = transforms.Compose( | |
| [ transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ] | |
| ) | |
| def load_specific_model(model_name): | |
| current_model = None | |
| if model_name == "EfficientNet-B3": | |
| current_model = EfficientNet.from_pretrained("efficientnet-b3", num_classes=len(classes.keys())) | |
| current_model.load_state_dict(torch.load("src/models/eff_b3_model.pt", map_location="cpu")) | |
| elif model_name == "EfficientNet-B4": | |
| current_model = EfficientNet.from_pretrained("efficientnet-b4", num_classes=len(classes.keys())) | |
| current_model.load_state_dict(torch.load("src/models/eff_b4.pt", map_location="cpu")) | |
| elif model_name == "vgg19": | |
| current_model = models.vgg19() | |
| in_features = current_model.classifier[0].in_features | |
| current_model.classifier = torch.nn.Linear(in_features, len(classes.keys())) | |
| current_model.load_state_dict(torch.load("src/models/vgg19.pt", map_location="cpu")) | |
| elif model_name == "resnet50": | |
| current_model = models.resnet50() | |
| in_features = current_model.fc.in_features | |
| current_model.fc = torch.nn.Linear(in_features, len(classes.keys())) | |
| current_model.load_state_dict(torch.load("src/models/resnet50.pt", map_location="cpu")) | |
| elif model_name == "dinov2_vits14": | |
| current_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vits14") | |
| current_model = DinoVisionClassifier(current_model, num_classes=len(classes.keys())) | |
| current_model.load_state_dict(torch.load("src/models/dinov2_vits14_0.054_98.00.pth", map_location="cpu")) | |
| print(f"Loaded model {model_name}") | |
| return current_model.eval().to(device) | |
| def inference(model, inp): | |
| model.eval() | |
| inp = transform(inp) if model.__class__.__name__ != "DinoVisionClassifier" else transform_dinov2(inp) | |
| inp = inp.unsqueeze(0).to(device) | |
| if torch.cuda.is_available(): | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy() | |
| else: | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy() | |
| confidences = {classes[i]: float(prediction[i]) for i in range(len(classes.keys()))} | |
| return confidences | |