multimodalart HF Staff commited on
Commit
07e76e1
·
verified ·
1 Parent(s): d792322

Create aoti.py

Browse files
Files changed (1) hide show
  1. aoti.py +15 -0
aoti.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from huggingface_hub import hf_hub_download
3
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
4
+ from spaces.zero.torch.aoti import ZeroGPUWeights
5
+ from spaces.zero.torch.aoti import drain_module_parameters
6
+
7
+
8
+ def aoti_load_(module: torch.nn.Module, repo_id: str, filename: str):
9
+ compiled_graph_file = hf_hub_download(repo_id, filename)
10
+ state_dict = module.state_dict()
11
+ zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
12
+ compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
13
+
14
+ setattr(module, "forward", compiled)
15
+ drain_module_parameters(module)