File size: 7,291 Bytes
601cad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import torch.nn as nn
import torch.nn.functional as F

from cnn_encoder import CNNEncoder
from vit_encoder import ViTEncoder
from transformer_encoder import TransformerEncoder
from transformer_decoder import TransformerDecoder


class ImageCaptioningModel(nn.Module):
    def __init__(

        self,

        vocab_size,

        pad_id,

        d_model=512,

        num_encoder_layers=6,

        num_decoder_layers=6,

        num_heads=8,

        dim_ff=2048,

        max_seq_len=50,

        dropout=0.1,

        freeze_backbone=True,

        use_vit=False

    ):
        super().__init__()

        self.use_vit = use_vit

        if self.use_vit:
            self.encoder = ViTEncoder(d_model=d_model, freeze_backbone=freeze_backbone)
        else:
            self.encoder = CNNEncoder(d_model=d_model, freeze_backbone=freeze_backbone)

        self.transformer_encoder = TransformerEncoder(
            d_model=d_model,
            num_layers=num_encoder_layers,
            num_heads=num_heads,
            dim_ff=dim_ff,
            max_len=200,
            dropout=dropout,
            use_vit=self.use_vit
        )
        self.decoder = TransformerDecoder(
            vocab_size=vocab_size,
            pad_id=pad_id,
            d_model=d_model,
            num_layers=num_decoder_layers,
            num_heads=num_heads,
            dim_ff=dim_ff,
            max_len=max_seq_len,
            dropout=dropout,
        )
        self.d_model = d_model

    def generate_square_subsequent_mask(self, sz):
        return self.decoder.generate_square_subsequent_mask(sz)

    def unfreeze_encoder(self, unfreeze=True):
        self.encoder.unfreeze_backbone(unfreeze)

    def encode_image(self, images):
        img_features = self.encoder(images)
        return self.transformer_encoder(img_features)

    def forward(self, images, captions, tgt_mask=None, tgt_padding_mask=None):
        img_features = self.encode_image(images)
        return self.decoder(
            captions=captions,
            img_features=img_features,
            tgt_mask=tgt_mask,
            tgt_padding_mask=tgt_padding_mask,
        )
    
    def predict_caption_beam(self, image, vocab, beam_width=5, max_len=50, alpha=0.7, device="cpu"):
        """

        Generates a caption using beam search decoding.



        Args:

            image: Preprocessed image tensor of shape (1, 3, H, W).

            vocab: Vocabulary object with word2idx and idx2word mappings.

            beam_width: Number of candidate sequences to keep at each step.

            max_len: Maximum caption length.

            alpha: Length normalization penalty. Higher values favor longer captions.

            device: Device to run inference on.



        Returns:

            The highest-scoring caption as a string.

        """
        self.eval()
        with torch.no_grad():
            img_features = self.encode_image(image)

            bos_idx = vocab.word2idx["<bos>"]
            eos_idx = vocab.word2idx["<eos>"]

            # Each beam: (log_probability, token_indices_list)
            beams = [(0.0, [bos_idx])]
            completed = []

            for _ in range(max_len):
                candidates = []

                for score, seq in beams:
                    # If this beam already ended, don't expand it
                    if seq[-1] == eos_idx:
                        completed.append((score, seq))
                        continue

                    tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device)
                    tgt_mask = self.generate_square_subsequent_mask(len(seq)).to(device)

                    logits = self.decoder(
                        captions=tgt_tensor,
                        img_features=img_features,
                        tgt_mask=tgt_mask,
                        tgt_padding_mask=None,
                    )

                    # Get log-probabilities for the last token
                    log_probs = F.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)

                    # Select top-k tokens
                    topk_log_probs, topk_indices = log_probs.topk(beam_width)

                    for log_p, idx in zip(topk_log_probs.tolist(), topk_indices.tolist()):
                        new_seq = seq + [idx]
                        new_score = score + log_p
                        candidates.append((new_score, new_seq))

                # Keep top beam_width candidates (sorted by score)
                candidates.sort(key=lambda x: x[0], reverse=True)
                beams = candidates[:beam_width]

                # Early stop: all beams have ended
                if not beams:
                    break

            # Add any remaining incomplete beams to completed
            completed.extend(beams)

            # Length-normalized scoring: score / (length ^ alpha)
            def normalize_score(score, length):
                return score / (length ** alpha)

            completed.sort(
                key=lambda x: normalize_score(x[0], len(x[1])),
                reverse=True
            )

            best_seq = completed[0][1]

            # Convert indices to words, skipping special tokens
            tokens = []
            for idx in best_seq:
                word = vocab.idx2word.get(idx, "<unk>")
                if word not in ["<bos>", "<eos>", "<pad>"]:
                    tokens.append(word)

            return " ".join(tokens)


    def predict_caption(self, image, vocab, max_len=50, device="cpu"):

        '''

        Generates a caption using greedy decoding.

        Args:

            image: Preprocessed image tensor of shape (1, 3, H, W).

            vocab: Vocabulary object with word2idx and idx2word mappings.

            max_len: Maximum caption length.

            device: Device to run inference on.



        Returns:

            The generated caption as a string.

        '''

        self.eval()
        with torch.no_grad():
            img_features = self.encode_image(image)

            start_token_idx = vocab.word2idx["<bos>"]
            end_token_idx = vocab.word2idx["<eos>"]

            tgt_indices = [start_token_idx]
            for _ in range(max_len):
                tgt_tensor = torch.tensor(tgt_indices).unsqueeze(0).to(device)
                tgt_mask = self.generate_square_subsequent_mask(len(tgt_indices)).to(device)

                logits = self.decoder(
                    captions=tgt_tensor,
                    img_features=img_features,
                    tgt_mask=tgt_mask,
                    tgt_padding_mask=None,
                )

                last_token_logits = logits[:, -1, :]
                predicted_id = last_token_logits.argmax(dim=-1).item()
                if predicted_id == end_token_idx:
                    break
                tgt_indices.append(predicted_id)

            tokens = []
            for idx in tgt_indices:
                word = vocab.idx2word.get(idx, "<unk>")
                if word not in ["<bos>", "<eos>", "<pad>"]:
                    tokens.append(word)

            return " ".join(tokens)