| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from torchvision import transforms | |
| from PIL import Image | |
| class Vocabulary: | |
| def __init__(self): | |
| self.itos = {} | |
| self.stoi = {} | |
| def load(self, stoi, itos): | |
| self.stoi = stoi | |
| self.itos = itos | |
| class EncoderCNN(nn.Module): | |
| def __init__(self, embed_size): | |
| super(EncoderCNN, self).__init__() | |
| resnet = resnet50(weights=ResNet50_Weights.DEFAULT) | |
| modules = list(resnet.children())[:-1] | |
| self.resnet = nn.Sequential(*modules) | |
| self.linear = nn.Linear(resnet.fc.in_features, embed_size) | |
| self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) | |
| def forward(self, images): | |
| with torch.no_grad(): | |
| features = self.resnet(images) | |
| features = features.view(features.size(0), -1) | |
| features = self.linear(features) | |
| features = self.bn(features) | |
| return features | |
| class DecoderRNN(nn.Module): | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): | |
| super(DecoderRNN, self).__init__() | |
| self.embed = nn.Embedding(vocab_size, embed_size) | |
| self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) | |
| self.linear = nn.Linear(hidden_size, vocab_size) | |
| def forward(self, features, captions): | |
| embeddings = self.embed(captions) | |
| inputs = torch.cat((features.unsqueeze(1), embeddings), 1) | |
| hiddens, _ = self.lstm(inputs) | |
| outputs = self.linear(hiddens) | |
| return outputs | |
| def sample(self, features, vocab, max_len=30): | |
| output_ids = [] | |
| states = None | |
| inputs = features.unsqueeze(1) | |
| for _ in range(max_len): | |
| hiddens, states = self.lstm(inputs, states) | |
| outputs = self.linear(hiddens.squeeze(1)) | |
| predicted = outputs.argmax(1) | |
| output_ids.append(predicted.item()) | |
| if vocab.itos[predicted.item()] == "<end>": | |
| break | |
| inputs = self.embed(predicted).unsqueeze(1) | |
| return output_ids | |
| checkpoint_path = "./checkpoints/caption_model_epoch30.pth" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| embed_size = 256 | |
| hidden_size = 512 | |
| num_layers = 1 | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| vocab = Vocabulary() | |
| vocab.load(checkpoint['vocab_stoi'], checkpoint['vocab_itos']) | |
| vocab_size = len(vocab.stoi) | |
| encoder = EncoderCNN(embed_size).to(device) | |
| decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device) | |
| encoder.load_state_dict(checkpoint['encoder_state_dict']) | |
| decoder.load_state_dict(checkpoint['decoder_state_dict']) | |
| encoder.eval() | |
| decoder.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def generate_caption(image): | |
| image = Image.fromarray(image).convert("RGB") | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| features = encoder(image) | |
| output_ids = decoder.sample(features, vocab) | |
| caption = [] | |
| for idx in output_ids: | |
| word = vocab.itos[idx] | |
| if word == "<end>": | |
| break | |
| caption.append(word) | |
| return ' '.join(caption) | |
| demo = gr.Interface( | |
| fn=generate_caption, | |
| inputs=gr.Image(type="numpy"), | |
| outputs="text", | |
| title="Skin Disease Image Captioning", | |
| description="Upload an image of a skin disease to generate a descriptive caption using your trained model." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |