Spaces:
Sleeping
Sleeping
| # import os | |
| # import torch | |
| # from dataclasses import dataclass, is_dataclass, replace, field | |
| # from typing import TypeVar, Generic, Optional, Any, Callable, Iterable | |
| # from pathlib import Path | |
| # import pickle | |
| # import tempfile | |
| # import shutil | |
| # import atexit | |
| # import signal | |
| # from typing import TypeVar, Generic | |
| # | |
| # T = TypeVar('T') | |
| # | |
| # | |
| # @dataclass | |
| # class DetachingCPUList(list[T]): | |
| # cache_dir: Optional[Path] = field(default=None) | |
| # save_serializer: Optional[Callable[[Any, Path], None]] = field(default=None) | |
| # load_serializer: Optional[Callable[[Path], Any]] = field(default=None) | |
| # detach_func: Optional[Callable[[Any], Any]] = field(default=None) | |
| # remove_files_on_delete: bool = field(default=True) | |
| # verbose: bool = field(default=True) | |
| # | |
| # _cache: dict = field(init=False, repr=False, default_factory=dict) | |
| # _fd_map: dict = field(init=False, repr=False, default_factory=dict) | |
| # _no_cache_set: set = field(init=False, repr=False, default_factory=set) # Track paths marked as no_cache | |
| # _cache_dir: Path = field(init=False, repr=False) | |
| # _tmp_dir_created: bool = field(init=False, repr=False, default=False) | |
| # | |
| # def __post_init__(self): | |
| # | |
| # # Set default serializers if not provided | |
| # if self.save_serializer is None: | |
| # def _save_pickle(obj, path: Path): | |
| # path.parent.mkdir(parents=True, exist_ok=True) | |
| # with open(path, "wb") as f: | |
| # pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| # self.save_serializer = _save_pickle | |
| # | |
| # # Set default deserializers if not provided | |
| # if self.load_serializer is None: | |
| # def _load_pickle(source): | |
| # if isinstance(source, (str, Path)): | |
| # with open(str(source), "rb") as f: | |
| # return pickle.load(f) | |
| # else: | |
| # return pickle.load(source) | |
| # self.load_serializer = _load_pickle | |
| # | |
| # # Set default detach_func if not provided | |
| # if self.detach_func is None: | |
| # self.detach_func = self._detach_recursive | |
| # | |
| # if self.cache_dir is None: | |
| # tmp = tempfile.mkdtemp(prefix="detaching_cpu_list_") | |
| # self._cache_dir = Path(tmp) | |
| # self._tmp_dir_created = True | |
| # else: | |
| # self._cache_dir = Path(self.cache_dir) | |
| # self._cache_dir.mkdir(parents=True, exist_ok=True) | |
| # | |
| # atexit.register(self._cleanup) | |
| # for sig in (signal.SIGINT, signal.SIGTERM): | |
| # try: | |
| # old = signal.getsignal(sig) | |
| # def _handler(signum, frame, _old=old): | |
| # self._cleanup() | |
| # if callable(_old) and _old not in (signal.SIG_DFL, signal.SIG_IGN): | |
| # _old(signum, frame) | |
| # signal.signal(sig, _handler) | |
| # except Exception: | |
| # pass | |
| # | |
| # # -------------------------- | |
| # # Core cleanup | |
| # # -------------------------- | |
| # def _cleanup(self): | |
| # # Close all file descriptors | |
| # for fd in list(self._fd_map.values()): | |
| # try: | |
| # os.close(fd) | |
| # except Exception: | |
| # pass | |
| # self._fd_map.clear() | |
| # # Remove on-disk files (non-guaranteed mode) | |
| # if self.remove_files_on_delete and self._tmp_dir_created and self._cache_dir.exists(): | |
| # shutil.rmtree(self._cache_dir, ignore_errors=True) | |
| # | |
| # # -------------------------- | |
| # # Save / load helpers | |
| # # -------------------------- | |
| # def _save_unlinked(self, item: Any): | |
| # """Save item to disk with guaranteed deletion (fd-based approach).""" | |
| # tmp = tempfile.NamedTemporaryFile(delete=False, dir=str(self._cache_dir)) | |
| # temp_path = Path(tmp.name) | |
| # tmp.close() | |
| # | |
| # assert self.save_serializer is not None, "save_serializer must be defined" | |
| # self.save_serializer(item, temp_path) | |
| # fd = os.open(str(temp_path), os.O_RDONLY) | |
| # os.unlink(str(temp_path)) # unlink immediately - kernel guarantees cleanup on fd close | |
| # pseudo = Path(f"/proc/self/fd/{fd}") | |
| # token = pseudo.as_posix() | |
| # self._fd_map[token] = fd | |
| # if self.verbose: | |
| # print(f"Saved item to fd {fd} with path {pseudo}") | |
| # return pseudo | |
| # | |
| # def _load_from_fd(self, token: str): | |
| # """Load from file descriptor path.""" | |
| # fd = self._fd_map.get(token) | |
| # if fd is None: | |
| # raise RuntimeError(f"FD {token} not available.") | |
| # dupfd = os.dup(fd) | |
| # assert self.load_serializer is not None, "load_serializer must be defined" | |
| # with os.fdopen(dupfd, "rb") as f: | |
| # f.seek(0) | |
| # obj = self.load_serializer(f) | |
| # if self.verbose: | |
| # print(f"Loaded item from fd {fd} (token {token})") | |
| # return obj | |
| # | |
| # def _load_from_disk(self, path: Path): | |
| # """Load from regular file path.""" | |
| # assert self.load_serializer is not None, "load_serializer must be defined" | |
| # if self.verbose: | |
| # print(f"Loading from disk: {path}") | |
| # return self.load_serializer(path) | |
| # | |
| # # -------------------------- | |
| # # Public interface | |
| # # -------------------------- | |
| # def append(self, item, detach_and_cpu: bool = False, save_to_disk: bool = False, no_cache: bool = False): | |
| # """ | |
| # Append an item to the list. | |
| # | |
| # Args: | |
| # item: The item to append | |
| # detach_and_cpu: If True, apply detach_func to move tensors to CPU | |
| # save_to_disk: If True, save to disk using fd-based guaranteed deletion | |
| # no_cache: If True, never cache this item in memory when accessed (always reload from disk) | |
| # """ | |
| # # Validate save_to_disk requires detach capability | |
| # if save_to_disk and not detach_and_cpu: | |
| # raise ValueError("Cannot save to disk without detach_and_cpu=True") | |
| # | |
| # if not save_to_disk and no_cache: | |
| # print("Warning: no_cache=True has no effect when save_to_disk=False") | |
| # | |
| # if detach_and_cpu and self.detach_func: | |
| # item = self.detach_func(item) | |
| # | |
| # if save_to_disk: | |
| # # Always use fd-based guaranteed deletion | |
| # p = self._save_unlinked(item) | |
| # super().append(p) | |
| # | |
| # # Mark as no_cache if requested | |
| # if no_cache: | |
| # self._no_cache_set.add(p.as_posix()) | |
| # else: | |
| # super().append(item) | |
| # | |
| # def insert(self, index: int, item, detach_and_cpu: bool = False, save_to_disk: bool = False, no_cache: bool = False): | |
| # """ | |
| # Insert an item at a specific index in the list. | |
| # | |
| # Args: | |
| # index: The index to insert the item at | |
| # item: The item to insert | |
| # detach_and_cpu: If True, apply detach_func to move tensors to CPU | |
| # save_to_disk: If True, save to disk using fd-based guaranteed deletion | |
| # no_cache: If True, never cache this item in memory when accessed (always reload from disk) | |
| # """ | |
| # # Validate save_to_disk requires detach capability | |
| # if save_to_disk and not detach_and_cpu: | |
| # raise ValueError("Cannot save to disk without detach_and_cpu=True") | |
| # | |
| # if detach_and_cpu and self.detach_func: | |
| # item = self.detach_func(item) | |
| # | |
| # if save_to_disk: | |
| # # Always use fd-based guaranteed deletion | |
| # p = self._save_unlinked(item) | |
| # super().insert(index, p) | |
| # | |
| # # Mark as no_cache if requested | |
| # if no_cache: | |
| # self._no_cache_set.add(p.as_posix()) | |
| # else: | |
| # super().insert(index, item) | |
| # | |
| # def extend(self, items: Iterable[Any], **kwargs): | |
| # for it in items: | |
| # self.append(it, **kwargs) | |
| # | |
| # def __getitem__(self, index): | |
| # """ | |
| # Return the item at `index`. | |
| # | |
| # - If the underlying stored value is a Path, load from disk/fd. | |
| # - If marked as no_cache, always reload from disk (never cache). | |
| # - Otherwise, cache after first load and return cached object on subsequent accesses. | |
| # - If not a Path, return the in-memory value directly. | |
| # """ | |
| # raw = super().__getitem__(index) | |
| # | |
| # # If it's not a Path, it's an in-memory object: return directly | |
| # if not isinstance(raw, Path): | |
| # return raw | |
| # | |
| # # it's a Path -> use its string as cache key (works for both normal paths and /proc/self/fd/<fd>) | |
| # key = raw.as_posix() | |
| # | |
| # # If in cache, return cached object | |
| # if key in self._cache: | |
| # assert key not in self._no_cache_set, "Inconsistent state: item both cached and marked no_cache" | |
| # return self._cache[key] | |
| # | |
| # # Always reload from disk/fd, never cache | |
| # if str(raw).startswith("/proc/self/fd/"): | |
| # obj = self._load_from_fd(key) | |
| # else: | |
| # raise NotImplementedError("Loading from disk is not implemented") | |
| # # return self._load_from_disk(raw) | |
| # | |
| # if key in self._no_cache_set: | |
| # pass # never cache | |
| # else: | |
| # # Cache it permanently (unless marked as no_cache) | |
| # self._cache[key] = obj | |
| # | |
| # return obj | |
| # | |
| # def pop(self, index: int = -1): | |
| # raw = super().pop(index) | |
| # # If it was a Path, return the loaded object (and keep the fd open / mapping intact) | |
| # if isinstance(raw, Path): | |
| # key = raw.as_posix() | |
| # # return cached value if present; else load now and cache it | |
| # if key in self._cache: | |
| # return self._cache[key] | |
| # if str(raw).startswith("/proc/self/fd/"): | |
| # obj = self._load_from_fd(key) | |
| # else: | |
| # obj = self._load_from_disk(raw) | |
| # self._cache[key] = obj | |
| # return obj | |
| # return raw | |
| # | |
| # def clear(self): | |
| # # do not close fds; keep them open until process exit as requested | |
| # # keep cache intact if you want (or clear it if you prefer) | |
| # # here we remove list entries but keep any cached objects and open fds | |
| # super().clear() | |
| # | |
| # def __iter__(self): | |
| # """Iterate over items, loading from disk as needed.""" | |
| # for i in range(len(self)): | |
| # yield self[i] | |
| # | |
| # def __del__(self): | |
| # self._cleanup() | |
| # | |
| # def _detach_recursive(self, obj): | |
| # if isinstance(obj, torch.Tensor): | |
| # return obj.detach().cpu() | |
| # elif isinstance(obj, dict): | |
| # return {k: self._detach_recursive(v) for k, v in obj.items()} | |
| # elif isinstance(obj, (list, tuple)): | |
| # t = type(obj) | |
| # return t(self._detach_recursive(x) for x in obj) | |
| # elif is_dataclass(obj): | |
| # # Replace fields recursively (returns a new instance) | |
| # return replace(obj, **{ | |
| # field.name: self._detach_recursive(getattr(obj, field.name)) | |
| # for field in obj.__dataclass_fields__.values() | |
| # }) | |
| # else: | |
| # return obj | |
| from dataclasses import is_dataclass, replace, dataclass | |
| import torch | |
| class DetachingCPUList(list): | |
| # TODO Naama: Add back disk saving | |
| def append(self, item, detach_and_cpu=False, save_to_disk=False, no_cache=False): | |
| if detach_and_cpu: | |
| item = self._detach_recursive(item) | |
| super().append(item) | |
| def extend(self, iterable, detach_and_cpu=False): | |
| if detach_and_cpu: | |
| iterable = (self._detach_recursive(x) for x in iterable) | |
| super().extend(iterable) | |
| def insert(self, index, item, detach_and_cpu=False, save_to_disk=False, no_cache=False): | |
| if detach_and_cpu: | |
| item = self._detach_recursive(item) | |
| super().insert(index, item) | |
| def _detach_recursive(self, obj): | |
| if isinstance(obj, torch.Tensor): | |
| return obj.detach().cpu() | |
| elif isinstance(obj, dict): | |
| return {k: self._detach_recursive(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| t = type(obj) | |
| return t(self._detach_recursive(x) for x in obj) | |
| elif is_dataclass(obj): | |
| # Replace fields recursively (returns a new instance) | |
| return replace(obj, **{ | |
| field.name: self._detach_recursive(getattr(obj, field.name)) | |
| for field in obj.__dataclass_fields__.values() | |
| }) | |
| else: | |
| return obj | |