| from torch import nn | |
| import timm | |
| from configuration import CFG | |
| class ImageEncoder(nn.Module): | |
| """ | |
| Encode images to a fixed size vector | |
| """ | |
| def __init__( | |
| self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable | |
| ): | |
| super().__init__() | |
| self.model = timm.create_model( | |
| model_name, pretrained, num_classes=0, global_pool="avg" | |
| ) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| def forward(self, x): | |
| return self.model(x) |