| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| import matplotlib.pyplot as plt | |
| writer = None | |
| def log_data(data, i): | |
| for key in data.keys(): | |
| writer.add_scalar(key, data[key], i) | |
| def log_img(img, name): | |
| writer.add_image(name, img) | |
| def save_grid_with_label(img_grid, label, out_file): | |
| img_grid = img_grid.permute(1, 2, 0).numpy() | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.imshow(img_grid) | |
| ax.set_title(label, fontsize=20) | |
| ax.axis('off') | |
| plt.subplots_adjust(top=0.85) | |
| plt.savefig(out_file, bbox_inches='tight', pad_inches=0.1) | |
| plt.close(fig) | |
| plt.close("all") | |
| def init_logger(dir="runs"): | |
| global writer | |
| if not writer: | |
| writer = SummaryWriter(dir) |