Spaces:
Configuration error
Configuration error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import numpy as np | |
| import os | |
| import ntpath | |
| import time | |
| import glob | |
| from scipy.misc import imresize | |
| import torchvision.utils as vutils | |
| from operator import itemgetter | |
| from tensorboardX import SummaryWriter | |
| class Visualizer(): | |
| def __init__(self, checkpoints_dir, name): | |
| self.win_size = 256 | |
| self.name = name | |
| self.saved = False | |
| self.checkpoints_dir = checkpoints_dir | |
| self.ncols = 4 | |
| # remove existing | |
| for filename in glob.glob(self.checkpoints_dir+"/events*"): | |
| os.remove(filename) | |
| self.writer = SummaryWriter(checkpoints_dir) | |
| def reset(self): | |
| self.saved = False | |
| # images: (b, c, 0, 1) array of images | |
| def image_summary(self, mode, epoch, images): | |
| images = vutils.make_grid(images, normalize=True, scale_each=True) | |
| self.writer.add_image('{}/Image'.format(mode), images, epoch) | |
| # text: type: ingredients/recipe | |
| def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20): | |
| for i, el in enumerate(text): # text_list | |
| if not gt: # we are printing a sample | |
| idx = el.nonzero().squeeze() + 1 | |
| else: | |
| idx = el # we are printing the ground truth | |
| words_list = itemgetter(*idx)(vocabulary) | |
| if len(words_list) <= max_length: | |
| self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), | |
| ', '.join(filter(lambda x: x != '<pad>', words_list)), epoch) | |
| else: | |
| self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), | |
| 'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch) | |
| # losses: dictionary of error labels and values | |
| def scalar_summary(self, mode, epoch, **args): | |
| for k, v in args.items(): | |
| self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch) | |
| self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir)) | |
| def histo_summary(self, model, step): | |
| """Log a histogram of the tensor of values.""" | |
| for name, param in model.named_parameters(): | |
| self.writer.add_histogram(name, param, step) | |
| def close(self): | |
| self.writer.close() | |