Tin113 commited on
Commit
84165ce
·
verified ·
1 Parent(s): 2fccb7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -125
app.py CHANGED
@@ -2,37 +2,28 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- import json
6
- from torchvision import transforms
7
  from PIL import Image
8
- import numpy as np
9
 
10
- # ============================================================================
11
- # 1. ĐỊNH NGHĨA LẠI CÁC CLASS MODEL (QUAN TRỌNG!)
12
- # (Copy từ code huấn luyện gốc, ĐÃ SỬA Attention theo lỗi trước)
13
- # ============================================================================
14
  # -----------------------
15
  # Attention Module
16
  # -----------------------
17
- class Attention(nn.Module):
18
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
19
- super(Attention, self).__init__()
20
  self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
21
  self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
22
  self.attn = nn.Linear(attention_dim, 1)
23
 
24
  def forward(self, cnn_features, lstm_features):
25
- # cnn_features: (batch, 1, cnn_dim)
26
- # lstm_features: (batch, seq_len, lstm_dim)
27
- cnn_proj = self.cnn_proj(cnn_features) # (batch, 1, attention_dim)
28
- lstm_proj = self.lstm_proj(lstm_features) # (batch, seq_len, attention_dim)
29
- combined = torch.tanh(cnn_proj + lstm_proj) # (batch, seq_len, attention_dim)
30
- attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1)
31
- attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim)
32
  return attended_features
33
- # -----------------------
34
- # VQA Model
35
- # -----------------------
36
  # -----------------------
37
  # Pre-trained VQA Model
38
  # -----------------------
@@ -42,11 +33,11 @@ class PretrainedVQAModel(nn.Module):
42
  self.vocab_size = vocab_size
43
  self.max_seq_len = max_seq_len
44
 
45
- # Pre-trained CNN Encoder (ResNet50)
46
  resnet = models.resnet18(pretrained=True)
47
- self.cnn = nn.Sequential(*list(resnet.children())[:-1]) # Remove the final FC layer
48
- self.cnn_output_dim = 512
49
-
50
  # Text Embedding
51
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
52
 
@@ -64,146 +55,145 @@ class PretrainedVQAModel(nn.Module):
64
 
65
  def forward(self, image, question, answer_input):
66
  # CNN Encoder
67
- cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1)
68
- cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
69
 
70
  # Question Encoder
71
- q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
72
- q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
73
 
74
  # Attention
75
- q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
76
- q_last = q_output[:, -1, :] # (batch, lstm_units)
77
 
78
  # Context Vector
79
- context = torch.cat([q_attended, q_last], dim=-1) # (batch, 2*lstm_units)
80
 
81
  # Decoder with Teacher Forcing
82
- answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
83
- context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 2*lstm_units)
84
- decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 2*lstm_units)
85
- decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
86
 
87
- decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
88
- output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
89
  return output
90
 
91
- def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
92
  self.eval()
93
- self.to(device)
94
- if image.dim() == 3:
95
- image = image.unsqueeze(0)
96
- image = image.to(device)
97
-
98
- question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
99
- question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
100
-
101
- # Encode image and question
102
- cnn_features = self.cnn(image).view(-1, self.cnn_output_dim)
103
- q_embed = self.embedding(question)
104
- q_output, _ = self.question_lstm(q_embed)
105
- q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
106
- q_last = q_output[:, -1, :]
107
- context = torch.cat([q_attended, q_last], dim=-1)
108
-
109
- # Generate answer
110
- answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
111
- answer_words = []
112
- hidden = None
113
- for _ in range(self.max_seq_len):
114
- answer_embed = self.embedding(answer_input)
115
- context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
116
- decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
117
- decoder_in = self.decoder_input_proj(decoder_in)
118
- decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
119
- output = self.fc_out(decoder_output[:, -1, :])
120
- next_word_idx = output.argmax(dim=-1).item()
121
- if next_word_idx == word_to_idx['<END>']:
122
- break
123
- answer_words.append(idx_to_word[next_word_idx])
124
- answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device)
125
- return ' '.join(answer_words)
126
-
 
 
 
 
127
 
128
-
129
- def load_model(model_path, word_to_idx_path, idx_to_word_path, device='cpu'):
 
 
 
130
  try:
131
- # Load từ điển từ file .pth
132
- word_to_idx = torch.load(word_to_idx_path, map_location=device)
133
- idx_to_word = torch.load(idx_to_word_path, map_location=device)
134
 
135
- # Khởi tạo mô hình
136
  model = PretrainedVQAModel(vocab_size=len(word_to_idx))
137
- model.load_state_dict(torch.load(model_path, map_location=device))
138
  model.to(device)
139
  model.eval()
140
-
141
  return model, word_to_idx, idx_to_word
142
  except Exception as e:
143
  print(f"Error loading model: {e}")
144
  raise
145
-
146
- def predict(image, question, model, word_to_idx, idx_to_word, device='cpu'):
147
- try:
148
- # Chuyển đổi ảnh
149
- image = transform(image).unsqueeze(0).to(device)
150
-
151
- # Dự đoán
152
- answer = model.predict(image, question, word_to_idx, idx_to_word, device)
153
- return answer
154
- except Exception as e:
155
- print(f"Prediction error: {e}")
156
- return "Error generating answer"
157
-
158
- # Tạo transform cho ảnh
159
- transform = transforms.Compose([
160
- transforms.Resize((224, 224)),
161
- transforms.ToTensor(),
162
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
163
- ])
164
  def create_interface():
165
- device = 'cpu' # Luôn dùng CPU trên Spaces
166
-
167
  try:
168
- model, word_to_idx, idx_to_word = load_model(
169
- "vqa_model.pth",
170
- "word_to_idx.pth",
171
- "idx_to_word.pth",
172
- device
173
- )
 
 
 
 
 
174
 
175
  def predict(image, question):
176
  try:
177
- transform = transforms.Compose([
178
- transforms.Resize((224, 224)),
179
- transforms.ToTensor(),
180
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
181
- std=[0.229, 0.224, 0.225])
182
- ])
183
- image = transform(image).unsqueeze(0).to(device)
184
- answer = model.predict(image, question, word_to_idx, idx_to_word, device)
185
  return answer
186
  except Exception as e:
187
- return f"Error: {str(e)}"
188
 
189
- iface = gr.Interface(
 
190
  fn=predict,
191
  inputs=[
192
  gr.Image(type="pil", label="Upload Image"),
193
- gr.Textbox(label="Question")
194
  ],
195
- outputs=gr.Textbox(label="Answer"),
196
- title="VQA train từ đầu",
197
- description="Tải ảnh về động vật lên đặt câu hỏi liên quan (CHỈ HỖ TRỢ TIẾNG ANH)"
 
198
  )
199
- return iface
200
  except Exception as e:
201
- return gr.Interface(lambda: "Model failed to load", None, "text")
 
 
 
 
 
202
 
 
 
 
203
  if __name__ == "__main__":
 
204
  iface = create_interface()
205
  iface.launch(
206
  server_name="0.0.0.0",
207
- server_port=7860
208
- )
209
-
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from torchvision import transforms, models
 
6
  from PIL import Image
7
+ import os
8
 
 
 
 
 
9
  # -----------------------
10
  # Attention Module
11
  # -----------------------
12
+ class Attention_PT(nn.Module):
13
  def __init__(self, cnn_dim, lstm_dim, attention_dim):
14
+ super(Attention_PT, self).__init__()
15
  self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
16
  self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
17
  self.attn = nn.Linear(attention_dim, 1)
18
 
19
  def forward(self, cnn_features, lstm_features):
20
+ cnn_proj = self.cnn_proj(cnn_features)
21
+ lstm_proj = self.lstm_proj(lstm_features)
22
+ combined = torch.tanh(cnn_proj + lstm_proj)
23
+ attn_weights = F.softmax(self.attn(combined), dim=1)
24
+ attended_features = (attn_weights * lstm_features).sum(dim=1)
 
 
25
  return attended_features
26
+
 
 
27
  # -----------------------
28
  # Pre-trained VQA Model
29
  # -----------------------
 
33
  self.vocab_size = vocab_size
34
  self.max_seq_len = max_seq_len
35
 
36
+ # Pre-trained CNN Encoder (ResNet18)
37
  resnet = models.resnet18(pretrained=True)
38
+ self.cnn = nn.Sequential(*list(resnet.children())[:-1]) # Remove final FC layer
39
+ self.cnn_output_dim = 512 # Output dim for ResNet18 features
40
+
41
  # Text Embedding
42
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
43
 
 
55
 
56
  def forward(self, image, question, answer_input):
57
  # CNN Encoder
58
+ cnn_features = self.cnn(image)
59
+ cnn_features = cnn_features.view(cnn_features.size(0), -1)
60
 
61
  # Question Encoder
62
+ q_embed = self.embedding(question)
63
+ q_output, _ = self.question_lstm(q_embed)
64
 
65
  # Attention
66
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
67
+ q_last = q_output[:, -1, :]
68
 
69
  # Context Vector
70
+ context = torch.cat([q_attended, q_last], dim=-1)
71
 
72
  # Decoder with Teacher Forcing
73
+ answer_embed = self.embedding(answer_input)
74
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
75
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
76
+ decoder_in = self.decoder_input_proj(decoder_in)
77
 
78
+ decoder_output, _ = self.decoder_lstm(decoder_in)
79
+ output = self.fc_out(self.dropout(decoder_output))
80
  return output
81
 
82
+ def predict(self, image, question, word_to_idx, idx_to_word, device='cpu'):
83
  self.eval()
84
+ with torch.no_grad():
85
+ if image.dim() == 3:
86
+ image = image.unsqueeze(0)
87
+ image = image.to(device)
88
+
89
+ # Process question
90
+ question_seq = [word_to_idx.get(word, word_to_idx['<PAD>'])
91
+ for word in question.lower().split()]
92
+ question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
93
+
94
+ # Encode image and question
95
+ cnn_features = self.cnn(image).view(-1, self.cnn_output_dim)
96
+ q_embed = self.embedding(question)
97
+ q_output, _ = self.question_lstm(q_embed)
98
+ q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
99
+ q_last = q_output[:, -1, :]
100
+ context = torch.cat([q_attended, q_last], dim=-1)
101
+
102
+ # Generate answer
103
+ answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
104
+ answer_words = []
105
+
106
+ for _ in range(self.max_seq_len):
107
+ answer_embed = self.embedding(answer_input)
108
+ context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
109
+ decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
110
+ decoder_in = self.decoder_input_proj(decoder_in)
111
+ decoder_output, _ = self.decoder_lstm(decoder_in)
112
+ output = self.fc_out(decoder_output[:, -1, :])
113
+ next_word_idx = output.argmax(dim=-1).item()
114
+
115
+ if next_word_idx == word_to_idx['<END>']:
116
+ break
117
+
118
+ answer_words.append(idx_to_word[str(next_word_idx)])
119
+ answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device)
120
+
121
+ return ' '.join(answer_words)
122
 
123
+ # -----------------------
124
+ # Load Model Function
125
+ # -----------------------
126
+ def load_model():
127
+ device = 'cpu'
128
  try:
129
+ # Load dictionaries
130
+ word_to_idx = torch.load("word_to_idx.pth", map_location=device)
131
+ idx_to_word = torch.load("idx_to_word.pth", map_location=device)
132
 
133
+ # Initialize model
134
  model = PretrainedVQAModel(vocab_size=len(word_to_idx))
135
+ model.load_state_dict(torch.load("vqa_model.pth", map_location=device))
136
  model.to(device)
137
  model.eval()
 
138
  return model, word_to_idx, idx_to_word
139
  except Exception as e:
140
  print(f"Error loading model: {e}")
141
  raise
142
+
143
+ # -----------------------
144
+ # Gradio Interface
145
+ # -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def create_interface():
 
 
147
  try:
148
+ model, word_to_idx, idx_to_word = load_model()
149
+
150
+ # Image preprocessing
151
+ def preprocess_image(image):
152
+ transform = transforms.Compose([
153
+ transforms.Resize((224, 224)),
154
+ transforms.ToTensor(),
155
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
156
+ std=[0.229, 0.224, 0.225])
157
+ ])
158
+ return transform(image).unsqueeze(0)
159
 
160
  def predict(image, question):
161
  try:
162
+ image_tensor = preprocess_image(image)
163
+ answer = model.predict(image_tensor, question, word_to_idx, idx_to_word, 'cpu')
 
 
 
 
 
 
164
  return answer
165
  except Exception as e:
166
+ return f"Error generating answer: {str(e)}"
167
 
168
+ # Create interface
169
+ return gr.Interface(
170
  fn=predict,
171
  inputs=[
172
  gr.Image(type="pil", label="Upload Image"),
173
+ gr.Textbox(label="Your Question", placeholder="Ask something about the image...")
174
  ],
175
+ outputs=gr.Textbox(label="Generated Answer"),
176
+ title="Visual Question Answering with ResNet18",
177
+ description="Upload an image and ask natural language questions about its content",
178
+ allow_flagging="never"
179
  )
180
+
181
  except Exception as e:
182
+ return gr.Interface(
183
+ lambda: f"Failed to load model: {str(e)}",
184
+ inputs=None,
185
+ outputs="text",
186
+ title="Error"
187
+ )
188
 
189
+ # -----------------------
190
+ # Main Execution
191
+ # -----------------------
192
  if __name__ == "__main__":
193
+ # Create and launch interface
194
  iface = create_interface()
195
  iface.launch(
196
  server_name="0.0.0.0",
197
+ server_port=7860,
198
+ enable_queue=True
199
+ )