import json import torch import torchvision.transforms as T from PIL import Image from model import EncoderDecoder # Load vocab từ vocab.json with open("vocab.json", "r", encoding="utf-8") as f: vocab_json = json.load(f) itos = {int(k): v for k, v in vocab_json["itos"].items()} stoi = {v: k for k, v in itos.items()} class Vocabulary: def __init__(self, itos, stoi): self.itos = itos self.stoi = stoi def __len__(self): return len(self.itos) vocab = Vocabulary(itos, stoi) # ✅ Load checkpoint và các siêu tham số từ file .pth ckpt = torch.load("attention_model_state.pth", map_location="cpu") embed_dim = 300 # giữ nguyên như lúc huấn luyện vocab_size = ckpt["vocab_size"] attention_dim = ckpt["attention_dim"] encoder_dim = ckpt["encoder_dim"] decoder_dim = ckpt["decoder_dim"] # ✅ Khởi tạo mô hình với tham số từ checkpoint embed_tensor = torch.randn(vocab_size, embed_dim) model = EncoderDecoder(embed_tensor, vocab, attention_dim, encoder_dim, decoder_dim) model.load_state_dict(ckpt["state_dict"]) model.eval() # ✅ Tiền xử lý ảnh transform = T.Compose([ T.Resize(226), T.CenterCrop(224), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # ✅ Hàm dự đoán caption từ ảnh def predict_caption(image: Image.Image): if image.mode != "RGB": image = image.convert("RGB") image_tensor = transform(image) caption = model.predict_caption(image_tensor) return caption