Spaces:
Sleeping
Sleeping
| ###### best val loss ์ง์ ์์ ๋ชจ๋ ์์ฑ ์บก์ ์ถ๋ ฅ ๋ฐ ๋ฐํ, heatmap ์ ์ฅ ##### | |
| import os | |
| import torch | |
| from src.utils.checkpoint_manager import load_checkpoint | |
| def make_show_all_caption( | |
| loader, | |
| encoder, | |
| decoder, | |
| optimizer, | |
| w2i, | |
| i2w, | |
| best_path, | |
| dec_atten_dir, | |
| enc_dec_atten_dir, | |
| save_prefix, | |
| sample, | |
| layer, | |
| device, | |
| use_subword, | |
| sp_model_path, | |
| use_beam_search, | |
| beam_size | |
| ): | |
| _, best_val_loss = load_checkpoint( | |
| best_path, | |
| encoder, | |
| decoder, | |
| optimizer, | |
| device | |
| ) | |
| all_references = [] | |
| all_generated_token = [] | |
| all_dec_atten = [] | |
| all_enc_dec_atten = [] | |
| all_images = [] | |
| all_file_name = [] | |
| for images, _, batch_references, file_name in loader: | |
| images = images.to(device) | |
| features = encoder(images, return_features=True) | |
| if use_beam_search: | |
| generated_token, dec_atten, enc_dec_atten = decoder.generate_beam( | |
| features, # B, 49, 512 | |
| torch.full((features.size(0),), w2i["<sos>"], device=device), # B, | |
| w2i["<eos>"], | |
| beam_size | |
| ) | |
| else: | |
| generated_token, dec_atten, enc_dec_atten = decoder.generate( | |
| features, # B, 49, 512 | |
| torch.full((features.size(0),), w2i["<sos>"], device=device), # B, | |
| w2i["<eos>"], | |
| ) | |
| all_dec_atten.extend(dec_atten) # all_B, layers, nhead, seq_len, seq_len | |
| all_enc_dec_atten.extend(enc_dec_atten) # all_B, layers, nhead, seq_len, 49 | |
| all_images.extend(images.cpu()) | |
| all_references.extend(list(zip(*batch_references))) | |
| all_generated_token.extend(generated_token) # all_B, seq_len-1 | |
| all_file_name.extend(file_name) | |
| all_generated_sentence = [] | |
| for sentence_token in all_generated_token: | |
| if w2i["<eos>"] in sentence_token: | |
| end_inx = sentence_token.index(w2i["<eos>"]) | |
| sentence_token = sentence_token[:end_inx] | |
| # ================================== | |
| # SentencePiece tokenizer | |
| # ================================== | |
| if use_subword: | |
| import sentencepiece as spm | |
| sp = spm.SentencePieceProcessor() | |
| sp.load(sp_model_path) | |
| # special token ์ ๊ฑฐ | |
| sentence_token = [token for token in sentence_token | |
| if token not in [ | |
| w2i["<pad>"], | |
| w2i["<sos>"], | |
| w2i["<eos>"] | |
| ] | |
| ] | |
| sentence = sp.decode(sentence_token) | |
| else: | |
| words = [i2w[i] for i in sentence_token] | |
| sentence = ' '.join(words) | |
| all_generated_sentence.append(sentence) # all_B, 1(๋ฌธ์ฅ) | |
| for i in sample: | |
| dec_atten_name = os.path.join(dec_atten_dir, f"{save_prefix}_dec_atten_{all_file_name[i]}") | |
| cross_atten_name = os.path.join(enc_dec_atten_dir, f"{save_prefix}_cross_atten_{all_file_name[i]}") | |
| decoder.show_dec_atten(all_dec_atten[i], all_generated_sentence[i].split(), layer, dec_atten_name) | |
| decoder.show_cross_atten(all_enc_dec_atten[i], all_generated_sentence[i].split(), layer, all_images[i], cross_atten_name) | |
| print("-" * 60) | |
| print(f' {all_file_name[i]}: {all_generated_sentence[i]}') | |
| print("-" * 60) | |
| for inx, reference in enumerate(all_references[i], start=1): | |
| print(f'Reference {inx}: {reference}') | |
| print("=" * 60) | |
| print(f'Best Val Loss: {best_val_loss}') | |
| return all_generated_sentence, all_references | |