import torch from huggingface_hub import hf_hub_download from spaces.zero.torch.aoti import ZeroGPUCompiledModel from spaces.zero.torch.aoti import ZeroGPUWeights def aoti_load_( module: torch.nn.Module, repo_id: str, filename: str, constants_filename: str, ): """Load an AOT compiled model and replace the module's forward method. Args: module: The module to replace forward with AOT compiled version repo_id: HuggingFace repo ID containing the compiled model filename: Filename of the compiled .pt2 file constants_filename: Filename of the saved constants (from compiled.weights.constants_map) """ compiled_graph_file = hf_hub_download(repo_id, filename) constants_file = hf_hub_download(repo_id, constants_filename) constants_map = torch.load(constants_file, map_location="cpu", weights_only=True) zerogpu_weights = ZeroGPUWeights(constants_map, to_cuda=True) compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights) setattr(module, "forward", compiled)