Spaces:
Sleeping
Sleeping
| from torchvision import transforms | |
| import torch | |
| import torch.utils.data | |
| from PIL import Image | |
| from source.vocab import Vocab | |
| from source.model import Decoder, Encoder | |
| from source.config import Config | |
| def generate_caption(image: torch.Tensor, | |
| image_encoder: Encoder, | |
| emb_layer: torch.nn.Embedding, | |
| image_decoder: Decoder, | |
| vocab: Vocab, | |
| device: torch.device) -> list[str]: | |
| """ | |
| Generate caption of a single image of size (3, 224, 224). | |
| Generating of caption starts with <sos>, and each next predicted word ID | |
| is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>. | |
| Returns: | |
| list[str]: caption for given image | |
| """ | |
| image = image.to(device) | |
| # image: (3, 224, 224) | |
| image = image.unsqueeze(0) | |
| # image: (1, 3, 224, 224) | |
| hidden = image_decoder.hidden_state_0 | |
| cell = image_decoder.cell_state_0 | |
| # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM) | |
| sentence = [] | |
| # initialize LSTM input to SOS token = 1 | |
| input_words = [vocab.SOS] | |
| MAX_LENGTH = 20 | |
| for i in range(MAX_LENGTH): | |
| features = image_encoder.forward(image) | |
| # features: (1, IMAGE_EMB_DIM) | |
| features = features.to(device) | |
| features = features.unsqueeze(0) | |
| # features: (1, 1, IMAGE_EMB_DIM) | |
| input_words_tensor = torch.tensor([input_words]) | |
| # input_word_tensor : (B=1, SEQ_LENGTH) | |
| input_words_tensor = input_words_tensor.to(device) | |
| lstm_input = emb_layer.forward(input_words_tensor) | |
| # lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM) | |
| lstm_input = lstm_input.permute(1, 0, 2) | |
| # lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM) | |
| SEQ_LENGTH = lstm_input.shape[0] | |
| features = features.repeat(SEQ_LENGTH, 1, 1) | |
| # features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM) | |
| next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, hidden, cell) | |
| # next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE) | |
| next_id_pred = next_id_pred[-1, 0, :] | |
| # next_id_pred : (VOCAB_SIZE) | |
| next_id_pred = torch.argmax(next_id_pred) | |
| # append it to input_words which will be again as input for LSTM | |
| input_words.append(next_id_pred.item()) | |
| # id --> word | |
| next_word_pred = vocab.index_to_word(int(next_id_pred.item())) | |
| if next_word_pred == vocab.index2word[vocab.EOS]: | |
| break | |
| sentence.append(next_word_pred) | |
| return sentence | |
| def main_caption(image): | |
| config = Config() | |
| vocab = Vocab() | |
| vocab.load_vocab(config.VOCAB_FILE) | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = transform(image) | |
| image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM, | |
| device=config.DEVICE) | |
| emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE, | |
| embedding_dim=config.WORD_EMB_DIM, | |
| padding_idx=vocab.PADDING_INDEX) | |
| image_decoder = Decoder(word_emb_dim=config.WORD_EMB_DIM, | |
| hidden_dim=config.HIDDEN_DIM, | |
| num_layers=config.NUM_LAYER, | |
| vocab_size=config.VOCAB_SIZE, | |
| device=config.DEVICE) | |
| emb_layer.eval() | |
| image_encoder.eval() | |
| image_decoder.eval() | |
| emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE)) | |
| image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE)) | |
| image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE)) | |
| emb_layer = emb_layer.to(config.DEVICE) | |
| image_encoder = image_encoder.to(config.DEVICE) | |
| image_decoder = image_decoder.to(config.DEVICE) | |
| image = image.to(config.DEVICE) | |
| sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE) | |
| description = ' '.join(word for word in sentence) | |
| return description | |