daVinci-MagiHuman / inference /common /cpu_offload_wrapper.py
ethanchern's picture
init
873b6ec
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Tuple
import torch
class CPUOffloadWrapper:
def __init__(self, model: Any, is_cpu_offload: bool = False, is_running_on_gpu: bool = True):
object.__setattr__(self, "model", model)
object.__setattr__(self, "is_cpu_offload", is_cpu_offload)
object.__setattr__(self, "is_running_on_gpu", is_running_on_gpu)
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda")
object.__setattr__(self, "cpu_device", cpu_device)
object.__setattr__(self, "cuda_device", cuda_device)
# Initialize placement location
if is_cpu_offload:
self.model.to(cpu_device)
else:
self.model.to(cuda_device)
# Whitelist non-compute methods that shouldn't trigger device hops (pass-through only; no device switch)
object.__setattr__(
self,
"_non_compute_methods",
{
"to",
"cpu",
"cuda",
"eval",
"train",
"state_dict",
"load_state_dict",
"parameters",
"named_parameters",
"buffers",
"named_buffers",
"modules",
"named_modules",
"children",
"named_children",
"register_forward_hook",
"register_forward_pre_hook",
"register_full_backward_hook",
"zero_grad",
"share_memory",
"half",
"float",
"bfloat16",
},
)
# Get current primary device (for external reads)
@property
def device(self) -> torch.device:
if isinstance(self.model, torch.nn.Module):
return next(self.model.parameters()).device
else:
for k, v in self.model.__dict__.items():
if isinstance(v, torch.Tensor):
return v.device
elif isinstance(v, torch.nn.Module):
return next(v.parameters()).device
return self.cuda_device
def _backup_cpu_state(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]:
# Backup module parameters and buffers
module_param_backup = {}
module_buffer_backup = {}
other_backup = {}
def save_module_state(mod: torch.nn.Module, prefix: str):
for name, param in mod.named_parameters():
if param is not None:
full_key = prefix + name
module_param_backup[full_key] = param.data
for name, buffer in mod.named_buffers():
if buffer is not None:
full_key = prefix + name
module_buffer_backup[full_key] = buffer.data
if isinstance(self.model, torch.nn.Module):
save_module_state(self.model, "")
else:
for name, attr_val in self.model.__dict__.items():
if isinstance(attr_val, torch.nn.Module):
save_module_state(attr_val, name + ".")
elif isinstance(attr_val, torch.Tensor):
other_backup[name] = attr_val
return module_param_backup, module_buffer_backup, other_backup
def _restore_cpu_state(self, backups: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]):
# Restore module parameters and buffers
module_param_backup, module_buffer_backup, other_backup = backups
def restore_module_state(mod: torch.nn.Module, prefix: str):
for name, param in mod.named_parameters():
full_key = prefix + name
if full_key in module_param_backup:
param.data = module_param_backup[full_key]
for name, buffer in mod.named_buffers():
full_key = prefix + name
if full_key in module_buffer_backup:
buffer.data = module_buffer_backup[full_key]
if isinstance(self.model, torch.nn.Module):
restore_module_state(self.model, "")
else:
for name, attr_val in self.model.__dict__.items():
if isinstance(attr_val, torch.nn.Module):
restore_module_state(attr_val, name + ".")
if not isinstance(self.model, torch.nn.Module):
for name, val in other_backup.items():
setattr(self.model, name, val)
# Unified on/offload executor
def _run_with_optional_offload(self, func: Callable[..., Any], *args, **kwargs):
if self.is_cpu_offload and self.is_running_on_gpu:
backups = self._backup_cpu_state()
self.model.to(self.cuda_device)
try:
return func(*args, **kwargs)
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
self._restore_cpu_state(backups)
else:
# Make sure model and args are on the same device
args = [
arg.to(self.device) if isinstance(arg, torch.Tensor) and arg.device != self.device else arg for arg in args
]
kwargs = {
k: v.to(self.device) if isinstance(v, torch.Tensor) and v.device != self.device else v
for k, v in kwargs.items()
}
return func(*args, **kwargs)
# Direct call (equivalent to forward)
def __call__(self, *args, **kwargs):
return self._run_with_optional_offload(self.model.__call__, *args, **kwargs)
# Explicit forward; some code calls model.forward(...)
def forward(self, *args, **kwargs):
return self._run_with_optional_offload(self.model.forward, *args, **kwargs)
# Key: passthrough all attrs/methods. For callables, wrap with on/offload; for non-compute methods, pass-through only with no device switch.
def __getattr__(self, name: str):
# Fetch attribute from the wrapped model first
attr = getattr(self.model, name)
# Wrap methods (except in whitelist)
if callable(attr) and name not in self._non_compute_methods:
def _wrapped(*args, **kwargs):
return self._run_with_optional_offload(attr, *args, **kwargs)
return _wrapped
return attr
def __dir__(self):
return sorted(set(list(super().__dir__()) + dir(self.model)))
def __setattr__(self, name: str, value: Any):
raise AttributeError("CPUOffloadWrapper is immutable")
def __repr__(self) -> str:
return f"CPUOffloadWrapper(is_cpu_offload={self.is_cpu_offload}, is_running_on_gpu={self.is_running_on_gpu}, model={repr(self.model)})"