Spaces:
Runtime error
Runtime error
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from spaces.zero.torch.aoti import ZeroGPUCompiledModel | |
| from spaces.zero.torch.aoti import ZeroGPUWeights | |
| def aoti_load_( | |
| module: torch.nn.Module, | |
| repo_id: str, | |
| filename: str, | |
| constants_filename: str, | |
| ): | |
| """Load an AOT compiled model and replace the module's forward method. | |
| Args: | |
| module: The module to replace forward with AOT compiled version | |
| repo_id: HuggingFace repo ID containing the compiled model | |
| filename: Filename of the compiled .pt2 file | |
| constants_filename: Filename of the saved constants (from compiled.weights.constants_map) | |
| """ | |
| compiled_graph_file = hf_hub_download(repo_id, filename) | |
| constants_file = hf_hub_download(repo_id, constants_filename) | |
| constants_map = torch.load(constants_file, map_location="cpu", weights_only=True) | |
| zerogpu_weights = ZeroGPUWeights(constants_map, to_cuda=True) | |
| compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights) | |
| setattr(module, "forward", compiled) | |