Spaces:
Sleeping
Sleeping
File size: 3,878 Bytes
c1596ac e3e7de5 c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | ###### best val loss 지점에서 모든 생성 캡션 출력 및 반환, heatmap 저장 #####
import os
import torch
from src.utils.checkpoint_manager import load_checkpoint
def make_show_all_caption(
loader,
encoder,
decoder,
optimizer,
w2i,
i2w,
best_path,
dec_atten_dir,
enc_dec_atten_dir,
save_prefix,
sample,
layer,
device,
use_subword,
sp_model_path,
use_beam_search,
beam_size
):
_, best_val_loss = load_checkpoint(
best_path,
encoder,
decoder,
optimizer,
device
)
all_references = []
all_generated_token = []
all_dec_atten = []
all_enc_dec_atten = []
all_images = []
all_file_name = []
for images, _, batch_references, file_name in loader:
images = images.to(device)
features = encoder(images, return_features=True)
if use_beam_search:
generated_token, dec_atten, enc_dec_atten = decoder.generate_beam(
features, # B, 49, 512
torch.full((features.size(0),), w2i["<sos>"], device=device), # B,
w2i["<eos>"],
beam_size
)
else:
generated_token, dec_atten, enc_dec_atten = decoder.generate(
features, # B, 49, 512
torch.full((features.size(0),), w2i["<sos>"], device=device), # B,
w2i["<eos>"],
)
all_dec_atten.extend(dec_atten) # all_B, layers, nhead, seq_len, seq_len
all_enc_dec_atten.extend(enc_dec_atten) # all_B, layers, nhead, seq_len, 49
all_images.extend(images.cpu())
all_references.extend(list(zip(*batch_references)))
all_generated_token.extend(generated_token) # all_B, seq_len-1
all_file_name.extend(file_name)
all_generated_sentence = []
for sentence_token in all_generated_token:
if w2i["<eos>"] in sentence_token:
end_inx = sentence_token.index(w2i["<eos>"])
sentence_token = sentence_token[:end_inx]
# ==================================
# SentencePiece tokenizer
# ==================================
if use_subword:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load(sp_model_path)
# special token 제거
sentence_token = [token for token in sentence_token
if token not in [
w2i["<pad>"],
w2i["<sos>"],
w2i["<eos>"]
]
]
sentence = sp.decode(sentence_token)
else:
words = [i2w[i] for i in sentence_token]
sentence = ' '.join(words)
all_generated_sentence.append(sentence) # all_B, 1(문장)
for i in sample:
dec_atten_name = os.path.join(dec_atten_dir, f"{save_prefix}_dec_atten_{all_file_name[i]}")
cross_atten_name = os.path.join(enc_dec_atten_dir, f"{save_prefix}_cross_atten_{all_file_name[i]}")
decoder.show_dec_atten(all_dec_atten[i], all_generated_sentence[i].split(), layer, dec_atten_name)
decoder.show_cross_atten(all_enc_dec_atten[i], all_generated_sentence[i].split(), layer, all_images[i], cross_atten_name)
print("-" * 60)
print(f' {all_file_name[i]}: {all_generated_sentence[i]}')
print("-" * 60)
for inx, reference in enumerate(all_references[i], start=1):
print(f'Reference {inx}: {reference}')
print("=" * 60)
print(f'Best Val Loss: {best_val_loss}')
return all_generated_sentence, all_references
|