photo-enhancer / src /envs /features_extractor.py
zakaria-narjis's picture
add src and models
998f96a
import torch
from torchvision import transforms
from torchvision import models
ENCODING_DEVICE = 'cuda'
class Extractor:
"""
Features extractor boilerplate
"""
def __init__(self,):
self.name=''
self.input_shape=None
self.output_shape=None
self.model = None
self.preprocess = None
self.device = None
class ResnetEncoder(Extractor):
def __init__(self):
super().__init__()
pass
"""
Legacy code
"""
# class ResnetEncoder(Extractor):
# def __init__(self,):
# super().__init__()
# self.device = ENCODING_DEVICE
# self.model = models.resnet18(weights='ResNet18_Weights.DEFAULT')
# self.model = torch.nn.Sequential(*(list(self.model.children())[:-1])).to(self.device) #remove classifier
# self.model.eval()
# self.preprocess = transforms.Compose([
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
# def encode(self,images:torch.Tensor)->torch.Tensor:
# """
# args:
# image(torch.Tensor): Batch of images with shape (B,3,H,W) with uint8 values
# return:
# output(torch.Tensor): Batch of encoded images (images features) with shape(B,512)
# """
# assert images.dim()==4
# assert images.shape[1]==3
# with torch.inference_mode():
# output = images.clone().to(self.device)
# output = self.preprocess(output)
# output = self.model(output)
# output = torch.flatten(output,start_dim=-3,end_dim=-1)
# return output