waypoint-1-small / aoti.py
dn6's picture
dn6 HF Staff
update
6d071e0
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
def aoti_load_(
module: torch.nn.Module,
repo_id: str,
filename: str,
constants_filename: str,
):
"""Load an AOT compiled model and replace the module's forward method.
Args:
module: The module to replace forward with AOT compiled version
repo_id: HuggingFace repo ID containing the compiled model
filename: Filename of the compiled .pt2 file
constants_filename: Filename of the saved constants (from compiled.weights.constants_map)
"""
compiled_graph_file = hf_hub_download(repo_id, filename)
constants_file = hf_hub_download(repo_id, constants_filename)
constants_map = torch.load(constants_file, map_location="cpu", weights_only=True)
zerogpu_weights = ZeroGPUWeights(constants_map, to_cuda=True)
compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
setattr(module, "forward", compiled)