Tin113 commited on
Commit
76b9d33
·
verified ·
1 Parent(s): aef5bf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -26
app.py CHANGED
@@ -9,57 +9,137 @@ import os
9
  import sys
10
 
11
  class Attention(nn.Module):
12
- # ... (Dán code class Attention của bạn vào đây) ...
13
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
14
  super(Attention, self).__init__()
15
- self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
16
- self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
17
  self.attn = nn.Linear(attention_dim, 1)
18
- # Thêm các lớp kích hoạt nếu có trong code gốc của bạn
19
- self.tanh = nn.Tanh()
20
- self.softmax = nn.Softmax(dim=1)
21
 
22
  def forward(self, cnn_features, lstm_features):
23
- cnn_proj = self.cnn_proj(cnn_features)
24
- lstm_proj = self.lstm_proj(lstm_features)
25
- # Đảm bảo broadcasting hoạt động đúng
26
- combined = self.tanh(cnn_proj + lstm_proj) # cnn_proj sẽ được broadcast
27
- attn_logits = self.attn(combined)
28
- attn_weights = self.softmax(attn_logits)
29
- attended_features = (attn_weights * lstm_features).sum(dim=1)
30
  return attended_features
 
31
 
 
 
 
32
  class VQAModel(nn.Module):
33
- # ... (Dán code class VQAModel gốc của bạn vào đây) ...
34
- # Đảm bảo các tham số mặc định khớp với lúc bạn lưu model
35
  def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
36
  super(VQAModel, self).__init__()
37
  self.vocab_size = vocab_size
38
  self.max_seq_len = max_seq_len
39
 
40
- # CNN Encoder (giống hệt lúc train)
41
  self.cnn = nn.Sequential(
42
- nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
43
- nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
44
- nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
45
- nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), nn.ReLU(),
 
 
 
 
 
 
 
46
  nn.AdaptiveAvgPool2d((1, 1))
47
  )
 
 
48
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
49
- # Có caption_lstm trong định nghĩa model gốc
 
50
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
51
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
 
 
52
  self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
53
- # Kích thước input decoder dựa trên context gốc (có cả caption)
 
 
 
 
54
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
55
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
56
  self.fc_out = nn.Linear(lstm_units, vocab_size)
57
- self.dropout = nn.Dropout(0.5) # Tự động tắt khi model.eval()
58
 
59
- # Hàm forward không thực sự được gọi trong predict_gradio theo cách làm này
60
- # Nhưng nó cần tồn tại để model load đúng cấu trúc
61
  def forward(self, image, caption, question, answer_input):
62
- raise NotImplementedError("Use the specific prediction logic for Gradio.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # ----------------------------------------------------------------------------
64
 
65
  # ============================================================================
 
9
  import sys
10
 
11
  class Attention(nn.Module):
 
12
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
13
  super(Attention, self).__init__()
14
+ self.cnn = nn.Linear(cnn_dim, attention_dim)
15
+ self.lstm = nn.Linear(lstm_dim, attention_dim)
16
  self.attn = nn.Linear(attention_dim, 1)
 
 
 
17
 
18
  def forward(self, cnn_features, lstm_features):
19
+ # cnn_features: (batch, 1, cnn_dim)
20
+ # lstm_features: (batch, seq_len, lstm_dim)
21
+ cnn = self.cnn(cnn_features) # (batch, 1, attention_dim)
22
+ lstm = self.lstm(lstm_features) # (batch, seq_len, attention_dim)
23
+ combined = torch.tanh(cnn + lstm) # (batch, seq_len, attention_dim)
24
+ attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1)
25
+ attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim)
26
  return attended_features
27
+
28
 
29
+ # -----------------------
30
+ # VQA Model
31
+ # -----------------------
32
  class VQAModel(nn.Module):
 
 
33
  def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
34
  super(VQAModel, self).__init__()
35
  self.vocab_size = vocab_size
36
  self.max_seq_len = max_seq_len
37
 
38
+ # CNN Encoder: Trích xuất đặc trưng ảnh
39
  self.cnn = nn.Sequential(
40
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
41
+ nn.ReLU(),
42
+ nn.MaxPool2d(2),
43
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
44
+ nn.ReLU(),
45
+ nn.MaxPool2d(2),
46
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
47
+ nn.ReLU(),
48
+ nn.MaxPool2d(2),
49
+ nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1),
50
+ nn.ReLU(),
51
  nn.AdaptiveAvgPool2d((1, 1))
52
  )
53
+
54
+ # Text Embedding
55
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
56
+
57
+ # LSTM Encoders cho caption và question
58
  self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
59
  self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
60
+
61
+ # Attention cho từng kênh
62
  self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
63
+
64
+ # Decoder: sử dụng teacher forcing
65
+ # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question
66
+ # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256)
67
+ # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units
68
  self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
69
  self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
70
  self.fc_out = nn.Linear(lstm_units, vocab_size)
71
+ self.dropout = nn.Dropout(0.5)
72
 
 
 
73
  def forward(self, image, caption, question, answer_input):
74
+ # --- CNN Encoder ---
75
+ cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1)
76
+ cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
77
+
78
+ # --- Text Encoders ---
79
+ cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim)
80
+ cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units)
81
+
82
+ q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
83
+ q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
84
+
85
+ # --- Attention ---
86
+ cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units)
87
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
88
+
89
+ q_last = q_output[:, -1, :] # (batch, lstm_units)
90
+
91
+ # Context vector: (batch, 3*lstm_units)
92
+ context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
93
+
94
+ # --- Decoder với Teacher Forcing ---
95
+ # answer_input: (batch, ans_seq_len)
96
+ answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
97
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units)
98
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
99
+ decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
100
+
101
+ decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
102
+ output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
103
+ return output
104
+
105
+ def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
106
+ self.eval()
107
+ self.to(device)
108
+
109
+ image = image.unsqueeze(0).to(device)
110
+ question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
111
+ question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
112
+
113
+ # Encode image & question
114
+ cnn_features = self.cnn(image)
115
+ cnn_features = cnn_features.view(cnn_features.size(0), -1)
116
+ q_embed = self.embedding(question)
117
+ q_output, _ = self.question_lstm(q_embed)
118
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
119
+ q_last = q_output[:, -1, :]
120
+ # Ở predict, ta tạo context vector từ q_attended lặp lại (chỉ dùng question cho ví dụ)
121
+ context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
122
+
123
+ # Khởi tạo câu trả lời với token <START>
124
+ answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
125
+ answer_words = []
126
+
127
+ hidden = None
128
+ for _ in range(self.max_seq_len):
129
+ answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim)
130
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
131
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
132
+ decoder_in = self.decoder_input_proj(decoder_in)
133
+ decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
134
+ output = self.fc_out(decoder_output[:, -1, :])
135
+ next_word_idx = output.argmax(dim=-1).item()
136
+ if next_word_idx == word_to_idx['<END>']:
137
+ break
138
+ answer_words.append(idx_to_word[next_word_idx])
139
+ answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1)
140
+
141
+ return ' '.join(answer_words)
142
+
143
  # ----------------------------------------------------------------------------
144
 
145
  # ============================================================================