| | """Console logger utilities. |
| | |
| | Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py |
| | Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging |
| | """ |
| |
|
| | import logging |
| | import math |
| |
|
| | import fsspec |
| | import lightning |
| | import torch |
| | from timm.scheduler import CosineLRScheduler |
| |
|
| |
|
| | def fsspec_exists(filename): |
| | """Check if a file exists using fsspec.""" |
| | fs, _ = fsspec.core.url_to_fs(filename) |
| | return fs.exists(filename) |
| |
|
| |
|
| | def fsspec_listdir(dirname): |
| | """Listdir in manner compatible with fsspec.""" |
| | fs, _ = fsspec.core.url_to_fs(dirname) |
| | return fs.ls(dirname) |
| |
|
| |
|
| | def fsspec_mkdirs(dirname, exist_ok=True): |
| | """Mkdirs in manner compatible with fsspec.""" |
| | fs, _ = fsspec.core.url_to_fs(dirname) |
| | fs.makedirs(dirname, exist_ok=exist_ok) |
| |
|
| |
|
| | def print_nans(tensor, name): |
| | if torch.isnan(tensor).any(): |
| | print(name, tensor) |
| |
|
| |
|
| | class CosineDecayWarmupLRScheduler( |
| | CosineLRScheduler, |
| | torch.optim.lr_scheduler._LRScheduler): |
| | """Wrap timm.scheduler.CosineLRScheduler |
| | Enables calling scheduler.step() without passing in epoch. |
| | Supports resuming as well. |
| | Adapted from: |
| | https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self._last_epoch = -1 |
| | self.step(epoch=0) |
| |
|
| | def step(self, epoch=None): |
| | if epoch is None: |
| | self._last_epoch += 1 |
| | else: |
| | self._last_epoch = epoch |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.t_in_epochs: |
| | super().step(epoch=self._last_epoch) |
| | else: |
| | super().step_update(num_updates=self._last_epoch) |
| |
|
| |
|
| | class LoggingContext: |
| | """Context manager for selective logging.""" |
| | def __init__(self, logger, level=None, handler=None, close=True): |
| | self.logger = logger |
| | self.level = level |
| | self.handler = handler |
| | self.close = close |
| |
|
| | def __enter__(self): |
| | if self.level is not None: |
| | self.old_level = self.logger.level |
| | self.logger.setLevel(self.level) |
| | if self.handler: |
| | self.logger.addHandler(self.handler) |
| |
|
| | def __exit__(self, et, ev, tb): |
| | if self.level is not None: |
| | self.logger.setLevel(self.old_level) |
| | if self.handler: |
| | self.logger.removeHandler(self.handler) |
| | if self.handler and self.close: |
| | self.handler.close() |
| |
|
| |
|
| | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: |
| | """Initializes multi-GPU-friendly python logger.""" |
| |
|
| | logger = logging.getLogger(name) |
| | logger.setLevel(level) |
| |
|
| | |
| | |
| | for level in ('debug', 'info', 'warning', 'error', |
| | 'exception', 'fatal', 'critical'): |
| | setattr(logger, |
| | level, |
| | lightning.pytorch.utilities.rank_zero_only( |
| | getattr(logger, level))) |
| |
|
| | return logger |
| |
|
| |
|
| | class Sampler: |
| | def __init__(self, shape): |
| | self.shape = shape |
| |
|
| | def _sampling_noise(self): |
| | pass |
| | |
| | def _hard_sample(self, logits): |
| | pass |
| |
|
| | def _soft_sample(self, logits): |
| | return 0 |
| |
|
| | def sample(self, logits): |
| | noise = self._sampling_noise() |
| | noise = noise[: logits.shape[0], :] |
| | logits = logits + noise.to( |
| | dtype=logits.dtype, device=logits.device) |
| | hard_sample = self._hard_sample(logits) |
| | soft_sample = self._soft_sample(logits) |
| | return soft_sample + (hard_sample - soft_sample).detach() |
| |
|
| |
|
| | class TopKSampler(Sampler): |
| | def __init__(self, k, shape, gamma_tau=1.0): |
| | super().__init__(shape) |
| | self.k = k |
| | self.gamma_tau = gamma_tau |
| | self.num_betas = 10 |
| | self.sampler = torch.distributions.gamma.Gamma( |
| | 1 / k * torch.ones(self.num_betas, * self.shape), 1.0) |
| |
|
| | def _sampling_noise(self): |
| | noise = self.sampler.sample() |
| | beta = self.k / torch.arange(1, self.num_betas + 1, 1, |
| | dtype=torch.float32) |
| | beta = beta[:, None, None] |
| | assert beta.ndim == noise.ndim |
| | s = noise / beta |
| | s = torch.sum(s, axis=0) |
| | s = s - math.log(10.0) |
| | s = self.gamma_tau * (s / self.k) |
| | return s |
| |
|
| | def _hard_sample(self, logits): |
| | assert logits.ndim == 2 |
| | thresholds, _ = torch.sort(logits, dim=-1) |
| | thresholds = thresholds[:, - self.k][:, None] |
| | return (logits >= thresholds).type(logits.dtype) |
| |
|
| | def _soft_sample(self, logits): |
| | soft_top_k = logits - torch.mean(logits, dim=-1, |
| | keepdim=True) |
| | return soft_top_k / torch.norm(soft_top_k, dim=-1, |
| | keepdim=True) |
| |
|
| |
|
| | class DeterministicTopK(TopKSampler): |
| | def __init__(self, k): |
| | super().__init__(k, shape=(1, 1)) |
| |
|
| | def _sampling_noise(self): |
| | return 0 |
| |
|
| | def discreize(self, x): |
| | hard_sample = self._hard_sample(x) |
| | soft_sample = self._soft_sample(x) |
| | return soft_sample + (hard_sample - soft_sample).detach() |
| |
|
| | class GumbelSampler(Sampler): |
| |
|
| | def __init__(self, shape, temperature=1.0): |
| | super().__init__(shape) |
| | self.temperature = temperature |
| |
|
| | def _sampling_noise(self): |
| | return - (1e-10 - ( |
| | torch.rand(* self.shape) + 1e-10).log()).log() |
| |
|
| | def _hard_sample(self, logits): |
| | assert logits.ndim == 2 |
| | indices = torch.argmax(logits, dim=-1) |
| | zeros = logits * 0 |
| | ones = torch.ones_like(logits[:, :, :1]) |
| | return torch.scatter(zeros, -1, indices[:, :, None], |
| | ones) |
| |
|
| | def _soft_sample(self, logits): |
| | return torch.nn.functional.softmax( |
| | logits / self.temperature, dim=-1) |
| |
|
| |
|
| | class BinarySampler(GumbelSampler): |
| |
|
| | def sample(self, probs): |
| | |
| | pos_noise = self._sampling_noise().to( |
| | dtype=probs.dtype, device=probs.device) |
| | neg_noise = self._sampling_noise().to( |
| | dtype=probs.dtype, device=probs.device) |
| | del_noise_exp = (neg_noise - pos_noise).exp() |
| | hard_sample = (probs * (1 + del_noise_exp) |
| | > 1).to(probs.dtype) |
| | soft_sample = probs / (probs + (1 - probs) * del_noise_exp) |
| | return soft_sample + (hard_sample - soft_sample).detach() |
| |
|
| |
|
| | class GaussianSampler: |
| | def __init__(self): |
| | self.softplus = torch.nn.Softplus() |
| |
|
| | def sample(self, x): |
| | assert x.ndim == 2 |
| | n = x.shape[-1] // 2 |
| | mu = x[:, :n] |
| | sigma = self.softplus(x[:, n:]).sqrt() |
| | return mu + sigma * torch.randn_like(mu) |