import gradio as gr import torch import torch.nn as nn import pickle from torchvision import models, transforms from PIL import Image class Config: embed_size = 300 hidden_size = 512 num_layers = 1 feature_dim = 2048 class Encoder(nn.Module): def __init__(self, input_dim, hidden_dim): super(Encoder, self).__init__() self.linear = nn.Linear(input_dim, hidden_dim) self.bn = nn.BatchNorm1d(hidden_dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) def forward(self, images): x = self.linear(images) x = self.bn(x) return self.dropout(self.relu(x)) class Decoder(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers): super(Decoder, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) def forward(self, features, captions): return None class Seq2Seq(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers, feature_dim): super(Seq2Seq, self).__init__() self.encoder = Encoder(feature_dim, hidden_size) self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers) device = torch.device("cpu") with open('vocab_safe.pkl', 'rb') as f: vocab_data = pickle.load(f) itos = vocab_data['itos'] stoi = vocab_data['stoi'] vocab_size = len(itos) model = Seq2Seq(Config.embed_size, Config.hidden_size, vocab_size, Config.num_layers, Config.feature_dim) model.load_state_dict(torch.load('best_model.pth', map_location=device)) model.eval() resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) resnet = nn.Sequential(*list(resnet.children())[:-1]).to(device) resnet.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) def generate_caption(image): try: if image is None: return "Please upload an image first." image = image.convert('RGB') img_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): features = resnet(img_tensor).view(1, -1) with torch.no_grad(): enc_out = model.encoder(features).unsqueeze(0) h, c = enc_out, enc_out word_idx = stoi[''] word = torch.tensor(word_idx).view(1).to(device) caption = [] for i in range(20): embed = model.decoder.embed(word).view(1, 1, -1) output, (h, c) = model.decoder.lstm(embed, (h, c)) prediction = model.decoder.linear(output) idx = prediction.argmax(2).item() if idx == stoi['']: break word_str = itos.get(idx, "") caption.append(word_str) word = torch.tensor(idx).view(1).to(device) final_caption = " ".join(caption).strip().capitalize() if final_caption: final_caption += "." return final_caption except Exception as e: return f"Error: {str(e)}" with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🖼️ Image Captioning Generator Upload an image to generate a descriptive caption. """ ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") generate_btn = gr.Button("✨ Generate Caption", variant="primary") with gr.Column(): caption_output = gr.Textbox(label="Generated Caption", lines=4, interactive=False) generate_btn.click( fn=generate_caption, inputs=image_input, outputs=caption_output ) if __name__ == "__main__": demo.launch()