File size: 3,663 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import time
import subprocess
import torch
import torch.nn as nn

from torch.utils.tensorboard import SummaryWriter


class TensorboardLogger(object):
    def __init__(self, log_dir='runs/experiment'):
        self.writer = SummaryWriter(log_dir=log_dir)

    def log_scalar(self, tag: str, value: float, step: int):
        self.writer.add_scalar(tag, value, step)

    # def log_scalars(self, tag: str, values: dict, step: int):
    #     self.writer.add_scalars(tag, values, step)

    def log_histogram(self, model:nn.Module, step: int):
        for name, param in model.named_parameters():
            self.writer.add_histogram(f'weights/{name}', param, step)
            if param.grad is not None:
                self.writer.add_histogram(f'gradients/{name}', param.grad, step)

    def log_model_graph(self, model: nn.Module, input_sample: torch.Tensor):
        self.writer.add_graph(model, input_sample)

    def close(self):
        self.writer.close()


def launch_tensorboard(log_dir: str = "runs", port: int = 6006, open_browser: bool = True):
    """Automatically launch TensorBoard pointing to log_dir."""
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    print(f"[TensorBoard] launching at http://localhost:{port}/")

    tb_command = ["tensorboard", f"--logdir={log_dir}", f"--port={port}"]
    if not open_browser:
        tb_command.append("--host=127.0.0.1")

    subprocess.Popen(tb_command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    time.sleep(2)


class TensorboardLoggerCallback(object):
    """
            Adapter uses TensorboardLogger
    """
    def __init__(self,
                 log_dir: str = 'logs/test_experiment',
                 log_histograms: bool = False,
                 launch_tb: bool = True,
                 tb_port: int = 6006,
                 open_tb_in_browser: bool = False,
                 ):

        self.tb_logger = TensorboardLogger(log_dir=log_dir)

        self.log_histograms = log_histograms

        if launch_tb:
            launch_tensorboard(log_dir=log_dir,
                               port=tb_port,
                               open_browser=open_tb_in_browser)

        self.model = None

    def on_train_begin(self, model: nn.Module):
        self.model = model
        print("[TensorBoard] Training started, callback init")

    def on_train_end(self):
        self.tb_logger.close()
        print("[TensorBoard] Training ended, logs saved")

    def on_epoch_end(self, state):
        epoch = state.get('epoch', 0)

        # losses
        self.tb_logger.log_scalar('Loss/train', state.get('train_loss', 0.0), epoch)
        self.tb_logger.log_scalar('Loss/val', state.get('val_loss', 0.0), epoch)

        # training & validation metrics

        # for metric, value in train_metrics.items():
        #     self.tb_logger.log_scalar(f'Metrics/train/{metric}', value, epoch)
        train_metrics = state.get('train_metrics')
        if train_metrics is not None:
            # self.tb_logger.log_scalars(f'Metrics/train', train_metrics, epoch)
            for metric_name, metric_value in train_metrics.items():
                self.tb_logger.log_scalar(f'Metrics/train/{metric_name}', metric_value, epoch)
        val_metrics = state.get('val_metrics')
        if val_metrics is not None:
            for metric_name, metric_value in val_metrics.items():
                self.tb_logger.log_scalar(f'Metrics/val/{metric_name}', metric_value, epoch)

        # weights and grads
        model = state.get('model')
        if self.log_histograms and model is not None:
            self.tb_logger.log_histogram(model, epoch)