File size: 1,059 Bytes
1692330
 
 
 
 
 
6d071e0
 
 
 
 
 
 
 
 
 
 
 
 
 
1692330
6d071e0
 
 
 
1692330
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights


def aoti_load_(
    module: torch.nn.Module,
    repo_id: str,
    filename: str,
    constants_filename: str,
):
    """Load an AOT compiled model and replace the module's forward method.

    Args:
        module: The module to replace forward with AOT compiled version
        repo_id: HuggingFace repo ID containing the compiled model
        filename: Filename of the compiled .pt2 file
        constants_filename: Filename of the saved constants (from compiled.weights.constants_map)
    """
    compiled_graph_file = hf_hub_download(repo_id, filename)
    constants_file = hf_hub_download(repo_id, constants_filename)

    constants_map = torch.load(constants_file, map_location="cpu", weights_only=True)
    zerogpu_weights = ZeroGPUWeights(constants_map, to_cuda=True)
    compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)

    setattr(module, "forward", compiled)