Spaces:
Runtime error
Runtime error
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, Callable, Dict, Tuple | |
| import torch | |
| class CPUOffloadWrapper: | |
| def __init__(self, model: Any, is_cpu_offload: bool = False, is_running_on_gpu: bool = True): | |
| object.__setattr__(self, "model", model) | |
| object.__setattr__(self, "is_cpu_offload", is_cpu_offload) | |
| object.__setattr__(self, "is_running_on_gpu", is_running_on_gpu) | |
| cpu_device = torch.device("cpu") | |
| cuda_device = torch.device("cuda") | |
| object.__setattr__(self, "cpu_device", cpu_device) | |
| object.__setattr__(self, "cuda_device", cuda_device) | |
| # Initialize placement location | |
| if is_cpu_offload: | |
| self.model.to(cpu_device) | |
| else: | |
| self.model.to(cuda_device) | |
| # Whitelist non-compute methods that shouldn't trigger device hops (pass-through only; no device switch) | |
| object.__setattr__( | |
| self, | |
| "_non_compute_methods", | |
| { | |
| "to", | |
| "cpu", | |
| "cuda", | |
| "eval", | |
| "train", | |
| "state_dict", | |
| "load_state_dict", | |
| "parameters", | |
| "named_parameters", | |
| "buffers", | |
| "named_buffers", | |
| "modules", | |
| "named_modules", | |
| "children", | |
| "named_children", | |
| "register_forward_hook", | |
| "register_forward_pre_hook", | |
| "register_full_backward_hook", | |
| "zero_grad", | |
| "share_memory", | |
| "half", | |
| "float", | |
| "bfloat16", | |
| }, | |
| ) | |
| # Get current primary device (for external reads) | |
| def device(self) -> torch.device: | |
| if isinstance(self.model, torch.nn.Module): | |
| return next(self.model.parameters()).device | |
| else: | |
| for k, v in self.model.__dict__.items(): | |
| if isinstance(v, torch.Tensor): | |
| return v.device | |
| elif isinstance(v, torch.nn.Module): | |
| return next(v.parameters()).device | |
| return self.cuda_device | |
| def _backup_cpu_state(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]: | |
| # Backup module parameters and buffers | |
| module_param_backup = {} | |
| module_buffer_backup = {} | |
| other_backup = {} | |
| def save_module_state(mod: torch.nn.Module, prefix: str): | |
| for name, param in mod.named_parameters(): | |
| if param is not None: | |
| full_key = prefix + name | |
| module_param_backup[full_key] = param.data | |
| for name, buffer in mod.named_buffers(): | |
| if buffer is not None: | |
| full_key = prefix + name | |
| module_buffer_backup[full_key] = buffer.data | |
| if isinstance(self.model, torch.nn.Module): | |
| save_module_state(self.model, "") | |
| else: | |
| for name, attr_val in self.model.__dict__.items(): | |
| if isinstance(attr_val, torch.nn.Module): | |
| save_module_state(attr_val, name + ".") | |
| elif isinstance(attr_val, torch.Tensor): | |
| other_backup[name] = attr_val | |
| return module_param_backup, module_buffer_backup, other_backup | |
| def _restore_cpu_state(self, backups: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]): | |
| # Restore module parameters and buffers | |
| module_param_backup, module_buffer_backup, other_backup = backups | |
| def restore_module_state(mod: torch.nn.Module, prefix: str): | |
| for name, param in mod.named_parameters(): | |
| full_key = prefix + name | |
| if full_key in module_param_backup: | |
| param.data = module_param_backup[full_key] | |
| for name, buffer in mod.named_buffers(): | |
| full_key = prefix + name | |
| if full_key in module_buffer_backup: | |
| buffer.data = module_buffer_backup[full_key] | |
| if isinstance(self.model, torch.nn.Module): | |
| restore_module_state(self.model, "") | |
| else: | |
| for name, attr_val in self.model.__dict__.items(): | |
| if isinstance(attr_val, torch.nn.Module): | |
| restore_module_state(attr_val, name + ".") | |
| if not isinstance(self.model, torch.nn.Module): | |
| for name, val in other_backup.items(): | |
| setattr(self.model, name, val) | |
| # Unified on/offload executor | |
| def _run_with_optional_offload(self, func: Callable[..., Any], *args, **kwargs): | |
| if self.is_cpu_offload and self.is_running_on_gpu: | |
| backups = self._backup_cpu_state() | |
| self.model.to(self.cuda_device) | |
| try: | |
| return func(*args, **kwargs) | |
| finally: | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self._restore_cpu_state(backups) | |
| else: | |
| # Make sure model and args are on the same device | |
| args = [ | |
| arg.to(self.device) if isinstance(arg, torch.Tensor) and arg.device != self.device else arg for arg in args | |
| ] | |
| kwargs = { | |
| k: v.to(self.device) if isinstance(v, torch.Tensor) and v.device != self.device else v | |
| for k, v in kwargs.items() | |
| } | |
| return func(*args, **kwargs) | |
| # Direct call (equivalent to forward) | |
| def __call__(self, *args, **kwargs): | |
| return self._run_with_optional_offload(self.model.__call__, *args, **kwargs) | |
| # Explicit forward; some code calls model.forward(...) | |
| def forward(self, *args, **kwargs): | |
| return self._run_with_optional_offload(self.model.forward, *args, **kwargs) | |
| # Key: passthrough all attrs/methods. For callables, wrap with on/offload; for non-compute methods, pass-through only with no device switch. | |
| def __getattr__(self, name: str): | |
| # Fetch attribute from the wrapped model first | |
| attr = getattr(self.model, name) | |
| # Wrap methods (except in whitelist) | |
| if callable(attr) and name not in self._non_compute_methods: | |
| def _wrapped(*args, **kwargs): | |
| return self._run_with_optional_offload(attr, *args, **kwargs) | |
| return _wrapped | |
| return attr | |
| def __dir__(self): | |
| return sorted(set(list(super().__dir__()) + dir(self.model))) | |
| def __setattr__(self, name: str, value: Any): | |
| raise AttributeError("CPUOffloadWrapper is immutable") | |
| def __repr__(self) -> str: | |
| return f"CPUOffloadWrapper(is_cpu_offload={self.is_cpu_offload}, is_running_on_gpu={self.is_running_on_gpu}, model={repr(self.model)})" | |