File size: 1,532 Bytes
3974b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import json
import torch
import torchvision.transforms as T
from PIL import Image
from model import EncoderDecoder

# Load vocab từ vocab.json
with open("vocab.json", "r", encoding="utf-8") as f:
    vocab_json = json.load(f)
    itos = {int(k): v for k, v in vocab_json["itos"].items()}
    stoi = {v: k for k, v in itos.items()}

class Vocabulary:
    def __init__(self, itos, stoi):
        self.itos = itos
        self.stoi = stoi
    def __len__(self):
        return len(self.itos)

vocab = Vocabulary(itos, stoi)

# ✅ Load checkpoint và các siêu tham số từ file .pth
ckpt = torch.load("attention_model_state.pth", map_location="cpu")
embed_dim = 300  # giữ nguyên như lúc huấn luyện
vocab_size = ckpt["vocab_size"]
attention_dim = ckpt["attention_dim"]
encoder_dim = ckpt["encoder_dim"]
decoder_dim = ckpt["decoder_dim"]

# ✅ Khởi tạo mô hình với tham số từ checkpoint
embed_tensor = torch.randn(vocab_size, embed_dim)
model = EncoderDecoder(embed_tensor, vocab, attention_dim, encoder_dim, decoder_dim)
model.load_state_dict(ckpt["state_dict"])
model.eval()

# ✅ Tiền xử lý ảnh
transform = T.Compose([
    T.Resize(226),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# ✅ Hàm dự đoán caption từ ảnh
def predict_caption(image: Image.Image):
    if image.mode != "RGB":
        image = image.convert("RGB")
    image_tensor = transform(image)
    caption = model.predict_caption(image_tensor)
    return caption