Spaces:
Sleeping
Sleeping
Commit
·
58b663c
1
Parent(s):
e7324f8
Fix generate caption function
Browse files- source/predict_sample.py +35 -24
source/predict_sample.py
CHANGED
|
@@ -14,7 +14,10 @@ def generate_caption(image: torch.Tensor,
|
|
| 14 |
image_decoder: Decoder,
|
| 15 |
vocab: Vocab,
|
| 16 |
device: torch.device) -> list[str]:
|
| 17 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
Returns:
|
| 20 |
list[str]: caption for given image
|
|
@@ -25,49 +28,57 @@ def generate_caption(image: torch.Tensor,
|
|
| 25 |
image = image.unsqueeze(0)
|
| 26 |
# image: (1, 3, 224, 224)
|
| 27 |
|
| 28 |
-
features = image_encoder.forward(image)
|
| 29 |
-
# features: (1, IMAGE_EMB_DIM)
|
| 30 |
-
features = features.to(device)
|
| 31 |
-
features = features.unsqueeze(0)
|
| 32 |
-
# features: (1, 1, IMAGE_EMB_DIM)
|
| 33 |
-
|
| 34 |
hidden = image_decoder.hidden_state_0
|
| 35 |
cell = image_decoder.cell_state_0
|
| 36 |
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
|
| 37 |
|
| 38 |
sentence = []
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
| 42 |
|
| 43 |
MAX_LENGTH = 20
|
| 44 |
|
| 45 |
for i in range(MAX_LENGTH):
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# lstm_input : (1, 1, WORD_EMB_DIM)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
#
|
| 57 |
|
| 58 |
-
|
| 59 |
-
#
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# stop if we predict '<eos>'
|
| 65 |
if next_word_pred == vocab.index2word[vocab.EOS]:
|
| 66 |
break
|
| 67 |
|
| 68 |
-
sentence.append(next_word_pred)
|
| 69 |
-
previous_word = next_word_pred
|
| 70 |
-
|
| 71 |
return sentence
|
| 72 |
|
| 73 |
|
|
|
|
| 14 |
image_decoder: Decoder,
|
| 15 |
vocab: Vocab,
|
| 16 |
device: torch.device) -> list[str]:
|
| 17 |
+
"""
|
| 18 |
+
Generate caption of a single image of size (3, 224, 224).
|
| 19 |
+
Generating of caption starts with <sos>, and each next predicted word ID
|
| 20 |
+
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
list[str]: caption for given image
|
|
|
|
| 28 |
image = image.unsqueeze(0)
|
| 29 |
# image: (1, 3, 224, 224)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
hidden = image_decoder.hidden_state_0
|
| 32 |
cell = image_decoder.cell_state_0
|
| 33 |
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
|
| 34 |
|
| 35 |
sentence = []
|
| 36 |
|
| 37 |
+
# initialize LSTM input to SOS token = 1
|
| 38 |
+
input_words = [vocab.SOS]
|
| 39 |
|
| 40 |
MAX_LENGTH = 20
|
| 41 |
|
| 42 |
for i in range(MAX_LENGTH):
|
| 43 |
|
| 44 |
+
features = image_encoder.forward(image)
|
| 45 |
+
# features: (1, IMAGE_EMB_DIM)
|
| 46 |
+
features = features.to(device)
|
| 47 |
+
features = features.unsqueeze(0)
|
| 48 |
+
# features: (1, 1, IMAGE_EMB_DIM)
|
| 49 |
+
|
| 50 |
+
input_words_tensor = torch.tensor([input_words])
|
| 51 |
+
# input_word_tensor : (B=1, SEQ_LENGTH)
|
| 52 |
+
input_words_tensor = input_words_tensor.to(device)
|
| 53 |
+
|
| 54 |
+
lstm_input = emb_layer.forward(input_words_tensor)
|
| 55 |
+
# lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM)
|
| 56 |
+
|
| 57 |
+
lstm_input = lstm_input.permute(1, 0, 2)
|
| 58 |
+
# lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM)
|
| 59 |
+
SEQ_LENGTH = lstm_input.shape[0]
|
| 60 |
|
| 61 |
+
features = features.repeat(SEQ_LENGTH, 1, 1)
|
| 62 |
+
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
|
|
|
|
| 63 |
|
| 64 |
+
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
|
| 65 |
+
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
|
| 66 |
|
| 67 |
+
next_id_pred = next_id_pred[-1, 0, :]
|
| 68 |
+
# next_id_pred : (VOCAB_SIZE)
|
| 69 |
+
next_id_pred = torch.argmax(next_id_pred)
|
| 70 |
|
| 71 |
+
# append it to input_words which will be again as input for LSTM
|
| 72 |
+
input_words.append(next_id_pred.item())
|
| 73 |
+
|
| 74 |
+
# id --> word
|
| 75 |
+
next_word_pred = vocab.index_to_word(int(next_id_pred.item()))
|
| 76 |
+
sentence.append(next_word_pred)
|
| 77 |
|
| 78 |
# stop if we predict '<eos>'
|
| 79 |
if next_word_pred == vocab.index2word[vocab.EOS]:
|
| 80 |
break
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
return sentence
|
| 83 |
|
| 84 |
|