File size: 775 Bytes
a8e2ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

device = "cuda" if torch.cuda.is_available() else "cpu"


class Resnet:
    def __init__(self):
        self.model = models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT).to(device)
        self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def get_embedding(self, image):
        img_tensor = self.transform(image).unsqueeze(0)
        with torch.no_grad():
            embedding = self.model(img_tensor).squeeze()
        return embedding