| | import torch
|
| | import torch.nn as nn
|
| | from torchvision import models
|
| | from transformers import PreTrainedModel, PretrainedConfig
|
| |
|
| |
|
| | class ImageCaptioningConfig(PretrainedConfig):
|
| | model_type = "image-captioning"
|
| |
|
| | def __init__(self, feature_dim=1024, embedding_dim=256, hidden_dim=512,
|
| | vocab_size=5000, num_layers=1, dropout=0.5, **kwargs):
|
| | super().__init__(**kwargs)
|
| | self.feature_dim = feature_dim
|
| | self.embedding_dim = embedding_dim
|
| | self.hidden_dim = hidden_dim
|
| | self.vocab_size = vocab_size
|
| | self.num_layers = num_layers
|
| | self.dropout = dropout
|
| |
|
| |
|
| | class LSTMDecoder(nn.Module):
|
| | def __init__(self, feature_dim, embedding_dim, hidden_dim, vocab_size, num_layers=1, dropout=0.5):
|
| | super(LSTMDecoder, self).__init__()
|
| | self.feature_project = nn.Linear(feature_dim, embedding_dim)
|
| | self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
| | self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers,
|
| | batch_first=True, bidirectional=True, dropout=dropout)
|
| | self.dropout = nn.Dropout(dropout)
|
| | self.fc = nn.Linear(hidden_dim * 2, vocab_size)
|
| |
|
| | def forward(self, image_features, captions):
|
| | image_features = self.feature_project(image_features).unsqueeze(1)
|
| | embeddings = self.embedding(captions)
|
| | embeddings = torch.cat((image_features, embeddings), dim=1)
|
| | lstm_out, _ = self.lstm(embeddings)
|
| | lstm_out = self.dropout(lstm_out)
|
| | outputs = self.fc(lstm_out)
|
| | return outputs
|
| |
|
| |
|
| | class ImageCaptioningModel(PreTrainedModel):
|
| | config_class = ImageCaptioningConfig
|
| |
|
| | def __init__(self, config: ImageCaptioningConfig):
|
| | super().__init__(config)
|
| |
|
| |
|
| | self.encoder = models.densenet121(pretrained=True)
|
| | self.encoder.classifier = nn.Identity()
|
| | self.encoder.eval()
|
| |
|
| |
|
| | self.decoder = LSTMDecoder(
|
| | feature_dim=config.feature_dim,
|
| | embedding_dim=config.embedding_dim,
|
| | hidden_dim=config.hidden_dim,
|
| | vocab_size=config.vocab_size,
|
| | num_layers=config.num_layers,
|
| | dropout=config.dropout
|
| | )
|
| |
|
| | def forward(self, image, caption):
|
| | with torch.no_grad():
|
| | image_features = self.encoder(image)
|
| | output = self.decoder(image_features, caption)
|
| | return output
|
| |
|