Spaces:
Sleeping
Sleeping
File size: 6,383 Bytes
cd94662 84165ce cd94662 f097a6b cd94662 84165ce cd94662 84165ce cd94662 84165ce cd94662 84165ce 624c509 2fccb7e 624c509 cd94662 84165ce f097a6b 84165ce cd94662 f097a6b cd94662 f097a6b 624c509 cd94662 624c509 cd94662 84165ce 624c509 84165ce f097a6b 84165ce cd94662 84165ce f097a6b 84165ce cd94662 84165ce cd94662 2fccb7e b2e80e9 cd94662 f097a6b 84165ce f097a6b cd94662 84165ce cd94662 84165ce f097a6b cd94662 f097a6b cd94662 84165ce cd94662 f097a6b cd94662 84165ce f097a6b 84165ce cd94662 84165ce cd94662 84165ce f097a6b 84165ce cd94662 84165ce cd94662 f097a6b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import warnings
# Tắt cảnh báo không cần thiết
warnings.filterwarnings("ignore", category=UserWarning)
# -----------------------
# Attention Module
# -----------------------
class Attention_PT(nn.Module):
def __init__(self, cnn_dim, lstm_dim, attention_dim):
super(Attention_PT, self).__init__()
self.cnn_proj = nn.Linear(cnn_dim, attention_dim)
self.lstm_proj = nn.Linear(lstm_dim, attention_dim)
self.attn = nn.Linear(attention_dim, 1)
def forward(self, cnn_features, lstm_features):
cnn_proj = self.cnn_proj(cnn_features)
lstm_proj = self.lstm_proj(lstm_features)
combined = torch.tanh(cnn_proj + lstm_proj)
attn_weights = F.softmax(self.attn(combined), dim=1)
attended_features = (attn_weights * lstm_features).sum(dim=1)
return attended_features
# -----------------------
# Pre-trained VQA Model
# -----------------------
class PretrainedVQAModel(nn.Module):
def __init__(self, vocab_size, embedding_dim=256, lstm_units=256, attention_dim=256, max_seq_len=30):
super(PretrainedVQAModel, self).__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
# Pre-trained CNN Encoder (ResNet18)
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
self.cnn = nn.Sequential(*list(resnet.children())[:-1])
self.cnn_output_dim = 512
# Text Embedding
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# LSTM Encoder
self.question_lstm = nn.LSTM(embedding_dim, lstm_units, batch_first=True)
# Attention
self.attention = Attention_PT(self.cnn_output_dim, lstm_units, attention_dim)
# Decoder
self.decoder_input_proj = nn.Linear(embedding_dim + 2 * lstm_units, lstm_units)
self.decoder_lstm = nn.LSTM(lstm_units, lstm_units, batch_first=True)
self.fc_out = nn.Linear(lstm_units, vocab_size)
self.dropout = nn.Dropout(0.5)
def predict(self, image, question, word_to_idx, idx_to_word, device='cpu'):
self.eval()
with torch.no_grad():
if image.dim() == 3:
image = image.unsqueeze(0)
image = image.to(device)
# Process question
question_seq = [word_to_idx.get(word, word_to_idx['<PAD>'])
for word in question.lower().split()]
question = torch.tensor(question_seq, dtype=torch.long).unsqueeze(0).to(device)
# Forward pass
cnn_features = self.cnn(image).view(-1, self.cnn_output_dim)
q_embed = self.embedding(question)
q_output, _ = self.question_lstm(q_embed)
q_attended = self.attention(cnn_features.unsqueeze(1), q_output)
q_last = q_output[:, -1, :]
context = torch.cat([q_attended, q_last], dim=-1)
# Generate answer
answer_input = torch.tensor([[word_to_idx['<START>']]], dtype=torch.long).to(device)
answer_words = []
for _ in range(self.max_seq_len):
answer_embed = self.embedding(answer_input)
context_repeated = context.unsqueeze(1).repeat(1, answer_input.size(1), 1)
decoder_in = torch.cat([answer_embed, context_repeated], dim=-1)
decoder_in = self.decoder_input_proj(decoder_in)
decoder_output, _ = self.decoder_lstm(decoder_in)
output = self.fc_out(decoder_output[:, -1, :])
next_word_idx = output.argmax(dim=-1).item()
if next_word_idx == word_to_idx['<END>']:
break
answer_words.append(idx_to_word[str(next_word_idx)])
answer_input = torch.tensor([[next_word_idx]], dtype=torch.long).to(device)
return ' '.join(answer_words)
# -----------------------
# Model Loader
# -----------------------
def load_model():
device = 'cpu'
try:
word_to_idx = torch.load("word_to_idx.pth", map_location=device)
idx_to_word = torch.load("idx_to_word.pth", map_location=device)
model = PretrainedVQAModel(vocab_size=len(word_to_idx))
model.load_state_dict(torch.load("vqa_pretrain_model.pth", map_location=device))
model.to(device)
model.eval()
return model, word_to_idx, idx_to_word
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
# -----------------------
# Gradio Interface
# -----------------------
def create_app():
try:
model, word_to_idx, idx_to_word = load_model()
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0)
def predict(image, question):
try:
image_tensor = preprocess_image(image)
answer = model.predict(image_tensor, question, word_to_idx, idx_to_word)
return answer
except Exception as e:
return f"Error: {str(e)}"
return gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Your Question", placeholder="What is in this image?")
],
outputs=gr.Textbox(label="Generated Answer"),
title="Visual Question Answering",
description="Upload an image and ask questions about its content",
allow_flagging="never"
)
except Exception as e:
return gr.Interface(
lambda: f"Initialization failed: {str(e)}",
inputs=None,
outputs="text",
title="Error"
)
# -----------------------
# Main Execution
# -----------------------
if __name__ == "__main__":
app = create_app()
app.launch(server_name="0.0.0.0", server_port=7860) |