Tin113 commited on
Commit
b25b382
·
verified ·
1 Parent(s): 4029376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -23
app.py CHANGED
@@ -5,35 +5,119 @@ from torchvision import transforms
5
 
6
  import torch.nn as nn
7
  import torchvision.models as models
8
-
 
 
9
  class VQAModel(nn.Module):
10
- def __init__(self, vocab_size):
11
  super(VQAModel, self).__init__()
12
- # Dùng ResNet làm CNN Encoder
13
- self.cnn = models.resnet18(pretrained=True)
14
- self.cnn.fc = nn.Linear(512, 256) # Thay FC layer
15
-
16
- # Dùng LSTM làm Text Encoder
17
- self.embedding = nn.Embedding(vocab_size, 256)
18
- self.lstm = nn.LSTM(256, 256, batch_first=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Fully Connected Layer để dự đoán câu trả lời
21
- self.fc = nn.Linear(256, vocab_size)
 
22
 
23
- def forward(self, image, question):
24
- # Encode ảnh
25
- img_features = self.cnn(image)
26
 
27
- # Encode câu hỏi
 
 
28
  q_embed = self.embedding(question)
29
- _, (q_features, _) = self.lstm(q_embed)
30
-
31
- # Kết hợp đặc trưng ảnh và câu hỏi
32
- combined = img_features + q_features.squeeze(0)
33
- output = self.fc(combined)
34
-
35
- return output
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  # Load mô hình từ Hugging Face Model Hub hoặc local
 
5
 
6
  import torch.nn as nn
7
  import torchvision.models as models
8
+ # -----------------------
9
+ # VQA Model
10
+ # -----------------------
11
  class VQAModel(nn.Module):
12
+ def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
13
  super(VQAModel, self).__init__()
14
+ self.vocab_size = vocab_size
15
+ self.max_seq_len = max_seq_len
16
+
17
+ # CNN Encoder: Trích xuất đặc trưng ảnh
18
+ self.cnn = nn.Sequential(
19
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
20
+ nn.ReLU(),
21
+ nn.MaxPool2d(2),
22
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(2),
25
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
26
+ nn.ReLU(),
27
+ nn.MaxPool2d(2),
28
+ nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1),
29
+ nn.ReLU(),
30
+ nn.AdaptiveAvgPool2d((1, 1))
31
+ )
32
+
33
+ # Text Embedding
34
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
35
+
36
+ # LSTM Encoders cho caption và question
37
+ self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
38
+ self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
39
+
40
+ # Attention cho từng kênh
41
+ self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
42
+
43
+ # Decoder: sử dụng teacher forcing
44
+ # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question
45
+ # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256)
46
+ # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units
47
+ self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
48
+ self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
49
+ self.fc_out = nn.Linear(lstm_units, vocab_size)
50
+ self.dropout = nn.Dropout(0.5)
51
+
52
+ def forward(self, image, caption, question, answer_input):
53
+ # --- CNN Encoder ---
54
+ cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1)
55
+ cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
56
+
57
+ # --- Text Encoders ---
58
+ cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim)
59
+ cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units)
60
+
61
+ q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
62
+ q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
63
+
64
+ # --- Attention ---
65
+ cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units)
66
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
67
+
68
+ q_last = q_output[:, -1, :] # (batch, lstm_units)
69
+
70
+ # Context vector: (batch, 3*lstm_units)
71
+ context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
72
+
73
+ # --- Decoder với Teacher Forcing ---
74
+ # answer_input: (batch, ans_seq_len)
75
+ answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
76
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units)
77
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
78
+ decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
79
+
80
+ decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
81
+ output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
82
+ return output
83
 
84
+ def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
85
+ self.eval()
86
+ self.to(device)
87
 
88
+ image = image.unsqueeze(0).to(device)
89
+ question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
90
+ question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
91
 
92
+ # Encode image & question
93
+ cnn_features = self.cnn(image)
94
+ cnn_features = cnn_features.view(cnn_features.size(0), -1)
95
  q_embed = self.embedding(question)
96
+ q_output, _ = self.question_lstm(q_embed)
97
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
98
+ q_last = q_output[:, -1, :]
99
+ # predict, ta tạo context vector từ q_attended lặp lại (chỉ dùng question cho ví dụ)
100
+ context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
101
+
102
+ # Khởi tạo câu trả lời với token <START>
103
+ answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
104
+ answer_words = []
105
+
106
+ hidden = None
107
+ for _ in range(self.max_seq_len):
108
+ answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim)
109
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
110
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
111
+ decoder_in = self.decoder_input_proj(decoder_in)
112
+ decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
113
+ output = self.fc_out(decoder_output[:, -1, :])
114
+ next_word_idx = output.argmax(dim=-1).item()
115
+ if next_word_idx == word_to_idx['<END>']:
116
+ break
117
+ answer_words.append(idx_to_word[next_word_idx])
118
+ answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1)
119
+
120
+ return ' '.join(answer_words)
121
 
122
 
123
  # Load mô hình từ Hugging Face Model Hub hoặc local