File size: 5,032 Bytes
d1e6a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import transforms
from PIL import Image

import os

# ===========
# Vocabulary
# ===========

class Vocabulary:
    def __init__(self):
        self.itos = {}
        self.stoi = {}

    def load(self, stoi, itos):
        self.stoi = stoi
        self.itos = itos

# ===========
# Encoder
# ===========

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)  # [B, 2048, 1, 1]
        features = features.view(features.size(0), -1)  # [B, 2048]
        features = self.linear(features)                # [B, embed_size]
        features = self.bn(features)                    # [B, embed_size]
        return features

    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images).squeeze()
        features = self.linear(features)
        features = self.bn(features)
        return features

# ===========
# Decoder
# ===========

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs

    def sample(self, features, vocab, max_len=30):
        """

        Generates a caption for given image features using greedy search.

        """
        output_ids = []
        states = None

        inputs = features.unsqueeze(1)  # [B, 1, embed_size]

        for _ in range(max_len):
            hiddens, states = self.lstm(inputs, states)  # [B, 1, hidden]
            outputs = self.linear(hiddens.squeeze(1))    # [B, vocab_size]
            predicted = outputs.argmax(1)                # [B]
            output_ids.append(predicted.item())

            if vocab.itos[predicted.item()] == "<end>":
                break

            inputs = self.embed(predicted).unsqueeze(1)

        return output_ids

# ===========
# Predict block
# ===========

def predict(image_path, checkpoint_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    embed_size = 256
    hidden_size = 512
    num_layers = 1

    # === Load checkpoint ===
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # === Load vocab ===
    vocab = Vocabulary()
    vocab.load(checkpoint['vocab_stoi'], checkpoint['vocab_itos'])
    vocab_size = len(vocab.stoi)

    # === Load models ===
    encoder = EncoderCNN(embed_size).to(device)
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    encoder.eval()
    decoder.eval()

    # === Image transform ===
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # [1, 3, 224, 224]

    # === Encode ===
    features = encoder(image)

    # === Decode ===
    output_ids = decoder.sample(features, vocab)

    # === Convert IDs to words ===
    caption = []
    for idx in output_ids:
        word = vocab.itos[idx]
        if word == "<end>":
            break
        caption.append(word)

    final_caption = ' '.join(caption)
    print(f"\n📝 Predicted caption: {final_caption}\n")

if __name__ == "__main__":
    # ✅ Change these!
    image_path = r"C:\Users\Jayasimma D\Documents\Skin_Disease_Captioning\Dataset\images\train\Albinism\Albinism2.jpg"  # 🔍 your test image path
    checkpoint_path = "./checkpoints/caption_model_epoch5.pth"

    predict(image_path, checkpoint_path)