image-caption-vi / inference.py
kRnos22's picture
Upload 6 files
3974b9b verified
import json
import torch
import torchvision.transforms as T
from PIL import Image
from model import EncoderDecoder
# Load vocab từ vocab.json
with open("vocab.json", "r", encoding="utf-8") as f:
vocab_json = json.load(f)
itos = {int(k): v for k, v in vocab_json["itos"].items()}
stoi = {v: k for k, v in itos.items()}
class Vocabulary:
def __init__(self, itos, stoi):
self.itos = itos
self.stoi = stoi
def __len__(self):
return len(self.itos)
vocab = Vocabulary(itos, stoi)
# ✅ Load checkpoint và các siêu tham số từ file .pth
ckpt = torch.load("attention_model_state.pth", map_location="cpu")
embed_dim = 300 # giữ nguyên như lúc huấn luyện
vocab_size = ckpt["vocab_size"]
attention_dim = ckpt["attention_dim"]
encoder_dim = ckpt["encoder_dim"]
decoder_dim = ckpt["decoder_dim"]
# ✅ Khởi tạo mô hình với tham số từ checkpoint
embed_tensor = torch.randn(vocab_size, embed_dim)
model = EncoderDecoder(embed_tensor, vocab, attention_dim, encoder_dim, decoder_dim)
model.load_state_dict(ckpt["state_dict"])
model.eval()
# ✅ Tiền xử lý ảnh
transform = T.Compose([
T.Resize(226),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# ✅ Hàm dự đoán caption từ ảnh
def predict_caption(image: Image.Image):
if image.mode != "RGB":
image = image.convert("RGB")
image_tensor = transform(image)
caption = model.predict_caption(image_tensor)
return caption