| import numpy as np | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| def tensorboard_plot_loss(win_name, loss, writer): | |
| writer.add_scalar("Loss/{}".format(win_name), loss[-1], len(loss)) | |
| writer.flush() | |
| def normalize_img(imgs): | |
| b,c,h,w = imgs.shape | |
| gt_batch = b//2 | |
| for i in range(gt_batch): | |
| factor = torch.max(imgs[i]) | |
| imgs[i] = imgs[i]/factor | |
| imgs[gt_batch + i] = imgs[gt_batch + i]/factor | |
| imgs = torch.clamp(imgs, 0.0,1.0) | |
| return imgs | |
| def tensorboard_show_batch(imgs, writer, win_name=None, nrow=2, normalize=True, step=0): | |
| if normalize: | |
| imgs = normalize_img(imgs) | |
| writer.add_images('{}'.format(win_name), imgs, step) | |
| writer.flush() | |
| def tensorboard_log(log_info, writer, win_name='logger', step=0): | |
| writer.add_text(win_name, log_info, step) | |
| writer.flush() | |