File size: 6,383 Bytes
cd94662
 
 
 
84165ce
cd94662
f097a6b
 
 
 
cd94662
 
 
 
84165ce
cd94662
84165ce
cd94662
 
 
 
 
84165ce
 
 
 
 
cd94662
84165ce
624c509
 
 
2fccb7e
624c509
 
cd94662
 
 
84165ce
f097a6b
 
 
84165ce
cd94662
 
 
f097a6b
cd94662
 
f097a6b
624c509
cd94662
624c509
 
cd94662
 
 
 
84165ce
624c509
84165ce
 
 
 
 
 
 
 
 
 
f097a6b
84165ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd94662
84165ce
f097a6b
84165ce
 
 
cd94662
84165ce
 
cd94662
2fccb7e
b2e80e9
cd94662
 
 
 
f097a6b
84165ce
 
 
 
f097a6b
cd94662
84165ce
 
 
 
 
 
 
 
 
 
cd94662
 
 
84165ce
f097a6b
cd94662
 
f097a6b
cd94662
84165ce
cd94662
 
 
f097a6b
cd94662
84165ce
f097a6b
 
84165ce
cd94662
84165ce
cd94662
84165ce
f097a6b
84165ce
 
 
 
cd94662
84165ce
 
 
cd94662
f097a6b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import warnings

# Tắt cảnh báo không cần thiết
warnings.filterwarnings("ignore", category=UserWarning)

# -----------------------
# Attention Module
# -----------------------
class Attention_PT(nn.Module):
    def __init__(self, cnn_dim, lstm_dim, attention_dim):
        super(Attention_PT, self).__init__()
        self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
        self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
        self.attn = nn.Linear(attention_dim, 1)

    def forward(self, cnn_features, lstm_features):
        cnn_proj = self.cnn_proj(cnn_features)
        lstm_proj = self.lstm_proj(lstm_features)
        combined = torch.tanh(cnn_proj + lstm_proj)
        attn_weights = F.softmax(self.attn(combined), dim=1)
        attended_features = (attn_weights * lstm_features).sum(dim=1)
        return attended_features

# -----------------------
# Pre-trained VQA Model
# -----------------------
class PretrainedVQAModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, attention_dim=256, max_seq_len=30):
        super(PretrainedVQAModel, self).__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len

        # Pre-trained CNN Encoder (ResNet18)
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])
        self.cnn_output_dim = 512

        # Text Embedding
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # LSTM Encoder
        self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)

        # Attention
        self.attention = Attention_PT(self.cnn_output_dim, lstm_units, attention_dim)

        # Decoder
        self.decoder_input_proj = nn.Linear(embedding_dim + 2 * lstm_units, lstm_units)
        self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
        self.fc_out = nn.Linear(lstm_units, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def predict(self, image, question, word_to_idx, idx_to_word, device='cpu'):
        self.eval()
        with torch.no_grad():
            if image.dim() == 3:
                image = image.unsqueeze(0)
            image = image.to(device)

            # Process question
            question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) 
                          for word in question.lower().split()]
            question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)

            # Forward pass
            cnn_features = self.cnn(image).view(-1, self.cnn_output_dim)
            q_embed = self.embedding(question)
            q_output, _ = self.question_lstm(q_embed)
            q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
            q_last = q_output[:, -1, :]
            context = torch.cat([q_attended, q_last], dim=-1)

            # Generate answer
            answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
            answer_words = []
            
            for _ in range(self.max_seq_len):
                answer_embed = self.embedding(answer_input)
                context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
                decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
                decoder_in = self.decoder_input_proj(decoder_in)
                decoder_output, _ = self.decoder_lstm(decoder_in)
                output = self.fc_out(decoder_output[:, -1, :])
                next_word_idx = output.argmax(dim=-1).item()
                
                if next_word_idx == word_to_idx['<END>']:
                    break
                    
                answer_words.append(idx_to_word[str(next_word_idx)])
                answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device)
            
            return ' '.join(answer_words)

# -----------------------
# Model Loader
# -----------------------
def load_model():
    device = 'cpu'
    try:
        word_to_idx = torch.load("word_to_idx.pth", map_location=device)
        idx_to_word = torch.load("idx_to_word.pth", map_location=device)
        
        model = PretrainedVQAModel(vocab_size=len(word_to_idx))
        model.load_state_dict(torch.load("vqa_pretrain_model.pth", map_location=device))
        model.to(device)
        model.eval()
        return model, word_to_idx, idx_to_word
    except Exception as e:
        raise RuntimeError(f"Model loading failed: {str(e)}")

# -----------------------
# Gradio Interface
# -----------------------
def create_app():
    try:
        model, word_to_idx, idx_to_word = load_model()
        
        def preprocess_image(image):
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
            return transform(image).unsqueeze(0)
        
        def predict(image, question):
            try:
                image_tensor = preprocess_image(image)
                answer = model.predict(image_tensor, question, word_to_idx, idx_to_word)
                return answer
            except Exception as e:
                return f"Error: {str(e)}"
        
        return gr.Interface(
            fn=predict,
            inputs=[
                gr.Image(type="pil", label="Upload Image"),
                gr.Textbox(label="Your Question", placeholder="What is in this image?")
            ],
            outputs=gr.Textbox(label="Generated Answer"),
            title="Visual Question Answering",
            description="Upload an image and ask questions about its content",
            allow_flagging="never"
        )
        
    except Exception as e:
        return gr.Interface(
            lambda: f"Initialization failed: {str(e)}",
            inputs=None,
            outputs="text",
            title="Error"
        )

# -----------------------
# Main Execution
# -----------------------
if __name__ == "__main__":
    app = create_app()
    app.launch(server_name="0.0.0.0", server_port=7860)