import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision.models import resnet50, ResNet50_Weights import torchvision.transforms as transforms from dataset import build_vocab_from_json, CaptionDataset, my_collate_fn class EncoderCNN(nn.Module): def __init__(self, embed_size): super(EncoderCNN, self).__init__() resnet = resnet50(weights=ResNet50_Weights.DEFAULT) modules = list(resnet.children())[:-1] # remove FC layer self.resnet = nn.Sequential(*modules) self.linear = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) def forward(self, images): with torch.no_grad(): features = self.resnet(images).squeeze() features = self.linear(features) features = self.bn(features) return features class DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): super(DecoderRNN, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) def forward(self, features, captions): embeddings = self.embed(captions[:, :-1]) # Exclude inputs = torch.cat((features.unsqueeze(1), embeddings), 1) # Add image feature at t=0 hiddens, _ = self.lstm(inputs) outputs = self.linear(hiddens) return outputs embed_size = 256 hidden_size = 512 num_layers = 1 learning_rate = 3e-4 num_epochs = 30 batch_size = 8 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") captions_train_json = "./Dataset/annotations/captions_train.json" images_train_dir = "./Dataset/images/train/" transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) vocab = build_vocab_from_json(captions_train_json, freq_threshold=2) vocab_size = len(vocab) train_dataset = CaptionDataset( images_dir=images_train_dir, captions_file=captions_train_json, vocab=vocab, transform=transform ) train_loader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn ) encoder = EncoderCNN(embed_size).to(device) decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device) criterion = nn.CrossEntropyLoss(ignore_index=0) params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters()) optimizer = optim.Adam(params, lr=learning_rate) encoder.train() decoder.train() os.makedirs("checkpoints", exist_ok=True) for epoch in range(num_epochs): for idx, (imgs, captions) in enumerate(train_loader): imgs, captions = imgs.to(device), captions.to(device) features = encoder(imgs) outputs = decoder(features, captions) outputs = outputs[:, 1:, :] # [B, T-1, vocab_size] outputs = outputs.reshape(-1, vocab_size) targets = captions[:, 1:].reshape(-1) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() if idx % 50 == 0: print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{idx}/{len(train_loader)}] Loss: {loss.item():.4f}") torch.save({ 'epoch': epoch + 1, 'encoder_state_dict': encoder.state_dict(), 'decoder_state_dict': decoder.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'vocab_stoi': vocab.stoi, 'vocab_itos': vocab.itos, }, f"checkpoints/caption_model_epoch{epoch+1}.pth") print(f"✅ Saved model to checkpoints/caption_model_epoch{epoch+1}.pth") print("Training complete ✅")