Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from typing import Tuple | |
| class Encoder(nn.Module): | |
| """ | |
| Image encoder to obtain features from images using a pretrained ResNet-50 model. | |
| The last layer of ResNet-50 is removed, and a linear layer is added to transform | |
| the output to the desired feature dimension. | |
| Args: | |
| image_emb_dim (int): Final output dimension of image features. | |
| device (torch.device): Device to run the model on (CPU or GPU). | |
| """ | |
| def __init__(self, image_emb_dim: int, device: torch.device): | |
| super(Encoder, self).__init__() | |
| self.image_emb_dim = image_emb_dim | |
| self.device = device | |
| # Load pretrained ResNet-50 model and freeze its parameters | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| for param in resnet.parameters(): | |
| param.requires_grad_(False) | |
| # Remove the last layer of ResNet-50 | |
| modules = list(resnet.children())[:-1] | |
| self.resnet = nn.Sequential(*modules) | |
| # Define a final classifier | |
| self.fc = nn.Linear(resnet.fc.in_features, self.image_emb_dim) | |
| def forward(self, images: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass through the encoder. | |
| Args: | |
| images (torch.Tensor): Input images of shape (BATCH, 3, 224, 224). | |
| Returns: | |
| torch.Tensor: Image features of shape (BATCH, IMAGE_EMB_DIM). | |
| """ | |
| features = self.resnet(images) | |
| # Reshape features to (BATCH, 2048) | |
| features = features.reshape(features.size(0), -1).to(self.device) | |
| # Pass features through final linear layer | |
| features = self.fc(features).to(self.device) | |
| return features | |
| class Decoder(nn.Module): | |
| """ | |
| Decoder that uses an LSTM to generate captions from embedded words and encoded image features. | |
| The hidden and cell states of the LSTM are initialized using the encoded image features. | |
| Args: | |
| word_emb_dim (int): Dimension of word embeddings. | |
| hidden_dim (int): Dimension of the LSTM hidden state. | |
| num_layers (int): Number of LSTM layers. | |
| vocab_size (int): Size of the vocabulary (output dimension of the final linear layer). | |
| device (torch.device): Device to run the model on (CPU or GPU). | |
| """ | |
| def __init__(self, | |
| word_emb_dim: int, | |
| hidden_dim: int, | |
| num_layers: int, | |
| vocab_size: int, | |
| device: torch.device): | |
| super(Decoder, self).__init__() | |
| self.word_emb_dim = word_emb_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_layers = num_layers | |
| self.vocab_size = vocab_size | |
| self.device = device | |
| # Initialize hidden and cell states | |
| self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim))) | |
| self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim))) | |
| # Define LSTM layer | |
| self.lstm = nn.LSTM(self.word_emb_dim, | |
| self.hidden_dim, | |
| num_layers=self.num_layers, | |
| bidirectional=False) | |
| # Define final linear layer with LogSoftmax activation | |
| self.fc = nn.Sequential( | |
| nn.Linear(self.hidden_dim, self.vocab_size), | |
| nn.LogSoftmax(dim=2) | |
| ) | |
| def forward(self, | |
| embedded_captions: torch.Tensor, | |
| hidden: torch.Tensor, | |
| cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Forward pass through the decoder. | |
| Args: | |
| embedded_captions (torch.Tensor): Embedded captions of shape (SEQ_LEN, BATCH, WORD_EMB_DIM). | |
| hidden (torch.Tensor): LSTM hidden state of shape (NUM_LAYER, BATCH, HIDDEN_DIM). | |
| cell (torch.Tensor): LSTM cell state of shape (NUM_LAYER, BATCH, HIDDEN_DIM). | |
| Returns: | |
| Tuple: | |
| - output (torch.Tensor): Output logits of shape (SEQ_LEN, BATCH, VOCAB_SIZE). | |
| - (hidden, cell) (Tuple[torch.Tensor, torch.Tensor]): Updated hidden and cell states. | |
| """ | |
| # Pass through LSTM | |
| output, (hidden, cell) = self.lstm(embedded_captions, (hidden, cell)) | |
| # Pass through final linear layer | |
| output = self.fc(output) | |
| return output, (hidden, cell) | |