File size: 9,976 Bytes
83e0b3c
0b67e8a
59805ac
 
83e0b3c
0b67e8a
c5676e0
83e0b3c
0b67e8a
313bd15
c574d51
313bd15
 
 
4e6e66e
 
 
83e0b3c
 
4e6e66e
 
 
83e0b3c
 
 
 
 
 
 
4e6e66e
76b9d33
 
 
4029376
b25b382
4029376
b25b382
bcf968d
4029376
83e0b3c
 
 
 
 
 
 
 
 
 
 
 
 
b25b382
 
bcf968d
83e0b3c
b25b382
83e0b3c
 
b25b382
 
83e0b3c
 
b25b382
83e0b3c
 
 
 
 
b25b382
 
 
bcf968d
b25b382
 
83e0b3c
 
 
bcf968d
83e0b3c
 
 
c5676e0
83e0b3c
 
c5676e0
83e0b3c
 
 
c5676e0
83e0b3c
c5676e0
83e0b3c
 
c5676e0
83e0b3c
 
 
 
 
 
c5676e0
83e0b3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d88eb93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83e0b3c
d88eb93
c5676e0
 
 
 
 
d88eb93
8f08c02
83e0b3c
d88eb93
 
 
 
 
 
 
 
8f08c02
 
 
 
 
 
 
 
 
 
 
 
 
d88eb93
8f08c02
 
d88eb93
8f08c02
 
d88eb93
 
c574d51
39d6d29
d88eb93
8f08c02
d88eb93
8f08c02
83e0b3c
 
d88eb93
8f08c02
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from torchvision import transforms
from PIL import Image
import numpy as np

# ============================================================================

# -----------------------
# Attention Module
# -----------------------
class Attention(nn.Module):
    def __init__(self, cnn_dim, lstm_dim, attention_dim):
        super(Attention, self).__init__()
        self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
        self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
        self.attn = nn.Linear(attention_dim, 1)

    def forward(self, cnn_features, lstm_features):
        # cnn_features: (batch, 1, cnn_dim)
        # lstm_features: (batch, seq_len, lstm_dim)
        cnn_proj = self.cnn_proj(cnn_features)  # (batch, 1, attention_dim)
        lstm_proj = self.lstm_proj(lstm_features)  # (batch, seq_len, attention_dim)
        combined = torch.tanh(cnn_proj + lstm_proj)  # (batch, seq_len, attention_dim)
        attn_weights = F.softmax(self.attn(combined), dim=1)  # (batch, seq_len, 1)
        attended_features = (attn_weights * lstm_features).sum(dim=1)  # (batch, lstm_dim)
        return attended_features
# -----------------------
# VQA Model
# -----------------------
class VQAModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
        super(VQAModel, self).__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len

        # CNN Encoder: Trích xuất đặc trưng ảnh
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        # Text Embedding
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # LSTM Encoders cho caption và question
        self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
        self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)

        # Attention cho từng kênh
        self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)

        # Decoder: sử dụng teacher forcing
        # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question
        # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256)
        # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units
        self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
        self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
        self.fc_out = nn.Linear(lstm_units, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, image, caption, question, answer_input):
        # --- CNN Encoder ---
        cnn_features = self.cnn(image)  # (batch, cnn_output_dim, 1, 1)
        cnn_features = cnn_features.view(cnn_features.size(0), -1)  # (batch, cnn_output_dim)

        # --- Text Encoders ---
        cap_embed = self.embedding(caption)  # (batch, cap_seq_len, embedding_dim)
        cap_output, _ = self.caption_lstm(cap_embed)  # (batch, cap_seq_len, lstm_units)

        q_embed = self.embedding(question)  # (batch, q_seq_len, embedding_dim)
        q_output, _ = self.question_lstm(q_embed)  # (batch, q_seq_len, lstm_units)

        # --- Attention ---
        cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output)  # (batch, lstm_units)
        q_attended = self.attention(cnn_features.unsqueeze(1), q_output)      # (batch, lstm_units)

        q_last = q_output[:, -1, :]  # (batch, lstm_units)

        # Context vector: (batch, 3*lstm_units)
        context = torch.cat([cap_attended, q_attended, q_last], dim=-1)

        # --- Decoder với Teacher Forcing ---
        # answer_input: (batch, ans_seq_len)
        answer_embed = self.embedding(answer_input)  # (batch, ans_seq_len, embedding_dim)
        context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)  # (batch, ans_seq_len, 3*lstm_units)
        decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)  # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
        decoder_in = self.decoder_input_proj(decoder_in)  # (batch, ans_seq_len, lstm_units)

        decoder_output, _ = self.decoder_lstm(decoder_in)  # (batch, ans_seq_len, lstm_units)
        output = self.fc_out(self.dropout(decoder_output))  # (batch, ans_seq_len, vocab_size)
        return output

    def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
      self.eval()
      self.to(device)
      # Kiểm tra nếu image không có batch dimension thì thêm
      if image.dim() == 3:
          image = image.unsqueeze(0)
      image = image.to(device)

      question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
      question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)

      # Encode image và question
      cnn_features = self.cnn(image)
      cnn_features = cnn_features.view(cnn_features.size(0), -1)
      q_embed = self.embedding(question)
      q_output, _ = self.question_lstm(q_embed)
      q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
      q_last = q_output[:, -1, :]

      # Ở predict, sử dụng một context vector đơn giản từ question (hoặc kết hợp với các thành phần khác nếu có)
      context = torch.cat([q_attended, q_attended, q_last], dim=-1)  # (1, 3*lstm_units)

      # Khởi tạo câu trả lời với token <START>
      answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
      answer_words = []

      hidden = None
      for _ in range(self.max_seq_len):
          answer_embed = self.embedding(answer_input)  # (1, seq_len, embedding_dim)
          context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
          decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
          decoder_in = self.decoder_input_proj(decoder_in)
          decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
          output = self.fc_out(decoder_output[:, -1, :])
          next_word_idx = output.argmax(dim=-1).item()
          if next_word_idx == word_to_idx['<END>']:
              break
          answer_words.append(idx_to_word[next_word_idx])
          answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1)

      return ' '.join(answer_words)



        
def load_model(model_path, word_to_idx_path, idx_to_word_path, device='cpu'):
    try:
        # Load từ điển từ file .pth
        word_to_idx = torch.load(word_to_idx_path, map_location=device)
        idx_to_word = torch.load(idx_to_word_path, map_location=device)
        
        # Khởi tạo mô hình
        model = VQAModel(vocab_size=len(word_to_idx))
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        
        return model, word_to_idx, idx_to_word
    except Exception as e:
        print(f"Error loading model: {e}")
        raise
        
def predict(image, question, model, word_to_idx, idx_to_word, device='cpu'):
    try:
        # Chuyển đổi ảnh
        image = transform(image).unsqueeze(0).to(device)
        
        # Dự đoán
        answer = model.predict(image, question, word_to_idx, idx_to_word, device)
        return answer
    except Exception as e:
        print(f"Prediction error: {e}")
        return "Error generating answer"

# Tạo transform cho ảnh
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def create_interface():
    device = 'cpu'  # Luôn dùng CPU trên Spaces
    
    try:
        model, word_to_idx, idx_to_word = load_model(
            "vqa_model.pth",
            "word_to_idx.pth",
            "idx_to_word.pth",
            device
        )
        
        def predict(image, question):
            try:
                transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                      std=[0.229, 0.224, 0.225])
                ])
                image = transform(image).unsqueeze(0).to(device)
                answer = model.predict(image, question, word_to_idx, idx_to_word, device)
                return answer
            except Exception as e:
                return f"Error: {str(e)}"
        
        iface = gr.Interface(
            fn=predict,
            inputs=[
                gr.Image(type="pil", label="Upload Image"),
                gr.Textbox(label="Question")
            ],
            outputs=gr.Textbox(label="Answer"),
            title="Visual Question Answering",
            description="Tải ảnh về động vật lên và đặt câu hỏi liên quan (CHỈ HỖ TRỢ TIẾNG ANH)"
        )
        return iface
    except Exception as e:
        return gr.Interface(lambda: "Model failed to load", None, "text")

if __name__ == "__main__":
    iface = create_interface()
    iface.launch(
        server_name="0.0.0.0",
        server_port=7860
    )