vqa_project / app.py
Tin113's picture
Update app.py
c574d51 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from torchvision import transforms
from PIL import Image
import numpy as np
# ============================================================================
# -----------------------
# Attention Module
# -----------------------
class Attention(nn.Module):
def __init__(self, cnn_dim, lstm_dim, attention_dim):
super(Attention, self).__init__()
self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
self.attn = nn.Linear(attention_dim, 1)
def forward(self, cnn_features, lstm_features):
# cnn_features: (batch, 1, cnn_dim)
# lstm_features: (batch, seq_len, lstm_dim)
cnn_proj = self.cnn_proj(cnn_features) # (batch, 1, attention_dim)
lstm_proj = self.lstm_proj(lstm_features) # (batch, seq_len, attention_dim)
combined = torch.tanh(cnn_proj + lstm_proj) # (batch, seq_len, attention_dim)
attn_weights = F.softmax(self.attn(combined), dim=1) # (batch, seq_len, 1)
attended_features = (attn_weights * lstm_features).sum(dim=1) # (batch, lstm_dim)
return attended_features
# -----------------------
# VQA Model
# -----------------------
class VQAModel(nn.Module):
def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, cnn_output_dim=512, attention_dim=256, max_seq_len=30):
super(VQAModel, self).__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
# CNN Encoder: Trích xuất đặc trưng ảnh
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, cnn_output_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1))
)
# Text Embedding
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# LSTM Encoders cho caption và question
self.caption_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
# Attention cho từng kênh
self.attention = Attention(cnn_output_dim, lstm_units, attention_dim)
# Decoder: sử dụng teacher forcing
# Context vector: kết hợp của attention từ caption, attention từ question và trạng thái cuối của question
# Kích thước context = lstm_units + lstm_units + lstm_units = 3 * lstm_units (ví dụ 768 nếu lstm_units=256)
# Kết hợp với embedding của câu trả lời (embedding_dim) => đầu vào của decoder = embedding_dim + 3*lstm_units
self.decoder_input_proj = nn.Linear(embedding_dim + 3 * lstm_units, lstm_units)
self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
self.fc_out = nn.Linear(lstm_units, vocab_size)
self.dropout = nn.Dropout(0.5)
def forward(self, image, caption, question, answer_input):
# --- CNN Encoder ---
cnn_features = self.cnn(image) # (batch, cnn_output_dim, 1, 1)
cnn_features = cnn_features.view(cnn_features.size(0), -1) # (batch, cnn_output_dim)
# --- Text Encoders ---
cap_embed = self.embedding(caption) # (batch, cap_seq_len, embedding_dim)
cap_output, _ = self.caption_lstm(cap_embed) # (batch, cap_seq_len, lstm_units)
q_embed = self.embedding(question) # (batch, q_seq_len, embedding_dim)
q_output, _ = self.question_lstm(q_embed) # (batch, q_seq_len, lstm_units)
# --- Attention ---
cap_attended = self.attention(cnn_features.unsqueeze(1), cap_output) # (batch, lstm_units)
q_attended = self.attention(cnn_features.unsqueeze(1), q_output) # (batch, lstm_units)
q_last = q_output[:, -1, :] # (batch, lstm_units)
# Context vector: (batch, 3*lstm_units)
context = torch.cat([cap_attended, q_attended, q_last], dim=-1)
# --- Decoder với Teacher Forcing ---
# answer_input: (batch, ans_seq_len)
answer_embed = self.embedding(answer_input) # (batch, ans_seq_len, embedding_dim)
context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1) # (batch, ans_seq_len, 3*lstm_units)
decoder_in = torch.cat([answer_embed, context_repeated], dim=-1) # (batch, ans_seq_len, embedding_dim + 3*lstm_units)
decoder_in = self.decoder_input_proj(decoder_in) # (batch, ans_seq_len, lstm_units)
decoder_output, _ = self.decoder_lstm(decoder_in) # (batch, ans_seq_len, lstm_units)
output = self.fc_out(self.dropout(decoder_output)) # (batch, ans_seq_len, vocab_size)
return output
def predict(self, image, question, word_to_idx, idx_to_word, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.eval()
self.to(device)
# Kiểm tra nếu image không có batch dimension thì thêm
if image.dim() == 3:
image = image.unsqueeze(0)
image = image.to(device)
question_seq = [word_to_idx.get(word, word_to_idx['<PAD>']) for word in question.lower().split()]
question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
# Encode image và question
cnn_features = self.cnn(image)
cnn_features = cnn_features.view(cnn_features.size(0), -1)
q_embed = self.embedding(question)
q_output, _ = self.question_lstm(q_embed)
q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
q_last = q_output[:, -1, :]
# Ở predict, sử dụng một context vector đơn giản từ question (hoặc kết hợp với các thành phần khác nếu có)
context = torch.cat([q_attended, q_attended, q_last], dim=-1) # (1, 3*lstm_units)
# Khởi tạo câu trả lời với token <START>
answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
answer_words = []
hidden = None
for _ in range(self.max_seq_len):
answer_embed = self.embedding(answer_input) # (1, seq_len, embedding_dim)
context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
decoder_in = self.decoder_input_proj(decoder_in)
decoder_output, hidden = self.decoder_lstm(decoder_in, hidden)
output = self.fc_out(decoder_output[:, -1, :])
next_word_idx = output.argmax(dim=-1).item()
if next_word_idx == word_to_idx['<END>']:
break
answer_words.append(idx_to_word[next_word_idx])
answer_input = torch.cat([answer_input, torch.tensor([[next_word_idx]], dtype=torch.long).to(device)], dim=1)
return ' '.join(answer_words)
def load_model(model_path, word_to_idx_path, idx_to_word_path, device='cpu'):
try:
# Load từ điển từ file .pth
word_to_idx = torch.load(word_to_idx_path, map_location=device)
idx_to_word = torch.load(idx_to_word_path, map_location=device)
# Khởi tạo mô hình
model = VQAModel(vocab_size=len(word_to_idx))
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model, word_to_idx, idx_to_word
except Exception as e:
print(f"Error loading model: {e}")
raise
def predict(image, question, model, word_to_idx, idx_to_word, device='cpu'):
try:
# Chuyển đổi ảnh
image = transform(image).unsqueeze(0).to(device)
# Dự đoán
answer = model.predict(image, question, word_to_idx, idx_to_word, device)
return answer
except Exception as e:
print(f"Prediction error: {e}")
return "Error generating answer"
# Tạo transform cho ảnh
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def create_interface():
device = 'cpu' # Luôn dùng CPU trên Spaces
try:
model, word_to_idx, idx_to_word = load_model(
"vqa_model.pth",
"word_to_idx.pth",
"idx_to_word.pth",
device
)
def predict(image, question):
try:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0).to(device)
answer = model.predict(image, question, word_to_idx, idx_to_word, device)
return answer
except Exception as e:
return f"Error: {str(e)}"
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Question")
],
outputs=gr.Textbox(label="Answer"),
title="Visual Question Answering",
description="Tải ảnh về động vật lên và đặt câu hỏi liên quan (CHỈ HỖ TRỢ TIẾNG ANH)"
)
return iface
except Exception as e:
return gr.Interface(lambda: "Model failed to load", None, "text")
if __name__ == "__main__":
iface = create_interface()
iface.launch(
server_name="0.0.0.0",
server_port=7860
)