File size: 3,950 Bytes
d1e6a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.transforms as transforms

from dataset import build_vocab_from_json, CaptionDataset, my_collate_fn


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]  # remove FC layer
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images).squeeze()
        features = self.linear(features)
        features = self.bn(features)
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])  # Exclude <end>
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)  # Add image feature at t=0
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs

embed_size = 256
hidden_size = 512
num_layers = 1
learning_rate = 3e-4
num_epochs = 30
batch_size = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

captions_train_json = "./Dataset/annotations/captions_train.json"
images_train_dir = "./Dataset/images/train/"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

vocab = build_vocab_from_json(captions_train_json, freq_threshold=2)
vocab_size = len(vocab)

train_dataset = CaptionDataset(
    images_dir=images_train_dir,
    captions_file=captions_train_json,
    vocab=vocab,
    transform=transform
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=my_collate_fn
)


encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0) 
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = optim.Adam(params, lr=learning_rate)


encoder.train()
decoder.train()

os.makedirs("checkpoints", exist_ok=True)

for epoch in range(num_epochs):
    for idx, (imgs, captions) in enumerate(train_loader):
        imgs, captions = imgs.to(device), captions.to(device)

        features = encoder(imgs)
        outputs = decoder(features, captions)

        outputs = outputs[:, 1:, :]  # [B, T-1, vocab_size]

        outputs = outputs.reshape(-1, vocab_size)
        targets = captions[:, 1:].reshape(-1)

        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 50 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{idx}/{len(train_loader)}] Loss: {loss.item():.4f}")

    torch.save({
        'epoch': epoch + 1,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'vocab_stoi': vocab.stoi,
        'vocab_itos': vocab.itos,
    }, f"checkpoints/caption_model_epoch{epoch+1}.pth")

    print(f"✅ Saved model to checkpoints/caption_model_epoch{epoch+1}.pth")

print("Training complete ✅")