import torch from torchvision import transforms from torchvision import models ENCODING_DEVICE = 'cuda' class Extractor: """ Features extractor boilerplate """ def __init__(self,): self.name='' self.input_shape=None self.output_shape=None self.model = None self.preprocess = None self.device = None class ResnetEncoder(Extractor): def __init__(self): super().__init__() pass """ Legacy code """ # class ResnetEncoder(Extractor): # def __init__(self,): # super().__init__() # self.device = ENCODING_DEVICE # self.model = models.resnet18(weights='ResNet18_Weights.DEFAULT') # self.model = torch.nn.Sequential(*(list(self.model.children())[:-1])).to(self.device) #remove classifier # self.model.eval() # self.preprocess = transforms.Compose([ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ]) # def encode(self,images:torch.Tensor)->torch.Tensor: # """ # args: # image(torch.Tensor): Batch of images with shape (B,3,H,W) with uint8 values # return: # output(torch.Tensor): Batch of encoded images (images features) with shape(B,512) # """ # assert images.dim()==4 # assert images.shape[1]==3 # with torch.inference_mode(): # output = images.clone().to(self.device) # output = self.preprocess(output) # output = self.model(output) # output = torch.flatten(output,start_dim=-3,end_dim=-1) # return output