Spaces:
Paused
Paused
File size: 1,059 Bytes
1692330 6d071e0 1692330 6d071e0 1692330 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
def aoti_load_(
module: torch.nn.Module,
repo_id: str,
filename: str,
constants_filename: str,
):
"""Load an AOT compiled model and replace the module's forward method.
Args:
module: The module to replace forward with AOT compiled version
repo_id: HuggingFace repo ID containing the compiled model
filename: Filename of the compiled .pt2 file
constants_filename: Filename of the saved constants (from compiled.weights.constants_map)
"""
compiled_graph_file = hf_hub_download(repo_id, filename)
constants_file = hf_hub_download(repo_id, constants_filename)
constants_map = torch.load(constants_file, map_location="cpu", weights_only=True)
zerogpu_weights = ZeroGPUWeights(constants_map, to_cuda=True)
compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
setattr(module, "forward", compiled)
|