import torch import torch.nn as nn from torchvision.models import resnet50, ResNet50_Weights from torchvision import transforms from PIL import Image import os # =========== # Vocabulary # =========== class Vocabulary: def __init__(self): self.itos = {} self.stoi = {} def load(self, stoi, itos): self.stoi = stoi self.itos = itos # =========== # Encoder # =========== class EncoderCNN(nn.Module): def __init__(self, embed_size): super(EncoderCNN, self).__init__() resnet = resnet50(weights=ResNet50_Weights.DEFAULT) modules = list(resnet.children())[:-1] self.resnet = nn.Sequential(*modules) self.linear = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) def forward(self, images): with torch.no_grad(): features = self.resnet(images) # [B, 2048, 1, 1] features = features.view(features.size(0), -1) # [B, 2048] features = self.linear(features) # [B, embed_size] features = self.bn(features) # [B, embed_size] return features def __init__(self, embed_size): super(EncoderCNN, self).__init__() resnet = resnet50(weights=ResNet50_Weights.DEFAULT) modules = list(resnet.children())[:-1] self.resnet = nn.Sequential(*modules) self.linear = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) def forward(self, images): with torch.no_grad(): features = self.resnet(images).squeeze() features = self.linear(features) features = self.bn(features) return features # =========== # Decoder # =========== class DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): super(DecoderRNN, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) def forward(self, features, captions): embeddings = self.embed(captions) inputs = torch.cat((features.unsqueeze(1), embeddings), 1) hiddens, _ = self.lstm(inputs) outputs = self.linear(hiddens) return outputs def sample(self, features, vocab, max_len=30): """ Generates a caption for given image features using greedy search. """ output_ids = [] states = None inputs = features.unsqueeze(1) # [B, 1, embed_size] for _ in range(max_len): hiddens, states = self.lstm(inputs, states) # [B, 1, hidden] outputs = self.linear(hiddens.squeeze(1)) # [B, vocab_size] predicted = outputs.argmax(1) # [B] output_ids.append(predicted.item()) if vocab.itos[predicted.item()] == "": break inputs = self.embed(predicted).unsqueeze(1) return output_ids # =========== # Predict block # =========== def predict(image_path, checkpoint_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embed_size = 256 hidden_size = 512 num_layers = 1 # === Load checkpoint === checkpoint = torch.load(checkpoint_path, map_location=device) # === Load vocab === vocab = Vocabulary() vocab.load(checkpoint['vocab_stoi'], checkpoint['vocab_itos']) vocab_size = len(vocab.stoi) # === Load models === encoder = EncoderCNN(embed_size).to(device) decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) encoder.eval() decoder.eval() # === Image transform === transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) image = Image.open(image_path).convert("RGB") image = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224] # === Encode === features = encoder(image) # === Decode === output_ids = decoder.sample(features, vocab) # === Convert IDs to words === caption = [] for idx in output_ids: word = vocab.itos[idx] if word == "": break caption.append(word) final_caption = ' '.join(caption) print(f"\nšŸ“ Predicted caption: {final_caption}\n") if __name__ == "__main__": # āœ… Change these! image_path = r"C:\Users\Jayasimma D\Documents\Skin_Disease_Captioning\Dataset\images\train\Albinism\Albinism2.jpg" # šŸ” your test image path checkpoint_path = "./checkpoints/caption_model_epoch5.pth" predict(image_path, checkpoint_path)