Tin113 commited on
Commit
4029376
·
verified ·
1 Parent(s): 2af0460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -3,6 +3,39 @@ import gradio as gr
3
  from PIL import Image
4
  from torchvision import transforms
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # Load mô hình từ Hugging Face Model Hub hoặc local
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
@@ -47,7 +80,7 @@ def predict(image, question):
47
  answer = idx_to_word[predicted_idx]
48
  return answer
49
 
50
- # 🎨 Giao diện Gradio
51
  iface = gr.Interface(
52
  fn=predict,
53
  inputs=[gr.Image(type="pil"), gr.Textbox(label="Câu hỏi")],
 
3
  from PIL import Image
4
  from torchvision import transforms
5
 
6
+ import torch.nn as nn
7
+ import torchvision.models as models
8
+
9
+ class VQAModel(nn.Module):
10
+ def __init__(self, vocab_size):
11
+ super(VQAModel, self).__init__()
12
+ # Dùng ResNet làm CNN Encoder
13
+ self.cnn = models.resnet18(pretrained=True)
14
+ self.cnn.fc = nn.Linear(512, 256) # Thay FC layer
15
+
16
+ # Dùng LSTM làm Text Encoder
17
+ self.embedding = nn.Embedding(vocab_size, 256)
18
+ self.lstm = nn.LSTM(256, 256, batch_first=True)
19
+
20
+ # Fully Connected Layer để dự đoán câu trả lời
21
+ self.fc = nn.Linear(256, vocab_size)
22
+
23
+ def forward(self, image, question):
24
+ # Encode ảnh
25
+ img_features = self.cnn(image)
26
+
27
+ # Encode câu hỏi
28
+ q_embed = self.embedding(question)
29
+ _, (q_features, _) = self.lstm(q_embed)
30
+
31
+ # Kết hợp đặc trưng ảnh và câu hỏi
32
+ combined = img_features + q_features.squeeze(0)
33
+ output = self.fc(combined)
34
+
35
+ return output
36
+
37
+
38
+
39
  # Load mô hình từ Hugging Face Model Hub hoặc local
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
 
 
80
  answer = idx_to_word[predicted_idx]
81
  return answer
82
 
83
+ # Giao diện Gradio
84
  iface = gr.Interface(
85
  fn=predict,
86
  inputs=[gr.Image(type="pil"), gr.Textbox(label="Câu hỏi")],