pretrained / app.py
Tin113's picture
Update app.py
f097a6b verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import warnings
# Tắt cảnh báo không cần thiết
warnings.filterwarnings("ignore", category=UserWarning)
# -----------------------
# Attention Module
# -----------------------
class Attention_PT(nn.Module):
def __init__(self, cnn_dim, lstm_dim, attention_dim):
super(Attention_PT, self).__init__()
self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
self.attn = nn.Linear(attention_dim, 1)
def forward(self, cnn_features, lstm_features):
cnn_proj = self.cnn_proj(cnn_features)
lstm_proj = self.lstm_proj(lstm_features)
combined = torch.tanh(cnn_proj + lstm_proj)
attn_weights = F.softmax(self.attn(combined), dim=1)
attended_features = (attn_weights * lstm_features).sum(dim=1)
return attended_features
# -----------------------
# Pre-trained VQA Model
# -----------------------
class PretrainedVQAModel(nn.Module):
def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, attention_dim=256, max_seq_len=30):
super(PretrainedVQAModel, self).__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
# Pre-trained CNN Encoder (ResNet18)
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.cnn = nn.Sequential(*list(resnet.children())[:-1])
self.cnn_output_dim = 512
# Text Embedding
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# LSTM Encoder
self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
# Attention
self.attention = Attention_PT(self.cnn_output_dim, lstm_units, attention_dim)
# Decoder
self.decoder_input_proj = nn.Linear(embedding_dim + 2 * lstm_units, lstm_units)
self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
self.fc_out = nn.Linear(lstm_units, vocab_size)
self.dropout = nn.Dropout(0.5)
def predict(self, image, question, word_to_idx, idx_to_word, device='cpu'):
self.eval()
with torch.no_grad():
if image.dim() == 3:
image = image.unsqueeze(0)
image = image.to(device)
# Process question
question_seq = [word_to_idx.get(word, word_to_idx['<PAD>'])
for word in question.lower().split()]
question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
# Forward pass
cnn_features = self.cnn(image).view(-1, self.cnn_output_dim)
q_embed = self.embedding(question)
q_output, _ = self.question_lstm(q_embed)
q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
q_last = q_output[:, -1, :]
context = torch.cat([q_attended, q_last], dim=-1)
# Generate answer
answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
answer_words = []
for _ in range(self.max_seq_len):
answer_embed = self.embedding(answer_input)
context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
decoder_in = self.decoder_input_proj(decoder_in)
decoder_output, _ = self.decoder_lstm(decoder_in)
output = self.fc_out(decoder_output[:, -1, :])
next_word_idx = output.argmax(dim=-1).item()
if next_word_idx == word_to_idx['<END>']:
break
answer_words.append(idx_to_word[str(next_word_idx)])
answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device)
return ' '.join(answer_words)
# -----------------------
# Model Loader
# -----------------------
def load_model():
device = 'cpu'
try:
word_to_idx = torch.load("word_to_idx.pth", map_location=device)
idx_to_word = torch.load("idx_to_word.pth", map_location=device)
model = PretrainedVQAModel(vocab_size=len(word_to_idx))
model.load_state_dict(torch.load("vqa_pretrain_model.pth", map_location=device))
model.to(device)
model.eval()
return model, word_to_idx, idx_to_word
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
# -----------------------
# Gradio Interface
# -----------------------
def create_app():
try:
model, word_to_idx, idx_to_word = load_model()
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def predict(image, question):
try:
image_tensor = preprocess_image(image)
answer = model.predict(image_tensor, question, word_to_idx, idx_to_word)
return answer
except Exception as e:
return f"Error: {str(e)}"
return gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Your Question", placeholder="What is in this image?")
],
outputs=gr.Textbox(label="Generated Answer"),
title="Visual Question Answering",
description="Upload an image and ask questions about its content",
allow_flagging="never"
)
except Exception as e:
return gr.Interface(
lambda: f"Initialization failed: {str(e)}",
inputs=None,
outputs="text",
title="Error"
)
# -----------------------
# Main Execution
# -----------------------
if __name__ == "__main__":
app = create_app()
app.launch(server_name="0.0.0.0", server_port=7860)