Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| # from utils.transforms import transforms | |
| # from utils.vocab import Vocabulary | |
| # from utils.helpers import VOCAB_PATH, ENCODER_PATH, DECODER_PATH | |
| # from encoder import EncoderCNN | |
| # from decoder import DecoderRNN | |
| # import PIL.Image as Image | |
| def sample(features, decoder, vocab, max_len=20): | |
| device = features.device | |
| result_caption = [] | |
| word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device) # Shape (1, 1) | |
| outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size) | |
| for _ in range(max_len): | |
| predicted = outputs.argmax(2) | |
| word = vocab[predicted.item()] # .item() to get numerical value from tensor | |
| if word == '<EOS>': | |
| break | |
| result_caption.append(word) | |
| # Pass features=None and previous hidden state | |
| outputs, hidden = decoder(None, predicted, hidden) | |
| return ' '.join(result_caption) | |
| def beam_sample(features, decoder, vocab, beam_size=5, max_len=30): | |
| device = features.device | |
| # (log_score, sequence, hidden_state) | |
| start_token = torch.tensor([vocab['<SOS>']]).to(device) | |
| beams = [(0, [start_token.item()], None)] | |
| for _ in range(max_len): | |
| candidates = [] | |
| for score, seq, hidden in beams: | |
| if seq[-1] == vocab['<EOS>']: | |
| candidates.append((score, seq, hidden)) | |
| continue | |
| # Predict next word | |
| curr_word = torch.tensor([seq[-1]]).unsqueeze(0).to(device) | |
| # Use features only on first step | |
| feat_input = features if _ == 0 else None | |
| outputs, next_hidden = decoder(feat_input, curr_word, hidden) | |
| # Get log probabilities | |
| log_probs = F.log_softmax(outputs.squeeze(1), dim=1) | |
| top_probs, top_idxs = log_probs.topk(beam_size) | |
| for i in range(beam_size): | |
| candidates.append((score + top_probs[0][i].item(), | |
| seq + [top_idxs[0][i].item()], | |
| next_hidden)) | |
| # Sort by score and keep top k | |
| beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam_size] | |
| # Stop if all beams end in <EOS> | |
| if all(s[-1] == vocab['<EOS>'] for _, s, _ in beams): | |
| break | |
| # Return best sequence (minus tokens) | |
| best_seq = beams[0][1] | |
| return ' '.join([vocab[idx] for idx in best_seq if idx not in [vocab['<SOS>'], vocab['<EOS>']]]) | |
| def sample_with_temp(features, decoder, vocab, temp=0.8, max_len=30): | |
| device = features.device | |
| result_caption = [] | |
| word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device) | |
| outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size) | |
| for _ in range(max_len): | |
| # Apply temperature to logits | |
| logits = outputs.squeeze(1) / temp | |
| probs = F.softmax(logits, dim=-1) | |
| # Sample from the distribution instead of argmax | |
| predicted = torch.multinomial(probs, 1) | |
| word = vocab[predicted.item()] | |
| if word == '<EOS>': break | |
| result_caption.append(word) | |
| outputs, hidden = decoder(None, predicted, hidden) | |
| return ' '.join(result_caption) | |