File size: 4,074 Bytes
9f93006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import torch.nn as nn
import pickle
from torchvision import models, transforms
from PIL import Image

class Config:
    embed_size = 300
    hidden_size = 512
    num_layers = 1
    feature_dim = 2048

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        x = self.linear(images)
        x = self.bn(x)
        return self.dropout(self.relu(x))

class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        return None

class Seq2Seq(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, feature_dim):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(feature_dim, hidden_size)
        self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers)

device = torch.device("cpu")

with open('vocab_safe.pkl', 'rb') as f:
    vocab_data = pickle.load(f)
itos = vocab_data['itos']
stoi = vocab_data['stoi']
vocab_size = len(itos)

model = Seq2Seq(Config.embed_size, Config.hidden_size, vocab_size, Config.num_layers, Config.feature_dim)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()

resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = nn.Sequential(*list(resnet.children())[:-1]).to(device)
resnet.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

def generate_caption(image):
    try:
        if image is None:
            return "Please upload an image first."

        image = image.convert('RGB')
        img_tensor = transform(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            features = resnet(img_tensor).view(1, -1) 
        
        with torch.no_grad():
            enc_out = model.encoder(features).unsqueeze(0) 
            h, c = enc_out, enc_out
            
            word_idx = stoi['<start>']
            word = torch.tensor(word_idx).view(1).to(device)
            caption = []
            
            for i in range(20):
                embed = model.decoder.embed(word).view(1, 1, -1)
                output, (h, c) = model.decoder.lstm(embed, (h, c))
                prediction = model.decoder.linear(output)
                idx = prediction.argmax(2).item()
                
                if idx == stoi['<end>']:
                    break
                
                word_str = itos.get(idx, "<unk>")
                caption.append(word_str)
                word = torch.tensor(idx).view(1).to(device)
                
        final_caption = " ".join(caption).strip().capitalize()
        if final_caption:
            final_caption += "."
        return final_caption
        
    except Exception as e:
        return f"Error: {str(e)}"

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🖼️ Image Captioning Generator
        Upload an image to generate a descriptive caption.
        """
    )
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            generate_btn = gr.Button("✨ Generate Caption", variant="primary")
        with gr.Column():
            caption_output = gr.Textbox(label="Generated Caption", lines=4, interactive=False)
            
    generate_btn.click(
        fn=generate_caption,
        inputs=image_input,
        outputs=caption_output
    )

if __name__ == "__main__":
    demo.launch()