import torch import torch.nn.functional as F # from utils.transforms import transforms # from utils.vocab import Vocabulary # from utils.helpers import VOCAB_PATH, ENCODER_PATH, DECODER_PATH # from encoder import EncoderCNN # from decoder import DecoderRNN # import PIL.Image as Image def sample(features, decoder, vocab, max_len=20): device = features.device result_caption = [] word_idx = torch.tensor([vocab['']]).unsqueeze(0).to(device) # Shape (1, 1) outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size) for _ in range(max_len): predicted = outputs.argmax(2) word = vocab[predicted.item()] # .item() to get numerical value from tensor if word == '': break result_caption.append(word) # Pass features=None and previous hidden state outputs, hidden = decoder(None, predicted, hidden) return ' '.join(result_caption) def beam_sample(features, decoder, vocab, beam_size=5, max_len=30): device = features.device # (log_score, sequence, hidden_state) start_token = torch.tensor([vocab['']]).to(device) beams = [(0, [start_token.item()], None)] for _ in range(max_len): candidates = [] for score, seq, hidden in beams: if seq[-1] == vocab['']: candidates.append((score, seq, hidden)) continue # Predict next word curr_word = torch.tensor([seq[-1]]).unsqueeze(0).to(device) # Use features only on first step feat_input = features if _ == 0 else None outputs, next_hidden = decoder(feat_input, curr_word, hidden) # Get log probabilities log_probs = F.log_softmax(outputs.squeeze(1), dim=1) top_probs, top_idxs = log_probs.topk(beam_size) for i in range(beam_size): candidates.append((score + top_probs[0][i].item(), seq + [top_idxs[0][i].item()], next_hidden)) # Sort by score and keep top k beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam_size] # Stop if all beams end in if all(s[-1] == vocab[''] for _, s, _ in beams): break # Return best sequence (minus tokens) best_seq = beams[0][1] return ' '.join([vocab[idx] for idx in best_seq if idx not in [vocab[''], vocab['']]]) def sample_with_temp(features, decoder, vocab, temp=0.8, max_len=30): device = features.device result_caption = [] word_idx = torch.tensor([vocab['']]).unsqueeze(0).to(device) outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size) for _ in range(max_len): # Apply temperature to logits logits = outputs.squeeze(1) / temp probs = F.softmax(logits, dim=-1) # Sample from the distribution instead of argmax predicted = torch.multinomial(probs, 1) word = vocab[predicted.item()] if word == '': break result_caption.append(word) outputs, hidden = decoder(None, predicted, hidden) return ' '.join(result_caption)