Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from model import EncoderDecoder | |
| # Load vocab từ vocab.json | |
| with open("vocab.json", "r", encoding="utf-8") as f: | |
| vocab_json = json.load(f) | |
| itos = {int(k): v for k, v in vocab_json["itos"].items()} | |
| stoi = {v: k for k, v in itos.items()} | |
| class Vocabulary: | |
| def __init__(self, itos, stoi): | |
| self.itos = itos | |
| self.stoi = stoi | |
| def __len__(self): | |
| return len(self.itos) | |
| vocab = Vocabulary(itos, stoi) | |
| # ✅ Load checkpoint và các siêu tham số từ file .pth | |
| ckpt = torch.load("attention_model_state.pth", map_location="cpu") | |
| embed_dim = 300 # giữ nguyên như lúc huấn luyện | |
| vocab_size = ckpt["vocab_size"] | |
| attention_dim = ckpt["attention_dim"] | |
| encoder_dim = ckpt["encoder_dim"] | |
| decoder_dim = ckpt["decoder_dim"] | |
| # ✅ Khởi tạo mô hình với tham số từ checkpoint | |
| embed_tensor = torch.randn(vocab_size, embed_dim) | |
| model = EncoderDecoder(embed_tensor, vocab, attention_dim, encoder_dim, decoder_dim) | |
| model.load_state_dict(ckpt["state_dict"]) | |
| model.eval() | |
| # ✅ Tiền xử lý ảnh | |
| transform = T.Compose([ | |
| T.Resize(226), | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| ]) | |
| # ✅ Hàm dự đoán caption từ ảnh | |
| def predict_caption(image: Image.Image): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| image_tensor = transform(image) | |
| caption = model.predict_caption(image_tensor) | |
| return caption |