Tin113 commited on
Commit
bcf968d
·
verified ·
1 Parent(s): e7520bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -39
app.py CHANGED
@@ -40,23 +40,22 @@ class Attention(nn.Module):
40
  # -----------------------
41
  # VQA Model
42
  # -----------------------
43
- # (Copy nguyên văn từ code gốc bạn cung cấp)
44
  class VQAModel(nn.Module):
45
- # !! QUAN TRỌNG: Đảm bảo các giá trị mặc định này (hoặc giá trị bạn truyền vào khi load)
46
- # KHỚP VỚI CÁCH BẠN KHỞI TẠO MODEL KHI LƯU FILE .pth !!
47
  def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
48
  super(VQAModel, self).__init__()
49
  self.vocab_size = vocab_size
50
- self.max_seq_len = max_seq_len # Lưu lại để dùng trong predict
51
 
52
- # CNN Encoder (giống hệt lúc train)
53
- self.cnn_net = nn.Sequential( # Đổi tên từ self.cnn thành self.cnn_net để tránh trùng tên biến local trong forward
54
  nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
55
  nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
56
  nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
57
  nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), nn.ReLU(),
58
  nn.AdaptiveAvgPool2d((1, 1))
59
  )
 
 
60
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
61
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
62
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
@@ -64,41 +63,31 @@ class VQAModel(nn.Module):
64
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
65
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
66
  self.fc_out = nn.Linear(lstm_units, vocab_size)
67
- self.dropout = nn.Dropout(0.5) # Tự động tắt khi model.eval()
68
 
69
- # Hàm forward không dùng trực tiếp khi predict, nhưng cần tồn tại
70
  def forward(self, image, caption, question, answer_input):
71
- # --- CNN Encoder ---
72
- # Sử dụng self.cnn_net để gọi Sequential
73
- cnn_features = self.cnn_net(image) # (batch, cnn_output_dim, 1, 1)
74
- cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
75
-
76
- # --- Text Encoders ---
77
- cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim)
78
- cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units)
79
-
80
- q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
81
- q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
82
-
83
- # --- Attention ---
84
- # Chắc chắn self.attention được gọi đúng
85
- cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units)
86
- q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
87
-
88
- q_last = q_output[:, -1, :] # (batch, lstm_units)
89
-
90
- # Context vector: (batch, 3*lstm_units)
91
  context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
92
-
93
- # --- Decoder với Teacher Forcing ---
94
- answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
95
- context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units)
96
- decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
97
- decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
98
-
99
- decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
100
- output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
101
  return output
 
102
  # ----------------------------------------------------------------------------
103
 
104
  # ============================================================================
@@ -218,8 +207,9 @@ def predict_vqa(image, question_str):
218
  with torch.no_grad(): # Tắt gradient calculation
219
  print("Encoding image...")
220
  # Sử dụng self.cnn_net thay vì self.cnn
221
- cnn_features = model.cnn_net(image_tensor)
222
- cnn_features = cnn_features.view(cnn_features.size(0), -1) # (1, cnn_output_dim)
 
223
  print(f"CNN features shape: {cnn_features.shape}")
224
 
225
  print("Encoding question...")
 
40
  # -----------------------
41
  # VQA Model
42
  # -----------------------
 
43
  class VQAModel(nn.Module):
 
 
44
  def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
45
  super(VQAModel, self).__init__()
46
  self.vocab_size = vocab_size
47
+ self.max_seq_len = max_seq_len
48
 
49
+ # --- CNN Encoder: ĐỔI TÊN TRỞ LẠI THÀNH self.cnn ---
50
+ self.cnn = nn.Sequential( # Đổi tên lại thành self.cnn
51
  nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
52
  nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
53
  nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
54
  nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), nn.ReLU(),
55
  nn.AdaptiveAvgPool2d((1, 1))
56
  )
57
+ # ------------------------------------------------
58
+
59
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
60
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
61
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
 
63
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
64
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
65
  self.fc_out = nn.Linear(lstm_units, vocab_size)
66
+ self.dropout = nn.Dropout(0.5)
67
 
68
+ # Hàm forward không bị ảnh hưởng không gọi trực tiếp
69
  def forward(self, image, caption, question, answer_input):
70
+ # Logic forward thể vẫn dùng tên biến local cnn_features
71
+ # nhưng self.cnn để gọi mạng Sequential thì đã khớp tên
72
+ cnn_features = self.cnn(image) # Gọi self.cnn mới đúng tên
73
+ # ... (phần còn lại của forward giữ nguyên) ...
74
+ cnn_features = cnn_features.view(cnn_features.size(0), -1)
75
+ cap_embed = self.embedding(caption)
76
+ cap_output, _ = self.caption_lstm(cap_embed)
77
+ q_embed = self.embedding(question)
78
+ q_output, _ = self.question_lstm(q_embed)
79
+ cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output)
80
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
81
+ q_last = q_output[:, -1, :]
 
 
 
 
 
 
 
 
82
  context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
83
+ answer_embed = self.embedding(answer_input)
84
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
85
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
86
+ decoder_in = self.decoder_input_proj(decoder_in)
87
+ decoder_output, _ = self.decoder_lstm(decoder_in)
88
+ output = self.fc_out(self.dropout(decoder_output))
 
 
 
89
  return output
90
+
91
  # ----------------------------------------------------------------------------
92
 
93
  # ============================================================================
 
207
  with torch.no_grad(): # Tắt gradient calculation
208
  print("Encoding image...")
209
  # Sử dụng self.cnn_net thay vì self.cnn
210
+
211
+ cnn_features = model.cnn(image_tensor)
212
+ cnn_features = cnn_features.view(cnn_features.size(0), -1)
213
  print(f"CNN features shape: {cnn_features.shape}")
214
 
215
  print("Encoding question...")