Tin113 commited on
Commit
83e0b3c
·
verified ·
1 Parent(s): 66f53bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -252
app.py CHANGED
@@ -1,42 +1,33 @@
 
1
  import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
- import json
7
- import gradio as gr
8
- import os
9
- import sys
10
 
11
  # ============================================================================
12
  # 1. ĐỊNH NGHĨA LẠI CÁC CLASS MODEL (QUAN TRỌNG!)
13
  # (Copy từ code huấn luyện gốc, ĐÃ SỬA Attention theo lỗi trước)
14
  # ============================================================================
15
-
16
  # -----------------------
17
  # Attention Module
18
  # -----------------------
19
  class Attention(nn.Module):
20
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
21
  super(Attention, self).__init__()
22
- # Tên lớp Linear đã được sửa để khớp với file .pth của bạn
23
- self.cnn = nn.Linear(cnn_dim, attention_dim)
24
- self.lstm = nn.Linear(lstm_dim, attention_dim)
25
  self.attn = nn.Linear(attention_dim, 1)
26
- # Giả sử bạn có các lớp này trong code gốc đã dùng để train
27
- self.tanh = nn.Tanh()
28
- self.softmax = nn.Softmax(dim=1)
29
 
30
  def forward(self, cnn_features, lstm_features):
31
- # Sử dụng tên lớp Linear đã sửa
32
- cnn_proj = self.cnn(cnn_features)
33
- lstm_proj = self.lstm(lstm_features)
34
- combined = self.tanh(cnn_proj + lstm_proj) # Broadcasting
35
- attn_logits = self.attn(combined)
36
- attn_weights = self.softmax(attn_logits)
37
- attended_features = (attn_weights * lstm_features).sum(dim=1)
38
  return attended_features
39
-
40
  # -----------------------
41
  # VQA Model
42
  # -----------------------
@@ -46,264 +37,190 @@ class VQAModel(nn.Module):
46
  self.vocab_size = vocab_size
47
  self.max_seq_len = max_seq_len
48
 
49
- # --- CNN Encoder: ĐỔI TÊN TRỞ LẠI THÀNH self.cnn ---
50
- self.cnn = nn.Sequential( # Đổi tên lại thành self.cnn
51
- nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
52
- nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
53
- nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
54
- nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), nn.ReLU(),
 
 
 
 
 
 
 
55
  nn.AdaptiveAvgPool2d((1, 1))
56
  )
57
- # ------------------------------------------------
58
 
 
59
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
 
 
60
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
61
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
 
 
62
  self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
 
 
 
 
 
63
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
64
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
65
  self.fc_out = nn.Linear(lstm_units, vocab_size)
66
  self.dropout = nn.Dropout(0.5)
67
 
68
- # Hàm forward không bị ảnh hưởng vì không gọi trực tiếp
69
  def forward(self, image, caption, question, answer_input):
70
- # Logic forward thể vẫn dùng tên biến local cnn_features
71
- # nhưng self.cnn để gọi mạng Sequential thì đã khớp tên
72
- cnn_features = self.cnn(image) # Gọi self.cnn mới đúng tên
73
- # ... (phần còn lại của forward giữ nguyên) ...
74
- cnn_features = cnn_features.view(cnn_features.size(0), -1)
75
- cap_embed = self.embedding(caption)
76
- cap_output, _ = self.caption_lstm(cap_embed)
77
- q_embed = self.embedding(question)
78
- q_output, _ = self.question_lstm(q_embed)
79
- cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output)
80
- q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
81
- q_last = q_output[:, -1, :]
82
- context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
83
- answer_embed = self.embedding(answer_input)
84
- context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
85
- decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
86
- decoder_in = self.decoder_input_proj(decoder_in)
87
- decoder_output, _ = self.decoder_lstm(decoder_in)
88
- output = self.fc_out(self.dropout(decoder_output))
89
- return output
90
 
91
- # ----------------------------------------------------------------------------
 
 
92
 
93
- # ============================================================================
94
- # 2. CẤU HÌNH LOAD MODEL/VOCAB
95
- # ============================================================================
96
- # !! THAY ĐỔI TÊN FILE NẾU CẦN !!
97
- MODEL_PATH = "vqa_model.pth" # Đảm bảo tên này khớp file bạn upload
98
- VOCAB_PATH = "vqa_custom_cnn_vocab.json" # Đảm bảo tên này khớp file bạn upload
99
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
 
101
- # --- Hàm load ---
102
- def load_model_and_vocab(model_path, vocab_path, device):
103
- print(f"Attempting to load vocabulary from: {vocab_path}")
104
- if not os.path.exists(vocab_path):
105
- print(f"Error: Vocabulary file not found at {vocab_path}")
106
- return None, None, None
107
- try:
108
- with open(vocab_path, 'r') as f:
109
- vocab_data = json.load(f)
110
- word_to_idx = vocab_data['word_to_idx']
111
- # Chuyển key của idx_to_word thành integer để tra cứu bằng index
112
- idx_to_word = {int(k): v for k, v in vocab_data['idx_to_word'].items()}
113
- vocab_size = len(word_to_idx)
114
- print(f"Vocabulary loaded successfully. Size: {vocab_size}")
115
- except Exception as e:
116
- print(f"Error loading or processing vocabulary: {e}")
117
- return None, None, None
118
 
119
- print(f"Attempting to load model from: {model_path}")
120
- if not os.path.exists(model_path):
121
- print(f"Error: Model file not found at {model_path}")
122
- return None, None, None
123
- try:
124
- # Khởi tạo model với các tham số chính xác
125
- # Lấy các giá trị này từ lúc bạn huấn luyện model gốc
126
- model = VQAModel(vocab_size=vocab_size,
127
- embedding_dim=256, # Xác nhận giá trị này
128
- lstm_units=256, # Xác nhận giá trị này
129
- cnn_output_dim=512, # Xác nhận giá trị này
130
- attention_dim=256, # Xác nhận giá trị này
131
- max_seq_len=30) # Xác nhận giá trị này
132
 
133
- model.load_state_dict(torch.load(model_path, map_location=device))
134
- model.to(device)
135
- model.eval() # Quan trọng: Chuyển sang chế độ đánh giá
136
- print(f"Model loaded successfully from {model_path} to {device}")
137
- return model, word_to_idx, idx_to_word
138
- except Exception as e:
139
- print(f"Error loading model state_dict: {e}")
140
- # Có thể in traceback để debug kỹ hơn nếu cần
141
- # import traceback
142
- # traceback.print_exc()
143
- return None, None, None
144
 
145
- # --- Load model vocab một lần khi app khởi động ---
146
- model, word_to_idx, idx_to_word = load_model_and_vocab(MODEL_PATH, VOCAB_PATH, DEVICE)
 
 
 
 
147
 
148
- # ============================================================================
149
- # 3. ĐỊNH NGHĨA TRANSFORM (Lấy từ hàm train_vqa của bạn)
150
- # ============================================================================
151
- # Đảm bảo transform này giống hệt lúc bạn huấn luyện
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  transform = transforms.Compose([
153
  transforms.Resize((224, 224)),
154
  transforms.ToTensor(),
155
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
156
  ])
157
 
158
- # ============================================================================
159
- # 4. HÀM DỰ ĐOÁN CHO GRADIO (BẮT CHƯỚC LOGIC model.predict GỐC)
160
- # ============================================================================
161
- # Hàm này sẽ được gọi bởi Gradio Interface
162
- def predict_vqa(image, question_str):
163
- print("--- Received request ---")
164
- if model is None or word_to_idx is None:
165
- print("Error: Model or vocabulary not loaded.")
166
- return "Lỗi: Model hoặc từ điển chưa được tải."
167
- if image is None:
168
- print("Error: No image provided.")
169
- return "Lỗi: Vui lòng cung cấp ảnh."
170
- if not question_str or not question_str.strip():
171
- print("Error: No question provided.")
172
- return "Lỗi: Vui lòng nhập câu hỏi."
173
-
174
- print(f"Input question: {question_str}")
175
-
176
- # --- 1. Tiền xử ảnh ---
177
- try:
178
- image_tensor = transform(image).unsqueeze(0).to(DEVICE)
179
- print(f"Image transformed, shape: {image_tensor.shape}")
180
- except Exception as e:
181
- print(f"Error transforming image: {e}")
182
- return f"Lỗi xử lý ảnh: {e}"
183
-
184
- # --- 2. Tiền xử lý câu hỏi ---
185
- try:
186
- question_tokens = question_str.lower().split()
187
- unk_idx = word_to_idx.get('<UNK>', word_to_idx.get('<PAD>', 0))
188
- question_seq = [word_to_idx.get(word, unk_idx) for word in question_tokens]
189
- if not question_seq: question_seq = [unk_idx] # Tránh sequence rỗng
190
- question_tensor = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
191
- print(f"Question tensor created, shape: {question_tensor.shape}")
192
- except Exception as e:
193
- print(f"Error processing question: {e}")
194
- return f"Lỗi xử lý câu hỏi: {e}"
195
-
196
- # --- 3. Chạy Inference (Logic từ model.predict gốc) ---
197
- start_token_idx = word_to_idx['<START>']
198
- end_token_idx = word_to_idx['<END>']
199
- max_len = model.max_seq_len # Lấy max_len từ model đã load
200
- generated_indices = []
201
- # Bắt đầu decoder input với <START> token
202
- decoder_input_tensor = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE)
203
- # Hidden state của decoder LSTM (khởi tạo là None, giống predict gốc)
204
- hidden_state = None
205
-
206
- try:
207
- with torch.no_grad(): # Tắt gradient calculation
208
- print("Encoding image...")
209
- # Sử dụng self.cnn_net thay vì self.cnn
210
-
211
- cnn_features = model.cnn(image_tensor)
212
- cnn_features = cnn_features.view(cnn_features.size(0), -1)
213
- print(f"CNN features shape: {cnn_features.shape}")
214
-
215
- print("Encoding question...")
216
- q_embed = model.embedding(question_tensor)
217
- q_output, _ = model.question_lstm(q_embed) # (1, q_seq_len, lstm_units)
218
- print(f"Question LSTM output shape: {q_output.shape}")
219
-
220
- print("Calculating attention...")
221
- # Chú ý unsqueeze(1) cho cnn_features khi đưa vào attention
222
- q_attended = model.attention(cnn_features.unsqueeze(1), q_output) # (1, lstm_units)
223
- q_last = q_output[:, -1, :] # (1, lstm_units)
224
- print(f"Attended question features shape: {q_attended.shape}")
225
- print(f"Last question LSTM state shape: {q_last.shape}")
226
-
227
- # --- Context Vector THEO LOGIC model.predict GỐC ---
228
- context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
229
- print(f"Context vector shape: {context.shape}")
230
-
231
- print("Starting decoder loop...")
232
- for i in range(max_len):
233
- print(f"Decoder step {i+1}/{max_len}")
234
- current_word_embed = model.embedding(decoder_input_tensor) # (1, 1, embedding_dim)
235
-
236
- # Context cần unsqueeze để có chiều seq_len=1 trước khi repeat/cat
237
- context_repeated = context.unsqueeze(1) # (1, 1, 3*lstm_units)
238
-
239
- # Input cho lớp chiếu của decoder
240
- decoder_proj_input = torch.cat([current_word_embed, context_repeated], dim=-1)
241
- decoder_lstm_input = model.decoder_input_proj(decoder_proj_input) # (1, 1, lstm_units)
242
-
243
- # Chạy Decoder LSTM
244
- decoder_output, hidden_state = model.decoder_lstm(decoder_lstm_input, hidden_state) # hidden_state được cập nhật
245
-
246
- # Lấy Logits từ output của step này
247
- output_logits = model.fc_out(decoder_output.squeeze(1)) # (1, vocab_size)
248
- predicted_idx = output_logits.argmax(dim=-1).item()
249
- print(f"Predicted index: {predicted_idx}")
250
-
251
- if predicted_idx == end_token_idx:
252
- print("End token detected.")
253
- break
254
- generated_indices.append(predicted_idx)
255
- # Input cho bước tiếp theo là từ vừa dự đoán
256
- decoder_input_tensor = torch.tensor([[predicted_idx]], dtype=torch.long).to(DEVICE)
257
- print("Decoder loop finished.")
258
-
259
- except Exception as e:
260
- print(f"Error during model inference: {e}")
261
- # In traceback đầy đủ để debug
262
- import traceback
263
- traceback.print_exc()
264
- return f"Lỗi trong quá trình dự đoán: {e}"
265
-
266
- # --- 4. Decode Output ---
267
- try:
268
- answer_words = [idx_to_word.get(idx, '<UNK>') for idx in generated_indices]
269
- final_answer = ' '.join(answer_words) if answer_words else "(Không tạo được câu trả lời)"
270
- print(f"Decoded answer: {final_answer}")
271
- return final_answer
272
- except Exception as e:
273
- print(f"Error decoding answer: {e}")
274
- return f"Lỗi giải mã câu trả lời: {e}"
275
-
276
- # ============================================================================
277
- # 5. TẠO GRADIO INTERFACE (Đảm bảo ở global scope)
278
- # ============================================================================
279
- # Chỉ định nghĩa interface nếu model đã load thành công
280
- if model is not None and word_to_idx is not None:
281
- print("Defining Gradio interface...")
282
- title = "VQA for Animal"
283
- description = "Tải lên ảnh con vật và nhập câu hỏi để nhận câu trả lời. (CHỈ HỖ TRỢ TIẾNG ANH)"
284
- # examples = [ # Optional: Thêm ví dụ nếu bạn upload ảnh tương ứng
285
- # ["zebra.jpg", "what animal is this?"]
286
- # ]
287
-
288
- # Định nghĩa Interface ở global scope
289
  iface = gr.Interface(
290
- fn=predict_vqa,
291
  inputs=[
292
- gr.Image(type="pil", label="Image"), # Input là PIL Image
293
- gr.Textbox(lines=2, placeholder="Enter question here...", label="Câu hỏi")
294
  ],
295
- outputs=gr.Textbox(label="Câu trả lời"),
296
- title=title,
297
- description=description,
298
- # examples=examples,
299
- allow_flagging='never' # Tắt flagging
300
  )
301
- print("Gradio interface defined.")
302
- else:
303
- print("Skipping Gradio interface definition due to load errors.")
304
- # Có thể định nghĩa một interface báo lỗi nếu muốn
305
- def error_interface(*args):
306
- return "Lỗi nghiêm trọng: Không thể tải model hoặc từ điển. Vui lòng kiểm tra logs của Space."
307
- iface = gr.Interface(fn=error_interface, inputs=[], outputs="text", title="Lỗi Load Model")
308
-
309
- # Không cần if __name__ == "__main__": iface.launch() cho Spaces
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ import json
 
4
  from torchvision import transforms
5
  from PIL import Image
6
+ import numpy as np
 
 
 
7
 
8
  # ============================================================================
9
  # 1. ĐỊNH NGHĨA LẠI CÁC CLASS MODEL (QUAN TRỌNG!)
10
  # (Copy từ code huấn luyện gốc, ĐÃ SỬA Attention theo lỗi trước)
11
  # ============================================================================
 
12
  # -----------------------
13
  # Attention Module
14
  # -----------------------
15
  class Attention(nn.Module):
16
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
17
  super(Attention, self).__init__()
18
+ self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
19
+ self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
 
20
  self.attn = nn.Linear(attention_dim, 1)
 
 
 
21
 
22
  def forward(self, cnn_features, lstm_features):
23
+ # cnn_features: (batch, 1, cnn_dim)
24
+ # lstm_features: (batch, seq_len, lstm_dim)
25
+ cnn_proj = self.cnn_proj(cnn_features) # (batch, 1, attention_dim)
26
+ lstm_proj = self.lstm_proj(lstm_features) # (batch, seq_len, attention_dim)
27
+ combined = torch.tanh(cnn_proj + lstm_proj) # (batch, seq_len, attention_dim)
28
+ attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1)
29
+ attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim)
30
  return attended_features
 
31
  # -----------------------
32
  # VQA Model
33
  # -----------------------
 
37
  self.vocab_size = vocab_size
38
  self.max_seq_len = max_seq_len
39
 
40
+ # CNN Encoder: Trích xuất đặc trưng ảnh
41
+ self.cnn = nn.Sequential(
42
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
43
+ nn.ReLU(),
44
+ nn.MaxPool2d(2),
45
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
46
+ nn.ReLU(),
47
+ nn.MaxPool2d(2),
48
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
49
+ nn.ReLU(),
50
+ nn.MaxPool2d(2),
51
+ nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1),
52
+ nn.ReLU(),
53
  nn.AdaptiveAvgPool2d((1, 1))
54
  )
 
55
 
56
+ # Text Embedding
57
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
58
+
59
+ # LSTM Encoders cho caption và question
60
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
61
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
62
+
63
+ # Attention cho từng kênh
64
  self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
65
+
66
+ # Decoder: sử dụng teacher forcing
67
+ # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question
68
+ # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256)
69
+ # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units
70
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
71
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
72
  self.fc_out = nn.Linear(lstm_units, vocab_size)
73
  self.dropout = nn.Dropout(0.5)
74
 
 
75
  def forward(self, image, caption, question, answer_input):
76
+ # --- CNN Encoder ---
77
+ cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1)
78
+ cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # --- Text Encoders ---
81
+ cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim)
82
+ cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units)
83
 
84
+ q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
85
+ q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
 
 
 
 
 
86
 
87
+ # --- Attention ---
88
+ cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units)
89
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ q_last = q_output[:, -1, :] # (batch, lstm_units)
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # Context vector: (batch, 3*lstm_units)
94
+ context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
 
 
 
 
 
 
 
 
 
95
 
96
+ # --- Decoder với Teacher Forcing ---
97
+ # answer_input: (batch, ans_seq_len)
98
+ answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
99
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units)
100
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
101
+ decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
102
 
103
+ decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
104
+ output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
105
+ return output
106
+
107
+ def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
108
+ self.eval()
109
+ self.to(device)
110
+ # Kiểm tra nếu image không có batch dimension thì thêm
111
+ if image.dim() == 3:
112
+ image = image.unsqueeze(0)
113
+ image = image.to(device)
114
+
115
+ question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
116
+ question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
117
+
118
+ # Encode image và question
119
+ cnn_features = self.cnn(image)
120
+ cnn_features = cnn_features.view(cnn_features.size(0), -1)
121
+ q_embed = self.embedding(question)
122
+ q_output, _ = self.question_lstm(q_embed)
123
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
124
+ q_last = q_output[:, -1, :]
125
+
126
+ # Ở predict, sử dụng một context vector đơn giản từ question (hoặc kết hợp với các thành phần khác nếu có)
127
+ context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
128
+
129
+ # Khởi tạo câu trả lời với token <START>
130
+ answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
131
+ answer_words = []
132
+
133
+ hidden = None
134
+ for _ in range(self.max_seq_len):
135
+ answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim)
136
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
137
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
138
+ decoder_in = self.decoder_input_proj(decoder_in)
139
+ decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
140
+ output = self.fc_out(decoder_output[:, -1, :])
141
+ next_word_idx = output.argmax(dim=-1).item()
142
+ if next_word_idx == word_to_idx['<END>']:
143
+ break
144
+ answer_words.append(idx_to_word[next_word_idx])
145
+ answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1)
146
+
147
+ return ' '.join(answer_words)
148
+
149
+
150
+
151
+
152
+ # Hàm load mô hình
153
+ def load_model(model_path, word_to_idx_path, idx_to_word_path, device='cpu'):
154
+ # Load từ điển
155
+ with open(word_to_idx_path, 'r') as f:
156
+ word_to_idx = json.load(f)
157
+
158
+ with open(idx_to_word_path, 'r') as f:
159
+ idx_to_word = json.load(f)
160
+
161
+ # Khởi tạo mô hình
162
+ model = VQAModel(vocab_size=len(word_to_idx))
163
+ model.load_state_dict(torch.load(model_path, map_location=device))
164
+ model.to(device)
165
+ model.eval()
166
+
167
+ return model, word_to_idx, idx_to_word
168
+
169
+ # Transform ảnh
170
  transform = transforms.Compose([
171
  transforms.Resize((224, 224)),
172
  transforms.ToTensor(),
173
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
174
  ])
175
 
176
+ # Hàm dự đoán
177
+ def predict(image, question, model, word_to_idx, idx_to_word, device='cpu'):
178
+ # Chuyển đổi ảnh
179
+ image = transform(image).unsqueeze(0).to(device)
180
+
181
+ # Dự đoán
182
+ answer = model.predict(image, question, word_to_idx, idx_to_word, device)
183
+ return answer
184
+
185
+
186
+ # Tạo giao diện Gradio
187
+ def create_interface(model, word_to_idx, idx_to_word, device='cpu'):
188
+ def vqa_interface(image, question):
189
+ answer = predict(image, question, model, word_to_idx, idx_to_word, device)
190
+ return answer
191
+
192
+ examples = [
193
+ ["example1.jpg", "What color is the animal?"],
194
+ ["example2.jpg", "Is this a cat or a dog?"]
195
+ ]
196
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  iface = gr.Interface(
198
+ fn=vqa_interface,
199
  inputs=[
200
+ gr.Image(type="pil", label="Upload an image"),
201
+ gr.Textbox(label="Ask a question about the image")
202
  ],
203
+ outputs=gr.Textbox(label="Answer"),
204
+ examples=examples,
205
+ title="Visual Question Answering System",
206
+ description="Upload an image and ask a question about it. The model will try to answer."
 
207
  )
208
+
209
+ return iface
210
+
211
+
212
+ # Main
213
+ if __name__ == "__main__":
214
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
215
+
216
+ # Load hình
217
+ model, word_to_idx, idx_to_word = load_model(
218
+ "vqa_model.pth",
219
+ "word_to_idx.json",
220
+ "idx_to_word.json",
221
+ device
222
+ )
223
+
224
+ # Tạo và chạy giao diện
225
+ iface = create_interface(model, word_to_idx, idx_to_word, device)
226
+ iface.launch()