| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import functools |
| | import logging |
| | from contextlib import contextmanager |
| | import inspect |
| | import time |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | EPS = 1e-8 |
| |
|
| |
|
| | def capture_init(init): |
| | """capture_init. |
| | |
| | Decorate `__init__` with this, and you can then |
| | recover the *args and **kwargs passed to it in `self._init_args_kwargs` |
| | """ |
| | @functools.wraps(init) |
| | def __init__(self, *args, **kwargs): |
| | self._init_args_kwargs = (args, kwargs) |
| | init(self, *args, **kwargs) |
| |
|
| | return __init__ |
| |
|
| |
|
| | def deserialize_model(package, strict=False): |
| | """deserialize_model. |
| | |
| | """ |
| | klass = package['class'] |
| | if strict: |
| | model = klass(*package['args'], **package['kwargs']) |
| | else: |
| | sig = inspect.signature(klass) |
| | kw = package['kwargs'] |
| | for key in list(kw): |
| | if key not in sig.parameters: |
| | logger.warning("Dropping inexistant parameter %s", key) |
| | del kw[key] |
| | model = klass(*package['args'], **kw) |
| | model.load_state_dict(package['state']) |
| | return model |
| |
|
| |
|
| | def copy_state(state): |
| | return {k: v.cpu().clone() for k, v in state.items()} |
| |
|
| |
|
| | def serialize_model(model): |
| | args, kwargs = model._init_args_kwargs |
| | state = copy_state(model.state_dict()) |
| | return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} |
| |
|
| |
|
| | @contextmanager |
| | def swap_state(model, state): |
| | """ |
| | Context manager that swaps the state of a model, e.g: |
| | |
| | # model is in old state |
| | with swap_state(model, new_state): |
| | # model in new state |
| | # model back to old state |
| | """ |
| | old_state = copy_state(model.state_dict()) |
| | model.load_state_dict(state) |
| | try: |
| | yield |
| | finally: |
| | model.load_state_dict(old_state) |
| |
|
| |
|
| | def pull_metric(history, name): |
| | out = [] |
| | for metrics in history: |
| | if name in metrics: |
| | out.append(metrics[name]) |
| | return out |
| |
|
| |
|
| | class LogProgress: |
| | """ |
| | Sort of like tqdm but using log lines and not as real time. |
| | Args: |
| | - logger: logger obtained from `logging.getLogger`, |
| | - iterable: iterable object to wrap |
| | - updates (int): number of lines that will be printed, e.g. |
| | if `updates=5`, log every 1/5th of the total length. |
| | - total (int): length of the iterable, in case it does not support |
| | `len`. |
| | - name (str): prefix to use in the log. |
| | - level: logging level (like `logging.INFO`). |
| | """ |
| | def __init__(self, |
| | logger, |
| | iterable, |
| | updates=5, |
| | total=None, |
| | name="LogProgress", |
| | level=logging.INFO): |
| | self.iterable = iterable |
| | self.total = total or len(iterable) |
| | self.updates = updates |
| | self.name = name |
| | self.logger = logger |
| | self.level = level |
| |
|
| | def update(self, **infos): |
| | self._infos = infos |
| |
|
| | def __iter__(self): |
| | self._iterator = iter(self.iterable) |
| | self._index = -1 |
| | self._infos = {} |
| | self._begin = time.time() |
| | return self |
| |
|
| | def __next__(self): |
| | self._index += 1 |
| | try: |
| | value = next(self._iterator) |
| | except StopIteration: |
| | raise |
| | else: |
| | return value |
| | finally: |
| | log_every = max(1, self.total // self.updates) |
| | |
| | if self._index >= 1 and self._index % log_every == 0: |
| | self._log() |
| |
|
| | def _log(self): |
| | self._speed = (1 + self._index) / (time.time() - self._begin) |
| | infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) |
| | if self._speed < 1e-4: |
| | speed = "oo sec/it" |
| | elif self._speed < 0.1: |
| | speed = f"{1/self._speed:.1f} sec/it" |
| | else: |
| | speed = f"{self._speed:.1f} it/sec" |
| | out = f"{self.name} | {self._index}/{self.total} | {speed}" |
| | if infos: |
| | out += " | " + infos |
| | self.logger.log(self.level, out) |
| |
|
| |
|
| | def colorize(text, color): |
| | """ |
| | Display text with some ANSI color in the terminal. |
| | """ |
| | code = f"\033[{color}m" |
| | restore = "\033[0m" |
| | return "".join([code, text, restore]) |
| |
|
| |
|
| | def bold(text): |
| | """ |
| | Display text in bold in the terminal. |
| | """ |
| | return colorize(text, "1") |
| |
|
| |
|
| | def cal_snr(lbl, est): |
| | import torch |
| | y = 10.0 * torch.log10( |
| | torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) + |
| | EPS |
| | ) |
| | return y |
| |
|