Varsha Dewangan
Initial clean commit for project deployment
ee1d4aa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models # Used for ResNet50
from .utils import get_logger # Import logger
logger = get_logger(__name__)
class EncoderCNN(nn.Module):
"""
Encoder using a pre-trained ResNet50 model.
The output feature maps are adaptively pooled to a fixed size
and then reshaped for the attention mechanism in the decoder.
"""
def __init__(self, encoded_image_size=14, fine_tune=True):
"""
Initializes the EncoderCNN.
Args:
encoded_image_size (int): The spatial dimension (e.g., 14x14) to which
the feature maps will be adaptively pooled.
fine_tune (bool): If True, allows the parameters of the pre-trained
ResNet to be updated during training. If False, they are frozen.
"""
super(EncoderCNN, self).__init__()
self.encoded_image_size = encoded_image_size
# Load pre-trained ResNet50 and remove the final classification layer
# We use the default recommended weights for ResNet50.
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
# Remove the average pooling layer and the fully connected layer at the end.
# We want the feature maps before these layers for spatial attention.
# The `modules` list will contain layers up to `layer4` (the last convolutional block).
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules)
# Freeze parameters of the pre-trained ResNet if fine_tune is False.
# This prevents updating their weights during training.
if not fine_tune:
for param in self.resnet.parameters():
param.requires_grad = False
logger.info("ResNet encoder base layers are frozen.")
else:
logger.info("ResNet encoder base layers are fine-tuning enabled.")
# Adaptive pool to a fixed size (e.g., 14x14).
# This ensures a consistent spatial dimension for the feature maps,
# regardless of the input image size, useful for attention.
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
# The output feature dimension from ResNet50 before the last avg pool/fc is 2048.
self.encoder_dim = 2048
def forward(self, images):
"""
Forward pass through the ResNet encoder.
Args:
images (torch.Tensor): Input images, shape (batch_size, 3, H, W).
Returns:
torch.Tensor: Encoded image features,
shape (batch_size, encoder_dim, encoded_image_size, encoded_image_size).
"""
# Pass images through the ResNet feature extractor
out = self.resnet(images)
# Apply adaptive pooling to get a fixed spatial size (e.g., 14x14)
out = self.adaptive_pool(out)
# The output shape is (batch_size, encoder_dim, encoded_image_size, encoded_image_size)
return out
class Attention(nn.Module):
"""
Additive Attention Mechanism (Bahdanau style).
Calculates attention weights based on encoded image features and decoder's hidden state.
"""
def __init__(self, encoder_dim, decoder_dim, attention_dim):
"""
Initializes the Attention module.
Args:
encoder_dim (int): Feature size of encoded images (e.g., 2048 for ResNet50).
decoder_dim (int): Hidden state size of the decoder LSTM.
attention_dim (int): Size of the linear layers within the attention mechanism.
"""
super(Attention, self).__init__()
# Linear layer to transform encoder output for attention calculation
self.encoder_att = nn.Linear(encoder_dim, attention_dim)
# Linear layer to transform decoder hidden state for attention calculation
self.decoder_att = nn.Linear(decoder_dim, attention_dim)
# Linear layer to calculate attention "score" (or energy)
# This layer projects the combined features to a single scalar per pixel.
self.full_att = nn.Linear(attention_dim, 1)
self.relu = nn.ReLU()
# Softmax over the "num_pixels" dimension to get attention weights that sum to 1
self.softmax = nn.Softmax(dim=1)
def forward(self, encoder_out, decoder_hidden):
"""
Forward pass through the attention mechanism.
Args:
encoder_out (torch.Tensor): Encoded images, shape (batch_size, num_pixels, encoder_dim).
decoder_hidden (torch.Tensor): Previous decoder hidden state, shape (batch_size, decoder_dim).
Returns:
tuple:
- attention_weighted_encoding (torch.Tensor): Context vector,
shape (batch_size, encoder_dim).
- alpha (torch.Tensor): Attention weights (probability distribution over pixels),
shape (batch_size, num_pixels).
"""
# Transform encoder output: (batch_size, num_pixels, attention_dim)
att1 = self.encoder_att(encoder_out)
# Transform decoder hidden state, then unsqueeze to (batch_size, 1, attention_dim)
# for broadcasting during addition with att1
att2 = self.decoder_att(decoder_hidden).unsqueeze(1)
# Calculate attention scores: (batch_size, num_pixels)
# Sum of transformed encoder output and transformed decoder hidden state,
# passed through ReLU and then a linear layer to get a single score per pixel.
att = self.full_att(self.relu(att1 + att2)).squeeze(2)
# Apply softmax to get attention weights (alpha): (batch_size, num_pixels)
alpha = self.softmax(att)
# Calculate attention-weighted encoding: (batch_size, encoder_dim)
# This is the context vector, a weighted sum of the encoder features.
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
return attention_weighted_encoding, alpha
class DecoderWithAttention(nn.Module):
"""
LSTM Decoder with Attention mechanism.
Generates captions word by word, using the attention-weighted image features
and previously generated words.
"""
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size,
encoder_dim=2048, dropout=0.5):
"""
Initializes the DecoderWithAttention.
Args:
attention_dim (int): Size of the attention linear layer.
embed_dim (int): Dimension of word embeddings.
decoder_dim (int): Hidden state size of the decoder LSTM.
vocab_size (int): Total size of the vocabulary.
encoder_dim (int): Feature size of encoded images (default 2048 for ResNet50).
dropout (float): Dropout rate for regularization.
"""
super(DecoderWithAttention, self).__init__()
self.encoder_dim = encoder_dim
self.attention_dim = attention_dim
self.embed_dim = embed_dim
self.decoder_dim = decoder_dim
self.vocab_size = vocab_size
self.dropout_rate = dropout
# Attention network
self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
# Word embedding layer
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.embedding_dropout = nn.Dropout(self.dropout_rate)
# LSTMCell for decoding
# Input to LSTMCell is the concatenation of word embedding and attention-weighted encoding
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
# Linear layers to initialize the LSTM's hidden and cell states from the encoder output
self.init_h = nn.Linear(encoder_dim, decoder_dim)
self.init_c = nn.Linear(encoder_dim, decoder_dim)
# Linear layer to create a "gate" for the attention-weighted encoding (Visual Sentinel)
# This f_beta gate allows the model to decide how much of the attention-weighted
# context to use for generating the next word, enabling it to ignore irrelevant visual information.
self.f_beta = nn.Linear(decoder_dim, encoder_dim)
self.sigmoid = nn.Sigmoid() # Activation for the gate
# Linear layer to project decoder output to vocabulary size (scores for each word)
self.fc = nn.Linear(decoder_dim, vocab_size)
self.dropout_layer = nn.Dropout(self.dropout_rate)
# Initialize some weights
self.init_weights()
# A placeholder for max caption length during inference/generation
# This will typically be set by the calling model or config
self.max_caption_length_for_inference = 20
def init_weights(self):
"""Initializes some parameters with values from the uniform distribution."""
self.embedding.weight.data.uniform_(-0.1, 0.1)
self.fc.bias.data.fill_(0)
self.fc.weight.data.uniform_(-0.1, 0.1)
def load_pretrained_embeddings(self, embeddings):
"""
Loads pre-trained embeddings into the embedding layer.
Args:
embeddings (torch.Tensor): A tensor of pre-trained word embeddings.
"""
self.embedding.weight = nn.Parameter(embeddings)
# Optionally, freeze embeddings if they are pre-trained and you don't want to fine-tune them
# self.embedding.weight.requires_grad = False
def fine_tune_embeddings(self, fine_tune=True):
"""
Allows or disallows fine-tuning of the embedding layer.
Args:
fine_tune (bool): If True, embedding weights are trainable. If False, they are frozen.
"""
for p in self.embedding.parameters():
p.requires_grad = fine_tune
def init_hidden_state(self, encoder_out):
"""
Creates initial hidden and cell states for the LSTM from the encoded image.
Uses the mean of the encoder output features to initialize the LSTM states.
Args:
encoder_out (torch.Tensor): Encoded images, shape (batch_size, num_pixels, encoder_dim).
Returns:
tuple: (hidden state (h), cell state (c)), each of shape (batch_size, decoder_dim).
"""
# Calculate mean of encoder output across pixels
mean_encoder_out = encoder_out.mean(dim=1)
h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
c = self.init_c(mean_encoder_out) # (batch_size, decoder_dim)
return h, c
def forward(self, encoder_out, encoded_captions, caption_lengths):
"""
Forward pass through the decoder during training.
Args:
encoder_out (torch.Tensor): Encoded images from CNN,
shape (batch_size, encoder_dim, enc_image_size_H, enc_image_size_W).
encoded_captions (torch.Tensor): Captions, shape (batch_size, max_caption_length).
caption_lengths (torch.Tensor): Actual lengths of captions (before padding), shape (batch_size,).
Returns:
tuple:
- predictions (torch.Tensor): Predicted word scores,
shape (batch_size, max_decode_length_in_batch, vocab_size).
- encoded_captions (torch.Tensor): Captions sorted by length.
- decode_lengths (list): Actual decoding lengths for each caption in the batch.
- alphas (torch.Tensor): Attention weights,
shape (batch_size, max_decode_length_in_batch, num_pixels).
- sort_ind (torch.Tensor): Indices used to sort the batch.
"""
batch_size = encoder_out.size(0)
enc_image_h = encoder_out.size(2)
enc_image_w = encoder_out.size(3)
num_pixels = enc_image_h * enc_image_w
# Reshape encoder_out for attention: (batch_size, num_pixels, encoder_dim)
# Permute from (N, C, H, W) to (N, H, W, C) then flatten H*W
encoder_out = encoder_out.permute(0, 2, 3, 1).contiguous()
encoder_out = encoder_out.view(batch_size, num_pixels, self.encoder_dim)
# Sort input data by decreasing lengths for packed sequences.
# This is crucial for efficient processing with `pack_padded_sequence`.
caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
encoder_out = encoder_out[sort_ind] # Apply sorting to encoder output
encoded_captions = encoded_captions[sort_ind] # Apply sorting to captions
# Embedding: (batch_size, max_caption_length, embed_dim)
embeddings = self.embedding(encoded_captions)
embeddings = self.embedding_dropout(embeddings)
# Initialize LSTM state (h, c) from the mean of encoder output
h, c = self.init_hidden_state(encoder_out)
# Create tensors to hold word predictions and attention weights.
# We predict up to (max_caption_length - 1) words (excluding the <START> token).
decode_lengths = (caption_lengths - 1).tolist() # Lengths of sequences to decode
max_decode_length = max(decode_lengths) # Max length in the current batch
predictions = torch.zeros(batch_size, max_decode_length, self.vocab_size).to(encoder_out.device)
alphas = torch.zeros(batch_size, max_decode_length, num_pixels).to(encoder_out.device)
# For each time step in the decoding process
for t in range(max_decode_length):
# Get batch size for current time step.
# Sequences are padded, so some might finish early.
batch_size_t = sum([l > t for l in decode_lengths])
# Apply attention mechanism to the active sequences in the batch
attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
# Apply sigmoid gate to attention-weighted encoding (Visual Sentinel)
gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
attention_weighted_encoding = gate * attention_weighted_encoding
# Perform one step of LSTM decoding
# Input to LSTM: (current_word_embedding + attention_weighted_encoding)
h, c = self.decode_step(
torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
(h[:batch_size_t], c[:batch_size_t])
)
# Predict next word using the fully connected layer
preds = self.fc(self.dropout_layer(h))
# Store predictions and attention weights for the current time step
predictions[:batch_size_t, t, :] = preds
alphas[:batch_size_t, t, :] = alpha
return predictions, encoded_captions, decode_lengths, alphas, sort_ind
class ImageCaptioningModel(nn.Module):
"""
Complete Image Captioning Model integrating EncoderCNN and DecoderWithAttention.
Provides methods for both training (forward pass) and inference (caption generation).
"""
def __init__(self, vocab_size, embed_dim=256, attention_dim=256, encoder_dim=2048,
decoder_dim=256, dropout=0.5, fine_tune_encoder=True, max_caption_length=20):
"""
Initializes the ImageCaptioningModel.
Args:
vocab_size (int): Total size of the vocabulary.
embed_dim (int): Dimension of word embeddings.
attention_dim (int): Size of the attention linear layer.
encoder_dim (int): Feature size of encoded images (default 2048 for ResNet50).
decoder_dim (int): Hidden state size of the decoder LSTM.
dropout (float): Dropout rate for regularization.
fine_tune_encoder (bool): If True, allows the encoder parameters to be updated.
max_caption_length (int): Maximum length of generated captions during inference.
"""
super(ImageCaptioningModel, self).__init__()
# Initialize the Encoder (ResNet50-based)
self.encoder = EncoderCNN(encoded_image_size=14, fine_tune=fine_tune_encoder)
# Ensure encoder_dim matches ResNet50 output dimension
self.encoder_dim = self.encoder.encoder_dim # This will be 2048
# Initialize the Decoder with Attention
self.decoder = DecoderWithAttention(
attention_dim=attention_dim,
embed_dim=embed_dim,
decoder_dim=decoder_dim,
vocab_size=vocab_size,
encoder_dim=self.encoder_dim, # Pass the correct encoder_dim
dropout=dropout
)
self.decoder.max_caption_length_for_inference = max_caption_length # Set max length for inference
def forward(self, images, captions, caption_lengths):
"""
Forward pass through the complete model for training.
Args:
images (torch.Tensor): Input images.
captions (torch.Tensor): Target captions.
caption_lengths (torch.Tensor): Actual lengths of captions.
Returns:
tuple: (predictions, encoded_captions, decode_lengths, alphas, sort_ind)
as returned by the decoder's forward pass.
"""
encoder_out = self.encoder(images) # Encode images
predictions, encoded_captions, decode_lengths, alphas, sort_ind = self.decoder(
encoder_out, captions, caption_lengths # Decode captions
)
return predictions, encoded_captions, decode_lengths, alphas, sort_ind
def generate_caption(self, image_tensor, vocabulary, device, beam_size=5, max_length=None):
"""
Performs beam search to generate a caption for a single image.
This method is now part of the ImageCaptioningModel class.
Args:
image_tensor (torch.Tensor): Preprocessed image tensor (3, H, W). NOT batched.
vocabulary (COCOVocabulary): Vocabulary object.
device (torch.device): Device to run the model on (cpu/cuda).
beam_size (int): Size of beam for beam search.
max_length (int, optional): Maximum length of the generated caption.
If None, uses self.decoder.max_caption_length_for_inference.
Returns:
str: Generated caption string.
"""
self.eval() # Set model to evaluation mode
# Use the max_length from config if provided, otherwise fallback to model's default
current_max_length = max_length if max_length is not None else self.decoder.max_caption_length_for_inference
with torch.no_grad():
# Add batch dimension and move to device for the encoder
# image_tensor goes from (C, H, W) to (1, C, H, W)
image_tensor_batched = image_tensor.unsqueeze(0).to(device)
# Get encoder output: (1, encoder_dim, encoded_image_size, encoded_image_size)
encoder_output_from_cnn = self.encoder(image_tensor_batched)
# Reshape encoder_output_from_cnn to (1, num_pixels, encoder_dim) for attention
# Permute from (N, C, H, W) to (N, H, W, C) then flatten H*W
encoder_out = encoder_output_from_cnn.permute(0, 2, 3, 1).contiguous()
encoder_out = encoder_out.view(1, -1, self.encoder_dim) # (1, num_pixels, encoder_dim)
# Expand for beam search: (beam_size, num_pixels, encoder_dim)
encoder_out = encoder_out.expand(beam_size, encoder_out.size(1), encoder_out.size(2))
# Tensor to store top k previous words at each step; initialized with <START> token for all beams
k_prev_words = torch.LongTensor([[vocabulary.word2idx['<START>']]] * beam_size).to(device)
# Tensor to store top k sequences; initially just the <START> token
seqs = k_prev_words
# Tensor to store top k sequences' scores (log probabilities); initially all zeros
top_k_scores = torch.zeros(beam_size, 1).to(device)
# Lists to store completed captions and their scores
complete_seqs = list()
complete_seqs_scores = list()
# Initialize hidden state and cell state for LSTM
# encoder_out is already expanded for beam_size, so init_hidden_state will work
h, c = self.decoder.init_hidden_state(encoder_out)
# Start decoding loop
step = 1
while True:
# Get embeddings for the previously predicted words
embeddings = self.decoder.embedding(k_prev_words).squeeze(1) # (beam_size, embed_dim)
# Calculate attention-weighted encoding and attention weights
attention_weighted_encoding, alpha = self.decoder.attention(encoder_out, h)
# Apply visual sentinel gate
gate = self.decoder.sigmoid(self.decoder.f_beta(h))
attention_weighted_encoding = gate * attention_weighted_encoding
# Perform one step of LSTM decoding
h, c = self.decoder.decode_step(
torch.cat([embeddings, attention_weighted_encoding], dim=1),
(h, c)
) # (beam_size, decoder_dim)
# Get scores for the next word
scores = self.decoder.fc(h) # (beam_size, vocab_size)
scores = F.log_softmax(scores, dim=1) # Convert to log-probabilities
# Add current scores to previous scores for beam search
scores = top_k_scores.expand_as(scores) + scores # (beam_size, vocab_size)
# For the first step, all k generated words are from the same parent (<START>).
# For subsequent steps, they are from different parents.
if step == 1:
# For the first step, select top 'beam_size' words from the first beam's scores
top_k_scores, top_k_words = scores[0].topk(beam_size, 0, True, True) # (beam_size)
else:
# Flatten scores to find the top 'beam_size' overall (from all current beams)
top_k_scores, top_k_words = scores.view(-1).topk(beam_size, 0, True, True) # (beam_size)
# Convert flattened indices to actual row (previous word's beam index)
# and column (next word's vocabulary index) indices
prev_word_inds = top_k_words // vocabulary.vocab_size # (beam_size)
next_word_inds = top_k_words % vocabulary.vocab_size # (beam_size)
# Add new words to sequences
seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (beam_size, step + 1)
# Identify completed sequences (where <END> is generated)
incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds)
if next_word != vocabulary.word2idx['<END>']]
complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
# Add complete sequences to their lists
if len(complete_inds) > 0:
complete_seqs.extend(seqs[complete_inds].tolist())
complete_seqs_scores.extend(top_k_scores[complete_inds])
# Update beam_size: number of active beams for the next step
beam_size -= len(complete_inds)
if beam_size == 0: # If all beams complete, break
break
# Filter seqs, hidden states, cell states, scores, and previous words for incomplete sequences
seqs = seqs[incomplete_inds]
h = h[prev_word_inds[incomplete_inds]]
c = c[prev_word_inds[incomplete_inds]]
top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) # Reshape for next step
k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
encoder_out = encoder_out[prev_word_inds[incomplete_inds]] # Propagate encoder_out for active beams
# Break if maximum caption length is exceeded
if step > current_max_length:
break
step += 1
# If no complete captions were found (e.g., all beams exceeded max_length before <END>),
# pick the best incomplete sequence found so far.
if not complete_seqs:
# Take the best sequence among the currently active (incomplete) beams
final_seqs = seqs.tolist()
final_scores = top_k_scores.squeeze(1).tolist()
if not final_seqs: # Fallback if even no incomplete sequences are available (shouldn't happen)
return ""
i = final_scores.index(max(final_scores))
best_seq = final_seqs[i]
else:
# Find the best caption among all completed sequences based on score
i = complete_seqs_scores.index(max(complete_seqs_scores))
best_seq = complete_seqs[i]
# Convert the best sequence of indices back to a human-readable caption
return vocabulary.indices_to_caption(best_seq)