| from abc import ABC, abstractmethod |
| from collections import defaultdict |
| from collections.abc import Callable, Iterable |
|
|
| import torch |
|
|
| _SourceGetter = Callable[[], Iterable[tuple[str, torch.Tensor]]] |
|
|
|
|
| class TensorBackuper(ABC): |
| @staticmethod |
| def create(source_getter, single_tag): |
| if single_tag is None: |
| return _TensorBackuperNormal(source_getter=source_getter) |
| else: |
| return _TensorBackuperNoop(source_getter=source_getter, single_tag=single_tag) |
|
|
| def __init__(self, source_getter: _SourceGetter): |
| self._source_getter = source_getter |
|
|
| @property |
| @abstractmethod |
| def backup_tags(self): |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def get(self, tag: str): |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def backup(self, tag: str): |
| raise NotImplementedError |
|
|
| def copy(self, *, src_tag: str, dst_tag: str): |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def restore(self, tag: str): |
| raise NotImplementedError |
|
|
|
|
| class _TensorBackuperNormal(TensorBackuper): |
| def __init__(self, source_getter): |
| super().__init__(source_getter=source_getter) |
| self._backups: dict[str, dict[str, torch.Tensor]] = defaultdict(dict) |
|
|
| @property |
| def backup_tags(self): |
| return list(self._backups) |
|
|
| def get(self, tag: str): |
| return self._backups[tag] |
|
|
| @torch.no_grad() |
| def backup(self, tag: str) -> None: |
| backup_dict = self._backups[tag] |
| for name, param in self._source_getter(): |
| if name not in backup_dict: |
| backup_dict[name] = torch.empty_like(param, device=torch.device("cpu"), pin_memory=True) |
| backup_dict[name].copy_(param.detach(), non_blocking=True) |
| torch.cuda.synchronize() |
|
|
| @torch.no_grad() |
| def copy(self, *, src_tag: str, dst_tag: str): |
| for name in self._backups[dst_tag]: |
| self._backups[dst_tag][name].copy_(self._backups[src_tag][name]) |
|
|
| @torch.no_grad() |
| def restore(self, tag: str) -> None: |
| backup_dict = self._backups[tag] |
| for name, param in self._source_getter(): |
| assert name in backup_dict |
| param.copy_(backup_dict[name], non_blocking=True) |
| torch.cuda.synchronize() |
|
|
|
|
| class _TensorBackuperNoop(TensorBackuper): |
| def __init__(self, source_getter, single_tag): |
| super().__init__(source_getter=source_getter) |
| self._single_tag = single_tag |
| |
| self._backup_hash_dict = None |
|
|
| @property |
| def backup_tags(self): |
| return [self._single_tag] |
|
|
| def get(self, tag: str): |
| ans = dict(self._source_getter()) |
| ans = {k: v.detach() for k, v in ans.items()} |
| assert _compute_hash_dict(ans) == self._backup_hash_dict |
| return ans |
|
|
| def backup(self, tag: str) -> None: |
| assert tag == self._single_tag |
| self._backup_hash_dict = _compute_hash_dict(dict(self._source_getter())) |
| torch.cuda.synchronize() |
|
|
| def restore(self, tag: str) -> None: |
| assert tag == self._single_tag |
| assert _compute_hash_dict(dict(self._source_getter())) == self._backup_hash_dict |
| torch.cuda.synchronize() |
|
|
|
|
| def _compute_hash_dict(tensors: dict[str, torch.Tensor]): |
| return {k: _compute_hash_tensor(v) for k, v in tensors.items()} |
|
|
|
|
| def _compute_hash_tensor(x: torch.Tensor): |
| |
| x = x.contiguous() |
| x = x.view(-1) |
| x = x.view(torch.uint32) |
| x = x.sum() |
| return x.item() |
|
|