File size: 4,331 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import mlflow
from typing import Protocol, Callable, Union
from contextlib import contextmanager
import torch
from torch import Tensor


def to_float(x):
    if torch.is_tensor(x):
        return float(x.detach().item())
    return float(x)


type AddType = float | Tensor


class LogCall(Protocol):
    def __call__(self, value: Union[AddType, Callable[[], AddType]]) -> None: ...
    def __enter__(self) -> "LogCall": ...
    def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
    def __getattr__(self, name: str) -> "LogCall": ...


class Debug:
    logs: dict[str, list[AddType | Callable[[], AddType]]] = {}
    _current_prefix = ""
    _is_pre = False

    def __init__(
        self,
        experiment: str,
        run: str,
    ):
        mlflow.set_tracking_uri("http://127.0.0.1:5000")
        mlflow.set_experiment(experiment)
        if mlflow.active_run() is not None:
            mlflow.end_run()
        mlflow.start_run(run_name=run)

    @contextmanager
    def _scope(self, name: str, is_pre: bool = False):
        old_prefix = self._current_prefix
        old_is_pre = self._is_pre

        new_path = f"{self._current_prefix}.{name}" if self._current_prefix else name
        self._current_prefix = new_path
        self._is_pre = is_pre
        try:
            yield self
        finally:
            self._current_prefix = old_prefix
            self._is_pre = old_is_pre

    def __getattr__(self, name: str) -> LogCall:
        if name == "pre":
            return self._PreProxy(self)
        return self._DynamicLogger(self, name)

    class _PreProxy:
        def __init__(self, outer, path=""):
            self.outer = outer
            self.path = path

        def __getattr__(self, name):
            new_path = f"{self.path}.{name}" if self.path else name
            return Debug._PreProxy(self.outer, new_path)

        def __call__(self, value: Union[float, Callable[[], float]]):
            self.outer.preAdd(self.path, value)

        def __enter__(self):
            self._cm = self.outer._scope(self.path, is_pre=True)
            return self._cm.__enter__()

        def __exit__(self, exc_type, exc_val, exc_tb):
            return self._cm.__exit__(exc_type, exc_val, exc_tb)

    class _DynamicLogger:
        def __init__(self, parent, name):
            self.parent = parent
            self.name = name
            self._cm = None

        def __getattr__(self, name: str) -> "Debug._DynamicLogger":
            new_name = f"{self.name}.{name}"
            return Debug._DynamicLogger(self.parent, new_name)

        def __call__(self, value: Union[float, Callable[[], float]]):
            prefix = self.parent._current_prefix
            full_key = f"{prefix}.{self.name}" if prefix else self.name

            if self.parent._is_pre or callable(value):
                self.parent.preAdd(full_key, value)
            else:
                self.parent.add(full_key, value)

        def __enter__(self):
            self._cm = self.parent._scope(self.name, is_pre=self.parent._is_pre)
            return self._cm.__enter__()

        def __exit__(self, exc_type, exc_val, exc_tb):
            if self._cm:
                return self._cm.__exit__(exc_type, exc_val, exc_tb)

    def add(
        self,
        key: str,
        value: AddType,
    ):
        if key not in self.logs:
            self.logs[key] = []
        self.logs[key].append(to_float(value))

    def preAdd(
        self,
        key: str,
        fn: Callable[[], AddType],
    ):
        if key not in self.logs:
            self.logs[key] = []
        self.logs[key].append(fn)

    def process(self):
        for key in self.logs:
            self.logs[key] = [
                to_float(item() if callable(item) else item) for item in self.logs[key]
            ]

    def commit(self, step: int):
        self.process()

        metrics: dict[str, float] = {}
        for key, values in self.logs.items():
            if not values:
                continue
            total_sum = sum(values)
            metrics[key.replace(".", "/")] = total_sum / len(values)

        if step > 0 and step % 10 == 0:
            mlflow.log_metrics(metrics, step)
        self.logs = {}

    def log_params(self, params: dict):
        mlflow.log_params(params)

    def end(self):
        mlflow.end_run()