DawnC commited on
Commit
7b2098a
·
verified ·
1 Parent(s): b3626fc

Delete aoti.py

Browse files
Files changed (1) hide show
  1. aoti.py +0 -66
aoti.py DELETED
@@ -1,66 +0,0 @@
1
- """
2
- AOTI (Ahead-of-Time Compilation) Utilities for ZeroGPU
3
- Loads pre-compiled model blocks from HuggingFace Hub
4
-
5
- Source: Based on zerogpu-aoti implementation
6
- https://huggingface.co/spaces/zerogpu-aoti/wan2-2-fp8da-aoti-faster
7
- Credits: ZeroGPU AoTI Team (Charles Bensimon, Sayak Paul, Linoy Tsaban, Apolinário Passos)
8
- """
9
-
10
- from typing import cast
11
-
12
- import torch
13
-
14
- from huggingface_hub import hf_hub_download
15
-
16
- from spaces.zero.torch.aoti import ZeroGPUCompiledModel
17
- from spaces.zero.torch.aoti import ZeroGPUWeights
18
-
19
- from torch._functorch._aot_autograd.subclass_parametrization import unwrap_tensor_subclass_parameters
20
-
21
-
22
- def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
23
- """
24
- Creates a shallow copy of a PyTorch module while preserving its structure.
25
-
26
- Args:
27
- module: PyTorch module to clone
28
-
29
- Returns:
30
- Cloned module with copied parameters and buffers
31
- """
32
- clone = object.__new__(module.__class__)
33
- clone.__dict__ = module.__dict__.copy()
34
- clone._parameters = module._parameters.copy()
35
- clone._buffers = module._buffers.copy()
36
- clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
37
- return clone
38
-
39
-
40
- def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
41
- """
42
- Loads pre-compiled AOTI blocks from Hugging Face Hub.
43
-
44
- Args:
45
- module: The model containing repeated blocks to compile
46
- repo_id: Hugging Face repository ID (e.g., 'zerogpu-aoti/Wan2')
47
- variant: Optional variant suffix for subfolder selection (e.g., 'fp8da')
48
-
49
- Example:
50
- >>> aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
51
- """
52
- repeated_blocks = cast(list[str], module._repeated_blocks)
53
-
54
- aoti_files = {name: hf_hub_download(
55
- repo_id=repo_id,
56
- filename='package.pt2',
57
- subfolder=name if variant is None else f'{name}.{variant}',
58
- ) for name in repeated_blocks}
59
-
60
- for block_name, aoti_file in aoti_files.items():
61
- for block in module.modules():
62
- if block.__class__.__name__ == block_name:
63
- block_ = _shallow_clone_module(block)
64
- unwrap_tensor_subclass_parameters(block_)
65
- weights = ZeroGPUWeights(block_.state_dict())
66
- block.forward = ZeroGPUCompiledModel(aoti_file, weights)