File size: 745 Bytes
9f5a022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

writer = None
def log_data(data, i):

    
    for key in data.keys():
        writer.add_scalar(key, data[key], i)

def log_img(img, name):
    writer.add_image(name, img)


def save_grid_with_label(img_grid, label, out_file):
    img_grid = img_grid.permute(1, 2, 0).numpy()

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(img_grid)
    ax.set_title(label, fontsize=20)
    ax.axis('off')


    plt.subplots_adjust(top=0.85)

    plt.savefig(out_file, bbox_inches='tight', pad_inches=0.1)


    plt.close(fig)
    plt.close("all")




def init_logger(dir="runs"):

    global writer
    if not writer:
        writer = SummaryWriter(dir)