github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
import mlflow
from typing import Protocol, Callable, Union
from contextlib import contextmanager
import torch
from torch import Tensor
def to_float(x):
if torch.is_tensor(x):
return float(x.detach().item())
return float(x)
type AddType = float | Tensor
class LogCall(Protocol):
def __call__(self, value: Union[AddType, Callable[[], AddType]]) -> None: ...
def __enter__(self) -> "LogCall": ...
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
def __getattr__(self, name: str) -> "LogCall": ...
class Debug:
logs: dict[str, list[AddType | Callable[[], AddType]]] = {}
_current_prefix = ""
_is_pre = False
def __init__(
self,
experiment: str,
run: str,
):
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment(experiment)
if mlflow.active_run() is not None:
mlflow.end_run()
mlflow.start_run(run_name=run)
@contextmanager
def _scope(self, name: str, is_pre: bool = False):
old_prefix = self._current_prefix
old_is_pre = self._is_pre
new_path = f"{self._current_prefix}.{name}" if self._current_prefix else name
self._current_prefix = new_path
self._is_pre = is_pre
try:
yield self
finally:
self._current_prefix = old_prefix
self._is_pre = old_is_pre
def __getattr__(self, name: str) -> LogCall:
if name == "pre":
return self._PreProxy(self)
return self._DynamicLogger(self, name)
class _PreProxy:
def __init__(self, outer, path=""):
self.outer = outer
self.path = path
def __getattr__(self, name):
new_path = f"{self.path}.{name}" if self.path else name
return Debug._PreProxy(self.outer, new_path)
def __call__(self, value: Union[float, Callable[[], float]]):
self.outer.preAdd(self.path, value)
def __enter__(self):
self._cm = self.outer._scope(self.path, is_pre=True)
return self._cm.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return self._cm.__exit__(exc_type, exc_val, exc_tb)
class _DynamicLogger:
def __init__(self, parent, name):
self.parent = parent
self.name = name
self._cm = None
def __getattr__(self, name: str) -> "Debug._DynamicLogger":
new_name = f"{self.name}.{name}"
return Debug._DynamicLogger(self.parent, new_name)
def __call__(self, value: Union[float, Callable[[], float]]):
prefix = self.parent._current_prefix
full_key = f"{prefix}.{self.name}" if prefix else self.name
if self.parent._is_pre or callable(value):
self.parent.preAdd(full_key, value)
else:
self.parent.add(full_key, value)
def __enter__(self):
self._cm = self.parent._scope(self.name, is_pre=self.parent._is_pre)
return self._cm.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if self._cm:
return self._cm.__exit__(exc_type, exc_val, exc_tb)
def add(
self,
key: str,
value: AddType,
):
if key not in self.logs:
self.logs[key] = []
self.logs[key].append(to_float(value))
def preAdd(
self,
key: str,
fn: Callable[[], AddType],
):
if key not in self.logs:
self.logs[key] = []
self.logs[key].append(fn)
def process(self):
for key in self.logs:
self.logs[key] = [
to_float(item() if callable(item) else item) for item in self.logs[key]
]
def commit(self, step: int):
self.process()
metrics: dict[str, float] = {}
for key, values in self.logs.items():
if not values:
continue
total_sum = sum(values)
metrics[key.replace(".", "/")] = total_sum / len(values)
if step > 0 and step % 10 == 0:
mlflow.log_metrics(metrics, step)
self.logs = {}
def log_params(self, params: dict):
mlflow.log_params(params)
def end(self):
mlflow.end_run()