dn6 HF Staff commited on
Commit
6d071e0
·
1 Parent(s): 8bb9483
Files changed (3) hide show
  1. __pycache__/aoti.cpython-310.pyc +0 -0
  2. aoti.py +18 -4
  3. app.py +16 -7
__pycache__/aoti.cpython-310.pyc CHANGED
Binary files a/__pycache__/aoti.cpython-310.pyc and b/__pycache__/aoti.cpython-310.pyc differ
 
aoti.py CHANGED
@@ -2,13 +2,27 @@ import torch
2
  from huggingface_hub import hf_hub_download
3
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel
4
  from spaces.zero.torch.aoti import ZeroGPUWeights
5
- from spaces.zero.torch.aoti import drain_module_parameters
6
 
7
 
8
- def aoti_load_(module: torch.nn.Module, repo_id: str, filename: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  compiled_graph_file = hf_hub_download(repo_id, filename)
10
- state_dict = module.state_dict()
11
- zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
 
 
12
  compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
13
 
14
  setattr(module, "forward", compiled)
 
2
  from huggingface_hub import hf_hub_download
3
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel
4
  from spaces.zero.torch.aoti import ZeroGPUWeights
 
5
 
6
 
7
+ def aoti_load_(
8
+ module: torch.nn.Module,
9
+ repo_id: str,
10
+ filename: str,
11
+ constants_filename: str,
12
+ ):
13
+ """Load an AOT compiled model and replace the module's forward method.
14
+
15
+ Args:
16
+ module: The module to replace forward with AOT compiled version
17
+ repo_id: HuggingFace repo ID containing the compiled model
18
+ filename: Filename of the compiled .pt2 file
19
+ constants_filename: Filename of the saved constants (from compiled.weights.constants_map)
20
+ """
21
  compiled_graph_file = hf_hub_download(repo_id, filename)
22
+ constants_file = hf_hub_download(repo_id, constants_filename)
23
+
24
+ constants_map = torch.load(constants_file, map_location="cpu", weights_only=True)
25
+ zerogpu_weights = ZeroGPUWeights(constants_map, to_cuda=True)
26
  compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
27
 
28
  setattr(module, "forward", compiled)
app.py CHANGED
@@ -113,14 +113,23 @@ def create_gpu_game_loop(command_queue: Queue):
113
  """
114
  pipe.to("cuda")
115
  pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16)
116
- pipe.transformer.apply_inference_patches()
117
  #pipe.transformer.quantize("fp8")
118
- #aoti_load_(pipe.transformer, "diffusers-internal-dev/world-engine-aot", "transformer-fp8.pt2")
119
- aoti_load_(pipe.transformer, "diffusers-internal-dev/world-engine-aot", "transformer-inference-patch.pt2")
120
-
121
- pipe.vae.bake_weight_norm()
122
- aoti_load_(pipe.vae.encoder, "diffusers-internal-dev/world-engine-aot", "encoder.pt2")
123
- aoti_load_(pipe.vae.decoder, "diffusers-internal-dev/world-engine-aot", "decoder.pt2")
 
 
 
 
 
 
 
 
 
124
 
125
  n_frames = pipe.transformer.config.n_frames
126
  print(f"Model loaded! (n_frames={n_frames})")
 
113
  """
114
  pipe.to("cuda")
115
  pipe.blocks.sub_blocks['before_denoise'].sub_blocks['setup_kv_cache']._setup_kv_cache(pipe.transformer, pipe.device, torch.bfloat16)
116
+ #pipe.transformer.apply_inference_patches()
117
  #pipe.transformer.quantize("fp8")
118
+ aoti_load_(
119
+ pipe.transformer,
120
+ "diffusers-internal-dev/world-engine-aot",
121
+ "transformer-fp8.pt2",
122
+ "transformer-fp8-constants.pt"
123
+ )
124
+ #aoti_load_(pipe.transformer, "diffusers-internal-dev/world-engine-aot", "transformer-inference-patch.pt2")
125
+ #pipe.vae.bake_weight_norm()
126
+ #aoti_load_(pipe.vae.encoder, "diffusers-internal-dev/world-engine-aot", "encoder.pt2")
127
+ aoti_load_(
128
+ pipe.vae.decoder,
129
+ "diffusers-internal-dev/world-engine-aot",
130
+ "decoder.pt2",
131
+ "decoder-constants.pt"
132
+ )
133
 
134
  n_frames = pipe.transformer.config.n_frames
135
  print(f"Model loaded! (n_frames={n_frames})")