caption-gen / inference.py
Sher1988's picture
Update inference.py
455dcea verified
import torch
import torch.nn.functional as F
# from utils.transforms import transforms
# from utils.vocab import Vocabulary
# from utils.helpers import VOCAB_PATH, ENCODER_PATH, DECODER_PATH
# from encoder import EncoderCNN
# from decoder import DecoderRNN
# import PIL.Image as Image
def sample(features, decoder, vocab, max_len=20):
device = features.device
result_caption = []
word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device) # Shape (1, 1)
outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size)
for _ in range(max_len):
predicted = outputs.argmax(2)
word = vocab[predicted.item()] # .item() to get numerical value from tensor
if word == '<EOS>':
break
result_caption.append(word)
# Pass features=None and previous hidden state
outputs, hidden = decoder(None, predicted, hidden)
return ' '.join(result_caption)
def beam_sample(features, decoder, vocab, beam_size=5, max_len=30):
device = features.device
# (log_score, sequence, hidden_state)
start_token = torch.tensor([vocab['<SOS>']]).to(device)
beams = [(0, [start_token.item()], None)]
for _ in range(max_len):
candidates = []
for score, seq, hidden in beams:
if seq[-1] == vocab['<EOS>']:
candidates.append((score, seq, hidden))
continue
# Predict next word
curr_word = torch.tensor([seq[-1]]).unsqueeze(0).to(device)
# Use features only on first step
feat_input = features if _ == 0 else None
outputs, next_hidden = decoder(feat_input, curr_word, hidden)
# Get log probabilities
log_probs = F.log_softmax(outputs.squeeze(1), dim=1)
top_probs, top_idxs = log_probs.topk(beam_size)
for i in range(beam_size):
candidates.append((score + top_probs[0][i].item(),
seq + [top_idxs[0][i].item()],
next_hidden))
# Sort by score and keep top k
beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam_size]
# Stop if all beams end in <EOS>
if all(s[-1] == vocab['<EOS>'] for _, s, _ in beams):
break
# Return best sequence (minus tokens)
best_seq = beams[0][1]
return ' '.join([vocab[idx] for idx in best_seq if idx not in [vocab['<SOS>'], vocab['<EOS>']]])
def sample_with_temp(features, decoder, vocab, temp=0.8, max_len=30):
device = features.device
result_caption = []
word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device)
outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size)
for _ in range(max_len):
# Apply temperature to logits
logits = outputs.squeeze(1) / temp
probs = F.softmax(logits, dim=-1)
# Sample from the distribution instead of argmax
predicted = torch.multinomial(probs, 1)
word = vocab[predicted.item()]
if word == '<EOS>': break
result_caption.append(word)
outputs, hidden = decoder(None, predicted, hidden)
return ' '.join(result_caption)