File size: 691 Bytes
a8e2ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torchvision import models, transforms
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"


class EfficientNet:
    def __init__(self):
        self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT).to(device)
        self.model.classifier = torch.nn.Identity()
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def get_embedding(self, image):
        img_tensor = self.transform(image).unsqueeze(0)
        with torch.no_grad():
            embedding = self.model(img_tensor)
        return embedding