import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import json from torchvision import transforms from PIL import Image import numpy as np # ============================================================================ # ----------------------- # Attention Module # ----------------------- class Attention(nn.Module): def __init__(self, cnn_dim, lstm_dim, attention_dim): super(Attention, self).__init__() self.cnn_proj = nn.Linear(cnn_dim, attention_dim) self.lstm_proj = nn.Linear(lstm_dim, attention_dim) self.attn = nn.Linear(attention_dim, 1) def forward(self, cnn_features, lstm_features): # cnn_features: (batch, 1, cnn_dim) # lstm_features: (batch, seq_len, lstm_dim) cnn_proj = self.cnn_proj(cnn_features) # (batch, 1, attention_dim) lstm_proj = self.lstm_proj(lstm_features) # (batch, seq_len, attention_dim) combined = torch.tanh(cnn_proj + lstm_proj) # (batch, seq_len, attention_dim) attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1) attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim) return attended_features # ----------------------- # VQA Model # ----------------------- class VQAModel(nn.Module): def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30): super(VQAModel, self).__init__() self.vocab_size = vocab_size self.max_seq_len = max_seq_len # CNN Encoder: Trích xuất đặc trưng ảnh self.cnn = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)) ) # Text Embedding self.embedding = nn.Embedding(vocab_size, embedding_dim) # LSTM Encoders cho caption và question self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True) self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True) # Attention cho từng kênh self.attention = Attention(cnn_output_dim, lstm_units, attention_dim) # Decoder: sử dụng teacher forcing # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256) # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units) self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True) self.fc_out = nn.Linear(lstm_units, vocab_size) self.dropout = nn.Dropout(0.5) def forward(self, image, caption, question, answer_input): # --- CNN Encoder --- cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1) cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim) # --- Text Encoders --- cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim) cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units) q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim) q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units) # --- Attention --- cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units) q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units) q_last = q_output[:, -1, :] # (batch, lstm_units) # Context vector: (batch, 3*lstm_units) context = torch.cat([cap_attended, q_attended, q_last], dim=-1) # --- Decoder với Teacher Forcing --- # answer_input: (batch, ans_seq_len) answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim) context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units) decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units) decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units) decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units) output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size) return output def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'): self.eval() self.to(device) # Kiểm tra nếu image không có batch dimension thì thêm if image.dim() == 3: image = image.unsqueeze(0) image = image.to(device) question_seq = [word_to_idx.get(word, word_to_idx['']) for word in question.lower().split()] question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device) # Encode image và question cnn_features = self.cnn(image) cnn_features = cnn_features.view(cnn_features.size(0), -1) q_embed = self.embedding(question) q_output, _ = self.question_lstm(q_embed) q_attended = self.attention(cnn_features.unsqueeze(1), q_output) q_last = q_output[:, -1, :] # Ở predict, sử dụng một context vector đơn giản từ question (hoặc kết hợp với các thành phần khác nếu có) context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units) # Khởi tạo câu trả lời với token answer_input = torch.tensor([[word_to_idx['']]], dtype=torch.long).to(device) answer_words = [] hidden = None for _ in range(self.max_seq_len): answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim) context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) decoder_in = self.decoder_input_proj(decoder_in) decoder_output, hidden = self.decoder_lstm(decoder_in, hidden) output = self.fc_out(decoder_output[:, -1, :]) next_word_idx = output.argmax(dim=-1).item() if next_word_idx == word_to_idx['']: break answer_words.append(idx_to_word[next_word_idx]) answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1) return ' '.join(answer_words) def load_model(model_path, word_to_idx_path, idx_to_word_path, device='cpu'): try: # Load từ điển từ file .pth word_to_idx = torch.load(word_to_idx_path, map_location=device) idx_to_word = torch.load(idx_to_word_path, map_location=device) # Khởi tạo mô hình model = VQAModel(vocab_size=len(word_to_idx)) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() return model, word_to_idx, idx_to_word except Exception as e: print(f"Error loading model: {e}") raise def predict(image, question, model, word_to_idx, idx_to_word, device='cpu'): try: # Chuyển đổi ảnh image = transform(image).unsqueeze(0).to(device) # Dự đoán answer = model.predict(image, question, word_to_idx, idx_to_word, device) return answer except Exception as e: print(f"Prediction error: {e}") return "Error generating answer" # Tạo transform cho ảnh transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def create_interface(): device = 'cpu' # Luôn dùng CPU trên Spaces try: model, word_to_idx, idx_to_word = load_model( "vqa_model.pth", "word_to_idx.pth", "idx_to_word.pth", device ) def predict(image, question): try: transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image).unsqueeze(0).to(device) answer = model.predict(image, question, word_to_idx, idx_to_word, device) return answer except Exception as e: return f"Error: {str(e)}" iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Question") ], outputs=gr.Textbox(label="Answer"), title="Visual Question Answering", description="Tải ảnh về động vật lên và đặt câu hỏi liên quan (CHỈ HỖ TRỢ TIẾNG ANH)" ) return iface except Exception as e: return gr.Interface(lambda: "Model failed to load", None, "text") if __name__ == "__main__": iface = create_interface() iface.launch( server_name="0.0.0.0", server_port=7860 )