import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import warnings # Tắt cảnh báo không cần thiết warnings.filterwarnings("ignore", category=UserWarning) # ----------------------- # Attention Module # ----------------------- class Attention_PT(nn.Module): def __init__(self, cnn_dim, lstm_dim, attention_dim): super(Attention_PT, self).__init__() self.cnn_proj = nn.Linear(cnn_dim, attention_dim) self.lstm_proj = nn.Linear(lstm_dim, attention_dim) self.attn = nn.Linear(attention_dim, 1) def forward(self, cnn_features, lstm_features): cnn_proj = self.cnn_proj(cnn_features) lstm_proj = self.lstm_proj(lstm_features) combined = torch.tanh(cnn_proj + lstm_proj) attn_weights = F.softmax(self.attn(combined), dim=1) attended_features = (attn_weights * lstm_features).sum(dim=1) return attended_features # ----------------------- # Pre-trained VQA Model # ----------------------- class PretrainedVQAModel(nn.Module): def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, attention_dim=256, max_seq_len=30): super(PretrainedVQAModel, self).__init__() self.vocab_size = vocab_size self.max_seq_len = max_seq_len # Pre-trained CNN Encoder (ResNet18) resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) self.cnn = nn.Sequential(*list(resnet.children())[:-1]) self.cnn_output_dim = 512 # Text Embedding self.embedding = nn.Embedding(vocab_size, embedding_dim) # LSTM Encoder self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True) # Attention self.attention = Attention_PT(self.cnn_output_dim, lstm_units, attention_dim) # Decoder self.decoder_input_proj = nn.Linear(embedding_dim + 2 * lstm_units, lstm_units) self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True) self.fc_out = nn.Linear(lstm_units, vocab_size) self.dropout = nn.Dropout(0.5) def predict(self, image, question, word_to_idx, idx_to_word, device='cpu'): self.eval() with torch.no_grad(): if image.dim() == 3: image = image.unsqueeze(0) image = image.to(device) # Process question question_seq = [word_to_idx.get(word, word_to_idx['']) for word in question.lower().split()] question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device) # Forward pass cnn_features = self.cnn(image).view(-1, self.cnn_output_dim) q_embed = self.embedding(question) q_output, _ = self.question_lstm(q_embed) q_attended = self.attention(cnn_features.unsqueeze(1), q_output) q_last = q_output[:, -1, :] context = torch.cat([q_attended, q_last], dim=-1) # Generate answer answer_input = torch.tensor([[word_to_idx['']]], dtype=torch.long).to(device) answer_words = [] for _ in range(self.max_seq_len): answer_embed = self.embedding(answer_input) context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) decoder_in = self.decoder_input_proj(decoder_in) decoder_output, _ = self.decoder_lstm(decoder_in) output = self.fc_out(decoder_output[:, -1, :]) next_word_idx = output.argmax(dim=-1).item() if next_word_idx == word_to_idx['']: break answer_words.append(idx_to_word[str(next_word_idx)]) answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device) return ' '.join(answer_words) # ----------------------- # Model Loader # ----------------------- def load_model(): device = 'cpu' try: word_to_idx = torch.load("word_to_idx.pth", map_location=device) idx_to_word = torch.load("idx_to_word.pth", map_location=device) model = PretrainedVQAModel(vocab_size=len(word_to_idx)) model.load_state_dict(torch.load("vqa_pretrain_model.pth", map_location=device)) model.to(device) model.eval() return model, word_to_idx, idx_to_word except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") # ----------------------- # Gradio Interface # ----------------------- def create_app(): try: model, word_to_idx, idx_to_word = load_model() def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def predict(image, question): try: image_tensor = preprocess_image(image) answer = model.predict(image_tensor, question, word_to_idx, idx_to_word) return answer except Exception as e: return f"Error: {str(e)}" return gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Your Question", placeholder="What is in this image?") ], outputs=gr.Textbox(label="Generated Answer"), title="Visual Question Answering", description="Upload an image and ask questions about its content", allow_flagging="never" ) except Exception as e: return gr.Interface( lambda: f"Initialization failed: {str(e)}", inputs=None, outputs="text", title="Error" ) # ----------------------- # Main Execution # ----------------------- if __name__ == "__main__": app = create_app() app.launch(server_name="0.0.0.0", server_port=7860)