Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import warnings | |
| # Tắt cảnh báo không cần thiết | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # ----------------------- | |
| # Attention Module | |
| # ----------------------- | |
| class Attention_PT(nn.Module): | |
| def __init__(self, cnn_dim, lstm_dim, attention_dim): | |
| super(Attention_PT, self).__init__() | |
| self.cnn_proj = nn.Linear(cnn_dim, attention_dim) | |
| self.lstm_proj = nn.Linear(lstm_dim, attention_dim) | |
| self.attn = nn.Linear(attention_dim, 1) | |
| def forward(self, cnn_features, lstm_features): | |
| cnn_proj = self.cnn_proj(cnn_features) | |
| lstm_proj = self.lstm_proj(lstm_features) | |
| combined = torch.tanh(cnn_proj + lstm_proj) | |
| attn_weights = F.softmax(self.attn(combined), dim=1) | |
| attended_features = (attn_weights * lstm_features).sum(dim=1) | |
| return attended_features | |
| # ----------------------- | |
| # Pre-trained VQA Model | |
| # ----------------------- | |
| class PretrainedVQAModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, attention_dim=256, max_seq_len=30): | |
| super(PretrainedVQAModel, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.max_seq_len = max_seq_len | |
| # Pre-trained CNN Encoder (ResNet18) | |
| resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) | |
| self.cnn = nn.Sequential(*list(resnet.children())[:-1]) | |
| self.cnn_output_dim = 512 | |
| # Text Embedding | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| # LSTM Encoder | |
| self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True) | |
| # Attention | |
| self.attention = Attention_PT(self.cnn_output_dim, lstm_units, attention_dim) | |
| # Decoder | |
| self.decoder_input_proj = nn.Linear(embedding_dim + 2 * lstm_units, lstm_units) | |
| self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True) | |
| self.fc_out = nn.Linear(lstm_units, vocab_size) | |
| self.dropout = nn.Dropout(0.5) | |
| def predict(self, image, question, word_to_idx, idx_to_word, device='cpu'): | |
| self.eval() | |
| with torch.no_grad(): | |
| if image.dim() == 3: | |
| image = image.unsqueeze(0) | |
| image = image.to(device) | |
| # Process question | |
| question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) | |
| for word in question.lower().split()] | |
| question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device) | |
| # Forward pass | |
| cnn_features = self.cnn(image).view(-1, self.cnn_output_dim) | |
| q_embed = self.embedding(question) | |
| q_output, _ = self.question_lstm(q_embed) | |
| q_attended = self.attention(cnn_features.unsqueeze(1), q_output) | |
| q_last = q_output[:, -1, :] | |
| context = torch.cat([q_attended, q_last], dim=-1) | |
| # Generate answer | |
| answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device) | |
| answer_words = [] | |
| for _ in range(self.max_seq_len): | |
| answer_embed = self.embedding(answer_input) | |
| context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) | |
| decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) | |
| decoder_in = self.decoder_input_proj(decoder_in) | |
| decoder_output, _ = self.decoder_lstm(decoder_in) | |
| output = self.fc_out(decoder_output[:, -1, :]) | |
| next_word_idx = output.argmax(dim=-1).item() | |
| if next_word_idx == word_to_idx['<END>']: | |
| break | |
| answer_words.append(idx_to_word[str(next_word_idx)]) | |
| answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device) | |
| return ' '.join(answer_words) | |
| # ----------------------- | |
| # Model Loader | |
| # ----------------------- | |
| def load_model(): | |
| device = 'cpu' | |
| try: | |
| word_to_idx = torch.load("word_to_idx.pth", map_location=device) | |
| idx_to_word = torch.load("idx_to_word.pth", map_location=device) | |
| model = PretrainedVQAModel(vocab_size=len(word_to_idx)) | |
| model.load_state_dict(torch.load("vqa_pretrain_model.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| return model, word_to_idx, idx_to_word | |
| except Exception as e: | |
| raise RuntimeError(f"Model loading failed: {str(e)}") | |
| # ----------------------- | |
| # Gradio Interface | |
| # ----------------------- | |
| def create_app(): | |
| try: | |
| model, word_to_idx, idx_to_word = load_model() | |
| def preprocess_image(image): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image).unsqueeze(0) | |
| def predict(image, question): | |
| try: | |
| image_tensor = preprocess_image(image) | |
| answer = model.predict(image_tensor, question, word_to_idx, idx_to_word) | |
| return answer | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| return gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Textbox(label="Your Question", placeholder="What is in this image?") | |
| ], | |
| outputs=gr.Textbox(label="Generated Answer"), | |
| title="Visual Question Answering", | |
| description="Upload an image and ask questions about its content", | |
| allow_flagging="never" | |
| ) | |
| except Exception as e: | |
| return gr.Interface( | |
| lambda: f"Initialization failed: {str(e)}", | |
| inputs=None, | |
| outputs="text", | |
| title="Error" | |
| ) | |
| # ----------------------- | |
| # Main Execution | |
| # ----------------------- | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.launch(server_name="0.0.0.0", server_port=7860) |