import gradio as gr import torch import torch.nn as nn from torchvision.models import resnet50, ResNet50_Weights from torchvision import transforms from PIL import Image class Vocabulary: def __init__(self): self.itos = {} self.stoi = {} def load(self, stoi, itos): self.stoi = stoi self.itos = itos class EncoderCNN(nn.Module): def __init__(self, embed_size): super(EncoderCNN, self).__init__() resnet = resnet50(weights=ResNet50_Weights.DEFAULT) modules = list(resnet.children())[:-1] self.resnet = nn.Sequential(*modules) self.linear = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) def forward(self, images): with torch.no_grad(): features = self.resnet(images) features = features.view(features.size(0), -1) features = self.linear(features) features = self.bn(features) return features class DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): super(DecoderRNN, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) def forward(self, features, captions): embeddings = self.embed(captions) inputs = torch.cat((features.unsqueeze(1), embeddings), 1) hiddens, _ = self.lstm(inputs) outputs = self.linear(hiddens) return outputs def sample(self, features, vocab, max_len=30): output_ids = [] states = None inputs = features.unsqueeze(1) for _ in range(max_len): hiddens, states = self.lstm(inputs, states) outputs = self.linear(hiddens.squeeze(1)) predicted = outputs.argmax(1) output_ids.append(predicted.item()) if vocab.itos[predicted.item()] == "": break inputs = self.embed(predicted).unsqueeze(1) return output_ids checkpoint_path = "./checkpoints/caption_model_epoch30.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embed_size = 256 hidden_size = 512 num_layers = 1 checkpoint = torch.load(checkpoint_path, map_location=device) vocab = Vocabulary() vocab.load(checkpoint['vocab_stoi'], checkpoint['vocab_itos']) vocab_size = len(vocab.stoi) encoder = EncoderCNN(embed_size).to(device) decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) encoder.eval() decoder.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) def generate_caption(image): image = Image.fromarray(image).convert("RGB") image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): features = encoder(image) output_ids = decoder.sample(features, vocab) caption = [] for idx in output_ids: word = vocab.itos[idx] if word == "": break caption.append(word) return ' '.join(caption) demo = gr.Interface( fn=generate_caption, inputs=gr.Image(type="numpy"), outputs="text", title="Skin Disease Image Captioning", description="Upload an image of a skin disease to generate a descriptive caption using your trained model." ) if __name__ == "__main__": demo.launch()