Tin113 commited on
Commit
c5676e0
·
verified ·
1 Parent(s): f2c0061

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -154
app.py CHANGED
@@ -1,198 +1,274 @@
1
  import torch
2
- import gradio as gr
3
- from PIL import Image
4
  from torchvision import transforms
 
 
 
 
 
5
 
6
- import torch.nn as nn
7
- import torchvision.models as models
 
 
 
 
8
 
9
- # -----------------------
10
- # Attention Module
11
- # -----------------------
12
  class Attention(nn.Module):
 
13
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
14
  super(Attention, self).__init__()
15
- self.cnn = nn.Linear(cnn_dim, attention_dim)
16
- self.lstm = nn.Linear(lstm_dim, attention_dim)
17
  self.attn = nn.Linear(attention_dim, 1)
 
 
 
18
 
19
  def forward(self, cnn_features, lstm_features):
20
- # cnn_features: (batch, 1, cnn_dim)
21
- # lstm_features: (batch, seq_len, lstm_dim)
22
- cnn = self.cnn(cnn_features) # (batch, 1, attention_dim)
23
- lstm = self.lstm(lstm_features) # (batch, seq_len, attention_dim)
24
- combined = torch.tanh(cnn + lstm) # (batch, seq_len, attention_dim)
25
- attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1)
26
- attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim)
27
  return attended_features
28
-
29
- # -----------------------
30
- # VQA Model
31
- # -----------------------
32
  class VQAModel(nn.Module):
 
 
33
  def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
34
  super(VQAModel, self).__init__()
35
  self.vocab_size = vocab_size
36
  self.max_seq_len = max_seq_len
37
 
38
- # CNN Encoder: Trích xuất đặc trưng ảnh
39
  self.cnn = nn.Sequential(
40
- nn.Conv2d(3, 32, kernel_size=3, padding=1),
41
- nn.ReLU(),
42
- nn.MaxPool2d(2),
43
- nn.Conv2d(32, 64, kernel_size=3, padding=1),
44
- nn.ReLU(),
45
- nn.MaxPool2d(2),
46
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
47
- nn.ReLU(),
48
- nn.MaxPool2d(2),
49
- nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1),
50
- nn.ReLU(),
51
  nn.AdaptiveAvgPool2d((1, 1))
52
  )
53
-
54
- # Text Embedding
55
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
56
-
57
- # LSTM Encoders cho caption và question
58
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
59
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
60
-
61
- # Attention cho từng kênh
62
  self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
63
-
64
- # Decoder: sử dụng teacher forcing
65
- # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question
66
- # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256)
67
- # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units
68
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
69
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
70
  self.fc_out = nn.Linear(lstm_units, vocab_size)
71
- self.dropout = nn.Dropout(0.5)
72
 
 
 
73
  def forward(self, image, caption, question, answer_input):
74
- # --- CNN Encoder ---
75
- cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1)
76
- cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
77
-
78
- # --- Text Encoders ---
79
- cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim)
80
- cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units)
81
-
82
- q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
83
- q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
84
-
85
- # --- Attention ---
86
- cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units)
87
- q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
88
-
89
- q_last = q_output[:, -1, :] # (batch, lstm_units)
90
-
91
- # Context vector: (batch, 3*lstm_units)
92
- context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
93
-
94
- # --- Decoder với Teacher Forcing ---
95
- # answer_input: (batch, ans_seq_len)
96
- answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
97
- context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units)
98
- decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
99
- decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
100
-
101
- decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
102
- output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
103
- return output
104
-
105
- def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
106
- self.eval()
107
- self.to(device)
108
-
109
- image = image.unsqueeze(0).to(device)
110
- question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
111
- question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
112
-
113
- # Encode image & question
114
- cnn_features = self.cnn(image)
115
- cnn_features = cnn_features.view(cnn_features.size(0), -1)
116
- q_embed = self.embedding(question)
117
- q_output, _ = self.question_lstm(q_embed)
118
- q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
119
- q_last = q_output[:, -1, :]
120
- # Ở predict, ta tạo context vector từ q_attended lặp lại (chỉ dùng question cho ví dụ)
121
- context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
122
-
123
- # Khởi tạo câu trả lời với token <START>
124
- answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
125
- answer_words = []
126
-
127
- hidden = None
128
- for _ in range(self.max_seq_len):
129
- answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim)
130
- context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
131
- decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
132
- decoder_in = self.decoder_input_proj(decoder_in)
133
- decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
134
- output = self.fc_out(decoder_output[:, -1, :])
135
- next_word_idx = output.argmax(dim=-1).item()
136
- if next_word_idx == word_to_idx['<END>']:
137
- break
138
- answer_words.append(idx_to_word[next_word_idx])
139
- answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1)
 
 
 
140
 
141
- return ' '.join(answer_words)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
 
 
 
 
143
 
144
- # Load hình từ Hugging Face Model Hub hoặc local
145
- device = "cuda" if torch.cuda.is_available() else "cpu"
146
 
147
- # Nếu dùng Model Hub, tải từ Hugging Face (bỏ comment nếu cần)
148
- # from huggingface_hub import hf_hub_download
149
- # model_path = hf_hub_download("your-username/VQA-Fruits-Model", "vqa_model.pth")
150
- # word_to_idx_path = hf_hub_download("your-username/VQA-Fruits-Model", "word_to_idx.pth")
151
- # idx_to_word_path = hf_hub_download("your-username/VQA-Fruits-Model", "idx_to_word.pth")
152
 
153
- # Nếu dùng file upload trực tiếp vào Space, dùng cách này:
154
- model_path = "vqa_model.pth"
155
- word_to_idx_path = "word_to_idx.pth"
156
- idx_to_word_path = "idx_to_word.pth"
157
 
158
- # Load word_to_idx idx_to_word
159
- word_to_idx = torch.load(word_to_idx_path, map_location=device)
160
- idx_to_word = torch.load(idx_to_word_path, map_location=device)
161
 
162
- # Khởi tạo mô hình
163
- vocab_size = len(word_to_idx)
164
- model = VQAModel(vocab_size) # ⚠️ Bạn cần định nghĩa class VQAModel
165
- model.load_state_dict(torch.load(model_path, map_location=device))
166
- model.to(device)
167
- model.eval()
168
 
169
- # Chuẩn bị tiền xử ảnh
170
- transform = transforms.Compose([
171
- transforms.Resize((224, 224)),
172
- transforms.ToTensor(),
173
- ])
174
 
175
- # Hàm dự đoán VQA
176
- def predict(image, question):
177
- image = transform(image).unsqueeze(0).to(device)
178
- question_tokens = [word_to_idx.get(word, 0) for word in question.lower().split()]
179
- question_tensor = torch.tensor(question_tokens).unsqueeze(0).to(device)
180
 
181
- with torch.no_grad():
182
- output = model(image, question_tensor)
183
- predicted_idx = torch.argmax(output, dim=1).item()
 
 
 
 
 
 
 
 
184
 
185
- answer = idx_to_word[predicted_idx]
186
- return answer
187
 
188
- # Giao diện Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  iface = gr.Interface(
190
- fn=predict,
191
- inputs=[gr.Image(type="pil"), gr.Textbox(label="Câu hỏi")],
192
- outputs=gr.Textbox(label="Câu trả lời"),
193
- title="VQA for Animal",
194
- description="Tải lên ảnh con vật và nhập câu hỏi để nhận câu trả lời. (CHỈ HỖ TRỢ TIẾNG ANH)",
 
 
 
 
 
195
  )
196
 
197
- # Chạy ứng dụng
198
- iface.launch()
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
  from torchvision import transforms
5
+ from PIL import Image
6
+ import json
7
+ import gradio as gr
8
+ import os
9
+ import sys
10
 
11
+ # ============================================================================
12
+ # 1. ĐỊNH NGHĨA LẠI CÁC CLASS MODEL (QUAN TRỌNG!)
13
+ # ============================================================================
14
+ # SAO CHÉP VÀ DÁN TOÀN BỘ ĐỊNH NGHĨA CỦA CLASS Attention và VQAModel
15
+ # (phiên bản gốc có caption, CNN tự định nghĩa) TỪ SCRIPT HUẤN LUYỆN VÀO ĐÂY.
16
+ # Nếu không có các định nghĩa này, torch.load sẽ không hoạt động.
17
 
18
+ # --- Ví dụ (BẠN CẦN DÁN CODE ĐẦY ĐỦ CỦA BẠN VÀO) ---
 
 
19
  class Attention(nn.Module):
20
+ # ... (Dán code class Attention của bạn vào đây) ...
21
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
22
  super(Attention, self).__init__()
23
+ self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
24
+ self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
25
  self.attn = nn.Linear(attention_dim, 1)
26
+ # Thêm các lớp kích hoạt nếu có trong code gốc của bạn
27
+ self.tanh = nn.Tanh()
28
+ self.softmax = nn.Softmax(dim=1)
29
 
30
  def forward(self, cnn_features, lstm_features):
31
+ cnn_proj = self.cnn_proj(cnn_features)
32
+ lstm_proj = self.lstm_proj(lstm_features)
33
+ # Đảm bảo broadcasting hoạt động đúng
34
+ combined = self.tanh(cnn_proj + lstm_proj) # cnn_proj sẽ được broadcast
35
+ attn_logits = self.attn(combined)
36
+ attn_weights = self.softmax(attn_logits)
37
+ attended_features = (attn_weights * lstm_features).sum(dim=1)
38
  return attended_features
39
+
 
 
 
40
  class VQAModel(nn.Module):
41
+ # ... (Dán code class VQAModel gốc của bạn vào đây) ...
42
+ # Đảm bảo các tham số mặc định khớp với lúc bạn lưu model
43
  def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
44
  super(VQAModel, self).__init__()
45
  self.vocab_size = vocab_size
46
  self.max_seq_len = max_seq_len
47
 
48
+ # CNN Encoder (giống hệt lúc train)
49
  self.cnn = nn.Sequential(
50
+ nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
51
+ nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
52
+ nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
53
+ nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), nn.ReLU(),
 
 
 
 
 
 
 
54
  nn.AdaptiveAvgPool2d((1, 1))
55
  )
 
 
56
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
+ # Có caption_lstm trong định nghĩa model gốc
 
58
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
59
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
 
 
60
  self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
61
+ # Kích thước input decoder dựa trên context gốc (có cả caption)
 
 
 
 
62
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
63
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
64
  self.fc_out = nn.Linear(lstm_units, vocab_size)
65
+ self.dropout = nn.Dropout(0.5) # Tự động tắt khi model.eval()
66
 
67
+ # Hàm forward không thực sự được gọi trong predict_gradio theo cách làm này
68
+ # Nhưng nó cần tồn tại để model load đúng cấu trúc
69
  def forward(self, image, caption, question, answer_input):
70
+ raise NotImplementedError("Use the specific prediction logic for Gradio.")
71
+ # ----------------------------------------------------------------------------
72
+
73
+ # ============================================================================
74
+ # 2. CẤU HÌNH LOAD MODEL/VOCAB
75
+ # ============================================================================
76
+ MODEL_PATH = "vqa_custom_cnn_model.pth" # Tên file model của bạn
77
+ VOCAB_PATH = "vqa_custom_cnn_vocab.json" # Tên file vocab của bạn
78
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+
80
+ # --- Hàm load ---
81
+ def load_model_and_vocab(model_path, vocab_path, device):
82
+ if not os.path.exists(vocab_path):
83
+ print(f"Error: Vocabulary file not found at {vocab_path}")
84
+ return None, None, None
85
+ try:
86
+ with open(vocab_path, 'r') as f:
87
+ vocab_data = json.load(f)
88
+ word_to_idx = vocab_data['word_to_idx']
89
+ # Đảm bảo idx_to_word có key là integer nếu dùng get(int_key)
90
+ # Hoặc chuyển index sang string nếu key là string
91
+ idx_to_word = {int(k): v for k, v in vocab_data['idx_to_word'].items()}
92
+ vocab_size = len(word_to_idx)
93
+ except Exception as e:
94
+ print(f"Error loading vocabulary: {e}")
95
+ return None, None, None
96
+
97
+ if not os.path.exists(model_path):
98
+ print(f"Error: Model file not found at {model_path}")
99
+ return None, None, None
100
+ try:
101
+ # Khởi tạo model với các tham số đúng
102
+ # Cần lấy các giá trị dim từ lúc bạn train model gốc
103
+ model = VQAModel(vocab_size=vocab_size,
104
+ embedding_dim=256, # Giả định, thay đổi nếu khác
105
+ lstm_units=256, # Giả định, thay đổi nếu khác
106
+ cnn_output_dim=512, # Giả định, thay đổi nếu khác
107
+ attention_dim=256, # Giả định, thay đổi nếu khác
108
+ max_seq_len=30) # Giả định, thay đổi nếu khác
109
+
110
+ model.load_state_dict(torch.load(model_path, map_location=device))
111
+ model.to(device)
112
+ model.eval() # QUAN TRỌNG: Chuyển sang chế độ đánh giá
113
+ print(f"Model loaded successfully from {model_path}")
114
+ return model, word_to_idx, idx_to_word
115
+ except Exception as e:
116
+ print(f"Error loading model: {e}")
117
+ # thể in traceback để debug kỹ hơn nếu cần
118
+ # import traceback
119
+ # traceback.print_exc()
120
+ return None, None, None
121
+
122
+ # --- Load model và vocab một lần khi app khởi động ---
123
+ model, word_to_idx, idx_to_word = load_model_and_vocab(MODEL_PATH, VOCAB_PATH, DEVICE)
124
+
125
+ # Thoát nếu không load được model/vocab
126
+ if model is None or word_to_idx is None:
127
+ print("Exiting because model or vocabulary failed to load.")
128
+ sys.exit(1)
129
+
130
+ # ============================================================================
131
+ # 3. ĐỊNH NGHĨA TRANSFORM (PHẢI GIỐNG HỆT LÚC TRAIN)
132
+ # ============================================================================
133
+ # Sử dụng lại transform bạn đã dùng trong hàm train_vqa
134
+ transform = transforms.Compose([
135
+ transforms.Resize((224, 224)),
136
+ transforms.ToTensor(),
137
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
138
+ ])
139
 
140
+ # ============================================================================
141
+ # 4. HÀM DỰ ĐOÁN CHO GRADIO
142
+ # ============================================================================
143
+ def predict_vqa(image, question):
144
+ """Hàm xử lý input từ Gradio và trả về dự đoán."""
145
+ if image is None or not question.strip():
146
+ return "Lỗi: Vui lòng cung cấp cả ảnh và câu hỏi."
147
+
148
+ # --- 1. Tiền xử lý ảnh ---
149
+ try:
150
+ # Gradio truyền vào PIL Image
151
+ image_tensor = transform(image).unsqueeze(0).to(DEVICE)
152
+ except Exception as e:
153
+ return f"Lỗi xử lý ảnh: {e}"
154
+
155
+ # --- 2. Tiền xử lý câu hỏi ---
156
+ question_tokens = question.lower().split()
157
+ # Sử dụng PAD index cho từ không biết nếu UNK không có
158
+ unk_idx = word_to_idx.get('<UNK>', word_to_idx.get('<PAD>', 0))
159
+ question_seq = [word_to_idx.get(word, unk_idx) for word in question_tokens]
160
+ if not question_seq:
161
+ question_seq = [unk_idx] # Xử lý câu hỏi rỗng
162
+ question_tensor = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
163
+
164
+ # --- 3. Chạy Inference (Bắt chước logic của model.predict gốc) ---
165
+ start_token_idx = word_to_idx['<START>']
166
+ end_token_idx = word_to_idx['<END>']
167
+ max_len = model.max_seq_len
168
+
169
+ generated_indices = [] # Không cần thêm START ở đây
170
+
171
+ # Bắt đầu giải mã với token <START>
172
+ decoder_input = torch.tensor([[start_token_idx]], dtype=torch.long).to(DEVICE)
173
+ # Hidden state của decoder LSTM sẽ được khởi tạo lại ở mỗi bước trong cách làm này
174
+ # (hoặc cần được truyền và cập nhật nếu logic predict gốc làm vậy)
175
+ # Logic predict gốc không truyền hidden state rõ ràng, nên ta cũng không cần
176
+ hidden_state = None
177
 
178
+ with torch.no_grad():
179
+ # Encode ảnh và câu hỏi một lần
180
+ cnn_features = model.cnn(image_tensor) # (1, cnn_output_dim, 1, 1)
181
+ cnn_features = cnn_features.view(cnn_features.size(0), -1) # (1, cnn_output_dim)
182
 
183
+ q_embed = model.embedding(question_tensor) # (1, q_seq_len, embedding_dim)
184
+ q_output, _ = model.question_lstm(q_embed) # (1, q_seq_len, lstm_units)
185
 
186
+ # Attention chỉ với question
187
+ # Cần unsqueeze cnn_features để có chiều seq_len=1
188
+ q_attended = model.attention(cnn_features.unsqueeze(1), q_output) # (1, lstm_units)
 
 
189
 
190
+ # Trạng thái cuối của LSTM question (lấy từ output)
191
+ q_last = q_output[:, -1, :] # (1, lstm_units)
 
 
192
 
193
+ # --- Context Vector (THEO LOGIC model.predict GỐC) ---
194
+ # Sử dụng q_attended hai lần, bỏ qua caption hoàn toàn trong inference này
195
+ context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
196
 
197
+ for _ in range(max_len):
198
+ # --- Chuẩn bị input cho decoder ở bước này ---
199
+ current_word_embed = model.embedding(decoder_input) # (1, 1, embedding_dim)
 
 
 
200
 
201
+ # Lặp context cho bước thời gian hiện tại (batch=1, seq_len=1)
202
+ context_repeated = context.unsqueeze(1) # (1, 1, 3*lstm_units)
 
 
 
203
 
204
+ # Input cho lớp chiếu của decoder
205
+ decoder_proj_input = torch.cat([current_word_embed, context_repeated], dim=-1)
206
+ decoder_lstm_input = model.decoder_input_proj(decoder_proj_input) # (1, 1, lstm_units)
 
 
207
 
208
+ # --- Chạy Decoder LSTM ---
209
+ # Logic predict gốc truyền hidden state, ta cần làm tương tự nếu muốn khớp 100%
210
+ # Hoặc nếu không truyền, LSTM sẽ tự khởi tạo state (có thể hơi khác kết quả)
211
+ # Giả sử logic gốc có truyền hidden state:
212
+ decoder_output, hidden_state = model.decoder_lstm(decoder_lstm_input, hidden_state) # Update hidden
213
+
214
+ # --- Lấy Logits và dự đoán ---
215
+ # Logic predict gốc lấy output của bước cuối cùng [-1]
216
+ # Vì ta đang chạy từng bước, output chỉ có 1 bước thời gian -> dùng squeeze(1)
217
+ output_logits = model.fc_out(decoder_output.squeeze(1)) # (1, vocab_size)
218
+ predicted_idx = output_logits.argmax(dim=-1).item()
219
 
220
+ if predicted_idx == end_token_idx:
221
+ break
222
 
223
+ generated_indices.append(predicted_idx)
224
+
225
+ # Chuẩn bị input cho bước tiếp theo
226
+ decoder_input = torch.tensor([[predicted_idx]], dtype=torch.long).to(DEVICE)
227
+
228
+ # --- 4. Decode Output ---
229
+ answer_words = [idx_to_word.get(idx, '<UNK>') for idx in generated_indices]
230
+ return ' '.join(answer_words) if answer_words else "(No answer generated)"
231
+
232
+ # ============================================================================
233
+ # 5. TẠO VÀ CHẠY GRADIO INTERFACE
234
+ # ============================================================================
235
+ title = "Visual Question Answering Demo"
236
+ description = """
237
+ Upload một ảnh và đặt câu hỏi về nội dung của ảnh đó.
238
+ Model này sử dụng CNN tùy chỉnh và LSTM với Attention (phiên bản gốc).
239
+ Lưu ý: Inference hiện tại dựa trên logic của hàm `predict` gốc, có thể không sử dụng caption.
240
+ """
241
+
242
+ # Ví dụ để người dùng thử
243
+ examples = [
244
+ ["path/to/your/example/cat_image.jpg", "what animal is in the picture"],
245
+ ["path/to/your/example/car_image.png", "what color is the car"],
246
+ # Thêm URL nếu muốn
247
+ # ["https://example.com/some_image.jpg", "how many people are there"]
248
+ ]
249
+ # Bạn cần thay đổi đường dẫn trong 'examples' thành đường dẫn thực tế
250
+ # tới file ảnh MÀ BẠN SẼ UPLOAD lên Space cùng với code.
251
+
252
+ # Tạo Interface
253
  iface = gr.Interface(
254
+ fn=predict_vqa,
255
+ inputs=[
256
+ gr.Image(type="pil", label="Input Image"), # Nhận PIL Image
257
+ gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn ở đây...", label="Question")
258
+ ],
259
+ outputs=gr.Textbox(label="Predicted Answer"),
260
+ title=title,
261
+ description=description,
262
+ examples=examples, # Cung cấp ví dụ (đảm bảo file ảnh ví dụ tồn tại trên Space)
263
+ allow_flagging='never' # Tắt flagging nếu không cần
264
  )
265
 
266
+ # Chạy app (Trong Hugging Face Spaces, nó sẽ tự chạy file này)
267
+ if __name__ == "__main__":
268
+ # if model is not None: # Kiểm tra lại lần nữa trước khi chạy
269
+ # iface.launch() # Không cần launch() ở đây khi deploy lên Spaces
270
+ # else:
271
+ # print("Cannot launch Gradio interface because model/vocab failed to load.")
272
+ # Dòng iface.launch() chỉ cần khi bạn chạy cục bộ để test.
273
+ # Trên Spaces, Gradio tự động tìm và chạy interface được định nghĩa.
274
+ pass # Để trống hoặc thêm logic chạy cục bộ nếu muốn