Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import json | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| # ============================================================================ | |
| # ----------------------- | |
| # Attention Module | |
| # ----------------------- | |
| class Attention(nn.Module): | |
| def __init__(self, cnn_dim, lstm_dim, attention_dim): | |
| super(Attention, self).__init__() | |
| self.cnn_proj = nn.Linear(cnn_dim, attention_dim) | |
| self.lstm_proj = nn.Linear(lstm_dim, attention_dim) | |
| self.attn = nn.Linear(attention_dim, 1) | |
| def forward(self, cnn_features, lstm_features): | |
| # cnn_features: (batch, 1, cnn_dim) | |
| # lstm_features: (batch, seq_len, lstm_dim) | |
| cnn_proj = self.cnn_proj(cnn_features) # (batch, 1, attention_dim) | |
| lstm_proj = self.lstm_proj(lstm_features) # (batch, seq_len, attention_dim) | |
| combined = torch.tanh(cnn_proj + lstm_proj) # (batch, seq_len, attention_dim) | |
| attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1) | |
| attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim) | |
| return attended_features | |
| # ----------------------- | |
| # VQA Model | |
| # ----------------------- | |
| class VQAModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30): | |
| super(VQAModel, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.max_seq_len = max_seq_len | |
| # CNN Encoder: Trích xuất đặc trưng ảnh | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((1, 1)) | |
| ) | |
| # Text Embedding | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| # LSTM Encoders cho caption và question | |
| self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True) | |
| self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True) | |
| # Attention cho từng kênh | |
| self.attention = Attention(cnn_output_dim, lstm_units, attention_dim) | |
| # Decoder: sử dụng teacher forcing | |
| # Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question | |
| # Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256) | |
| # Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units | |
| self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units) | |
| self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True) | |
| self.fc_out = nn.Linear(lstm_units, vocab_size) | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, image, caption, question, answer_input): | |
| # --- CNN Encoder --- | |
| cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1) | |
| cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim) | |
| # --- Text Encoders --- | |
| cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim) | |
| cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units) | |
| q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim) | |
| q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units) | |
| # --- Attention --- | |
| cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units) | |
| q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units) | |
| q_last = q_output[:, -1, :] # (batch, lstm_units) | |
| # Context vector: (batch, 3*lstm_units) | |
| context = torch.cat([cap_attended, q_attended, q_last], dim=-1) | |
| # --- Decoder với Teacher Forcing --- | |
| # answer_input: (batch, ans_seq_len) | |
| answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim) | |
| context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units) | |
| decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units) | |
| decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units) | |
| decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units) | |
| output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size) | |
| return output | |
| def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
| self.eval() | |
| self.to(device) | |
| # Kiểm tra nếu image không có batch dimension thì thêm | |
| if image.dim() == 3: | |
| image = image.unsqueeze(0) | |
| image = image.to(device) | |
| question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()] | |
| question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device) | |
| # Encode image và question | |
| cnn_features = self.cnn(image) | |
| cnn_features = cnn_features.view(cnn_features.size(0), -1) | |
| q_embed = self.embedding(question) | |
| q_output, _ = self.question_lstm(q_embed) | |
| q_attended = self.attention(cnn_features.unsqueeze(1), q_output) | |
| q_last = q_output[:, -1, :] | |
| # Ở predict, sử dụng một context vector đơn giản từ question (hoặc kết hợp với các thành phần khác nếu có) | |
| context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units) | |
| # Khởi tạo câu trả lời với token <START> | |
| answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device) | |
| answer_words = [] | |
| hidden = None | |
| for _ in range(self.max_seq_len): | |
| answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim) | |
| context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) | |
| decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) | |
| decoder_in = self.decoder_input_proj(decoder_in) | |
| decoder_output, hidden = self.decoder_lstm(decoder_in, hidden) | |
| output = self.fc_out(decoder_output[:, -1, :]) | |
| next_word_idx = output.argmax(dim=-1).item() | |
| if next_word_idx == word_to_idx['<END>']: | |
| break | |
| answer_words.append(idx_to_word[next_word_idx]) | |
| answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1) | |
| return ' '.join(answer_words) | |
| def load_model(model_path, word_to_idx_path, idx_to_word_path, device='cpu'): | |
| try: | |
| # Load từ điển từ file .pth | |
| word_to_idx = torch.load(word_to_idx_path, map_location=device) | |
| idx_to_word = torch.load(idx_to_word_path, map_location=device) | |
| # Khởi tạo mô hình | |
| model = VQAModel(vocab_size=len(word_to_idx)) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| return model, word_to_idx, idx_to_word | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| def predict(image, question, model, word_to_idx, idx_to_word, device='cpu'): | |
| try: | |
| # Chuyển đổi ảnh | |
| image = transform(image).unsqueeze(0).to(device) | |
| # Dự đoán | |
| answer = model.predict(image, question, word_to_idx, idx_to_word, device) | |
| return answer | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| return "Error generating answer" | |
| # Tạo transform cho ảnh | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def create_interface(): | |
| device = 'cpu' # Luôn dùng CPU trên Spaces | |
| try: | |
| model, word_to_idx, idx_to_word = load_model( | |
| "vqa_model.pth", | |
| "word_to_idx.pth", | |
| "idx_to_word.pth", | |
| device | |
| ) | |
| def predict(image, question): | |
| try: | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = transform(image).unsqueeze(0).to(device) | |
| answer = model.predict(image, question, word_to_idx, idx_to_word, device) | |
| return answer | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Textbox(label="Question") | |
| ], | |
| outputs=gr.Textbox(label="Answer"), | |
| title="Visual Question Answering", | |
| description="Tải ảnh về động vật lên và đặt câu hỏi liên quan (CHỈ HỖ TRỢ TIẾNG ANH)" | |
| ) | |
| return iface | |
| except Exception as e: | |
| return gr.Interface(lambda: "Model failed to load", None, "text") | |
| if __name__ == "__main__": | |
| iface = create_interface() | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |