root
commited on
Commit
·
036458a
1
Parent(s):
33ec12f
Fix LFS and upload model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- __init__.py +18 -0
- assets/{banner.png → input_0_0.png} +2 -2
- assets/{logo.png → input_1_0.png} +2 -2
- assets/{robot.png → input_1_1.png} +2 -2
- assets/{framework.png → input_2_0.png} +2 -2
- assets/{gsb.png → input_2_1.png} +2 -2
- assets/input_2_2.png +3 -0
- assets/pg_imgs/image1.png +0 -3
- assets/pg_imgs/image2.png +0 -3
- assets/pg_imgs/image3.png +0 -3
- assets/pg_imgs/image4.png +0 -3
- assets/pg_imgs/image5.png +0 -3
- assets/pg_imgs/image6.png +0 -3
- assets/pg_imgs/image7.png +0 -3
- assets/pg_imgs/image8.png +0 -3
- assets/ssae_side_by_side_comparison.png +0 -3
- assets/ssae_side_by_side_heatmap.png +0 -3
- assets/user.png +0 -3
- autoencoder_kl_3d.py +1081 -0
- cache_utils.py +226 -0
- config.json +283 -0
- configuration_hunyuan_image_3.py +310 -0
- generation_config.json +21 -0
- hunyuan_image_3_pipeline.py +913 -0
- image_processor.py +465 -0
- model-0001-of-0032.safetensors +3 -0
- model-0002-of-0032.safetensors +3 -0
- model-0003-of-0032.safetensors +3 -0
- model-0004-of-0032.safetensors +3 -0
- model-0005-of-0032.safetensors +3 -0
- model-0006-of-0032.safetensors +3 -0
- model-0007-of-0032.safetensors +3 -0
- model-0008-of-0032.safetensors +3 -0
- model-0009-of-0032.safetensors +3 -0
- model-0010-of-0032.safetensors +3 -0
- model-0011-of-0032.safetensors +3 -0
- model-0012-of-0032.safetensors +3 -0
- model-0013-of-0032.safetensors +3 -0
- model-0014-of-0032.safetensors +3 -0
- model-0015-of-0032.safetensors +3 -0
- model-0016-of-0032.safetensors +3 -0
- model-0017-of-0032.safetensors +3 -0
- model-0018-of-0032.safetensors +3 -0
- model-0019-of-0032.safetensors +3 -0
- model-0020-of-0032.safetensors +3 -0
- model-0021-of-0032.safetensors +3 -0
- model-0022-of-0032.safetensors +3 -0
- model-0023-of-0032.safetensors +3 -0
- model-0024-of-0032.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
assets/banner_all.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 37 |
assets/banner_all.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
*.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
assets/**/*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from utils import _LazyModule
|
| 4 |
+
from utils.import_utils import define_import_structure
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from .configuration_hunyuan_image_3 import *
|
| 9 |
+
from .modeling_hunyuan_image_3 import *
|
| 10 |
+
from .autoencoder_kl_3d import *
|
| 11 |
+
from .image_processor import *
|
| 12 |
+
from .siglip2 import *
|
| 13 |
+
from .tokenization_hunyuan_image_3 import *
|
| 14 |
+
else:
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
_file = globals()["__file__"]
|
| 18 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
assets/{banner.png → input_0_0.png}
RENAMED
|
File without changes
|
assets/{logo.png → input_1_0.png}
RENAMED
|
File without changes
|
assets/{robot.png → input_1_1.png}
RENAMED
|
File without changes
|
assets/{framework.png → input_2_0.png}
RENAMED
|
File without changes
|
assets/{gsb.png → input_2_1.png}
RENAMED
|
File without changes
|
assets/input_2_2.png
ADDED
|
Git LFS Details
|
assets/pg_imgs/image1.png
DELETED
Git LFS Details
|
assets/pg_imgs/image2.png
DELETED
Git LFS Details
|
assets/pg_imgs/image3.png
DELETED
Git LFS Details
|
assets/pg_imgs/image4.png
DELETED
Git LFS Details
|
assets/pg_imgs/image5.png
DELETED
Git LFS Details
|
assets/pg_imgs/image6.png
DELETED
Git LFS Details
|
assets/pg_imgs/image7.png
DELETED
Git LFS Details
|
assets/pg_imgs/image8.png
DELETED
Git LFS Details
|
assets/ssae_side_by_side_comparison.png
DELETED
Git LFS Details
|
assets/ssae_side_by_side_heatmap.png
DELETED
Git LFS Details
|
assets/user.png
DELETED
Git LFS Details
|
autoencoder_kl_3d.py
ADDED
|
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reference code
|
| 3 |
+
[FLUX] https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
|
| 4 |
+
[DCAE] https://github.com/mit-han-lab/efficientvit/blob/master/efficientvit/models/efficientvit/dc_ae.py
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Tuple, Optional
|
| 9 |
+
import math
|
| 10 |
+
import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
import torch.multiprocessing as mp
|
| 18 |
+
|
| 19 |
+
from safetensors import safe_open
|
| 20 |
+
import os
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from collections.abc import Iterable
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
from diffusers.utils import BaseOutput
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DiagonalGaussianDistribution(object):
|
| 32 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
| 33 |
+
if parameters.ndim == 3:
|
| 34 |
+
dim = 2 # (B, L, C)
|
| 35 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
| 36 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
self.parameters = parameters
|
| 40 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 41 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 42 |
+
self.deterministic = deterministic
|
| 43 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 44 |
+
self.var = torch.exp(self.logvar)
|
| 45 |
+
if self.deterministic:
|
| 46 |
+
self.var = self.std = torch.zeros_like(
|
| 47 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
| 51 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
| 52 |
+
sample = randn_tensor(
|
| 53 |
+
self.mean.shape,
|
| 54 |
+
generator=generator,
|
| 55 |
+
device=self.parameters.device,
|
| 56 |
+
dtype=self.parameters.dtype,
|
| 57 |
+
)
|
| 58 |
+
x = self.mean + self.std * sample
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
| 62 |
+
if self.deterministic:
|
| 63 |
+
return torch.Tensor([0.0])
|
| 64 |
+
else:
|
| 65 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
| 66 |
+
if other is None:
|
| 67 |
+
return 0.5 * torch.sum(
|
| 68 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 69 |
+
dim=reduce_dim,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
return 0.5 * torch.sum(
|
| 73 |
+
torch.pow(self.mean - other.mean, 2) / other.var +
|
| 74 |
+
self.var / other.var -
|
| 75 |
+
1.0 -
|
| 76 |
+
self.logvar +
|
| 77 |
+
other.logvar,
|
| 78 |
+
dim=reduce_dim,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
| 82 |
+
if self.deterministic:
|
| 83 |
+
return torch.Tensor([0.0])
|
| 84 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 85 |
+
return 0.5 * torch.sum(
|
| 86 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 87 |
+
dim=dims,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def mode(self) -> torch.Tensor:
|
| 91 |
+
return self.mean
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class DecoderOutput(BaseOutput):
|
| 95 |
+
sample: torch.FloatTensor
|
| 96 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
| 97 |
+
|
| 98 |
+
def swish(x: Tensor) -> Tensor:
|
| 99 |
+
return x * torch.sigmoid(x)
|
| 100 |
+
|
| 101 |
+
def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
|
| 102 |
+
def create_custom_forward(module):
|
| 103 |
+
def custom_forward(*inputs):
|
| 104 |
+
return module(*inputs)
|
| 105 |
+
return custom_forward
|
| 106 |
+
|
| 107 |
+
if use_checkpointing:
|
| 108 |
+
return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
|
| 109 |
+
else:
|
| 110 |
+
return module(*inputs)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Conv3d(nn.Conv3d):
|
| 114 |
+
"""Perform Conv3d on patches with numerical differences from nn.Conv3d within 1e-5. Only symmetric padding is supported."""
|
| 115 |
+
|
| 116 |
+
def forward(self, input):
|
| 117 |
+
B, C, T, H, W = input.shape
|
| 118 |
+
memory_count = (C * T * H * W) * 2 / 1024**3
|
| 119 |
+
if memory_count > 2:
|
| 120 |
+
n_split = math.ceil(memory_count / 2)
|
| 121 |
+
assert n_split >= 2
|
| 122 |
+
chunks = torch.chunk(input, chunks=n_split, dim=-3)
|
| 123 |
+
padded_chunks = []
|
| 124 |
+
for i in range(len(chunks)):
|
| 125 |
+
if self.padding[0] > 0:
|
| 126 |
+
padded_chunk = F.pad(
|
| 127 |
+
chunks[i],
|
| 128 |
+
(0, 0, 0, 0, self.padding[0], self.padding[0]),
|
| 129 |
+
mode="constant" if self.padding_mode == "zeros" else self.padding_mode,
|
| 130 |
+
value=0,
|
| 131 |
+
)
|
| 132 |
+
if i > 0:
|
| 133 |
+
padded_chunk[:, :, :self.padding[0]] = chunks[i - 1][:, :, -self.padding[0]:]
|
| 134 |
+
if i < len(chunks) - 1:
|
| 135 |
+
padded_chunk[:, :, -self.padding[0]:] = chunks[i + 1][:, :, :self.padding[0]]
|
| 136 |
+
else:
|
| 137 |
+
padded_chunk = chunks[i]
|
| 138 |
+
padded_chunks.append(padded_chunk)
|
| 139 |
+
padding_bak = self.padding
|
| 140 |
+
self.padding = (0, self.padding[1], self.padding[2])
|
| 141 |
+
outputs = []
|
| 142 |
+
for i in range(len(padded_chunks)):
|
| 143 |
+
outputs.append(super().forward(padded_chunks[i]))
|
| 144 |
+
self.padding = padding_bak
|
| 145 |
+
return torch.cat(outputs, dim=-3)
|
| 146 |
+
else:
|
| 147 |
+
return super().forward(input)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class AttnBlock(nn.Module):
|
| 151 |
+
def __init__(self, in_channels: int):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.in_channels = in_channels
|
| 154 |
+
|
| 155 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 156 |
+
|
| 157 |
+
self.q = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 158 |
+
self.k = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 159 |
+
self.v = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 160 |
+
self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1)
|
| 161 |
+
|
| 162 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 163 |
+
h_ = self.norm(h_)
|
| 164 |
+
q = self.q(h_)
|
| 165 |
+
k = self.k(h_)
|
| 166 |
+
v = self.v(h_)
|
| 167 |
+
|
| 168 |
+
b, c, f, h, w = q.shape
|
| 169 |
+
q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous()
|
| 170 |
+
k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous()
|
| 171 |
+
v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous()
|
| 172 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 173 |
+
|
| 174 |
+
return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b)
|
| 175 |
+
|
| 176 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 177 |
+
return x + self.proj_out(self.attention(x))
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class ResnetBlock(nn.Module):
|
| 181 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.in_channels = in_channels
|
| 184 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 185 |
+
self.out_channels = out_channels
|
| 186 |
+
|
| 187 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 188 |
+
self.conv1 = Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 189 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 190 |
+
self.conv2 = Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 191 |
+
if self.in_channels != self.out_channels:
|
| 192 |
+
self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 193 |
+
|
| 194 |
+
def forward(self, x):
|
| 195 |
+
h = x
|
| 196 |
+
h = self.norm1(h)
|
| 197 |
+
h = swish(h)
|
| 198 |
+
h = self.conv1(h)
|
| 199 |
+
|
| 200 |
+
h = self.norm2(h)
|
| 201 |
+
h = swish(h)
|
| 202 |
+
h = self.conv2(h)
|
| 203 |
+
|
| 204 |
+
if self.in_channels != self.out_channels:
|
| 205 |
+
x = self.nin_shortcut(x)
|
| 206 |
+
return x + h
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class Downsample(nn.Module):
|
| 210 |
+
def __init__(self, in_channels: int, add_temporal_downsample: bool = True):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.add_temporal_downsample = add_temporal_downsample
|
| 213 |
+
stride = (2, 2, 2) if add_temporal_downsample else (1, 2, 2) # THW
|
| 214 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 215 |
+
self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0)
|
| 216 |
+
|
| 217 |
+
def forward(self, x: Tensor):
|
| 218 |
+
spatial_pad = (0, 1, 0, 1, 0, 0) # WHT
|
| 219 |
+
x = nn.functional.pad(x, spatial_pad, mode="constant", value=0)
|
| 220 |
+
|
| 221 |
+
temporal_pad = (0, 0, 0, 0, 0, 1) if self.add_temporal_downsample else (0, 0, 0, 0, 1, 1)
|
| 222 |
+
x = nn.functional.pad(x, temporal_pad, mode="replicate")
|
| 223 |
+
|
| 224 |
+
x = self.conv(x)
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class DownsampleDCAE(nn.Module):
|
| 229 |
+
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
| 230 |
+
super().__init__()
|
| 231 |
+
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
| 232 |
+
assert out_channels % factor == 0
|
| 233 |
+
self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
| 234 |
+
|
| 235 |
+
self.add_temporal_downsample = add_temporal_downsample
|
| 236 |
+
self.group_size = factor * in_channels // out_channels
|
| 237 |
+
|
| 238 |
+
def forward(self, x: Tensor):
|
| 239 |
+
r1 = 2 if self.add_temporal_downsample else 1
|
| 240 |
+
h = self.conv(x)
|
| 241 |
+
h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
| 242 |
+
shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
|
| 243 |
+
|
| 244 |
+
B, C, T, H, W = shortcut.shape
|
| 245 |
+
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
| 246 |
+
return h + shortcut
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class Upsample(nn.Module):
|
| 250 |
+
def __init__(self, in_channels: int, add_temporal_upsample: bool = True):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.add_temporal_upsample = add_temporal_upsample
|
| 253 |
+
self.scale_factor = (2, 2, 2) if add_temporal_upsample else (1, 2, 2) # THW
|
| 254 |
+
self.conv = Conv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 255 |
+
|
| 256 |
+
def forward(self, x: Tensor):
|
| 257 |
+
x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
|
| 258 |
+
x = self.conv(x)
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class UpsampleDCAE(nn.Module):
|
| 263 |
+
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
|
| 264 |
+
super().__init__()
|
| 265 |
+
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
|
| 266 |
+
self.conv = Conv3d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
| 267 |
+
|
| 268 |
+
self.add_temporal_upsample = add_temporal_upsample
|
| 269 |
+
self.repeats = factor * out_channels // in_channels
|
| 270 |
+
|
| 271 |
+
def forward(self, x: Tensor):
|
| 272 |
+
r1 = 2 if self.add_temporal_upsample else 1
|
| 273 |
+
h = self.conv(x)
|
| 274 |
+
h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
|
| 275 |
+
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
| 276 |
+
shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
|
| 277 |
+
return h + shortcut
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class Encoder(nn.Module):
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
in_channels: int,
|
| 284 |
+
z_channels: int,
|
| 285 |
+
block_out_channels: Tuple[int, ...],
|
| 286 |
+
num_res_blocks: int,
|
| 287 |
+
ffactor_spatial: int,
|
| 288 |
+
ffactor_temporal: int,
|
| 289 |
+
downsample_match_channel: bool = True,
|
| 290 |
+
):
|
| 291 |
+
super().__init__()
|
| 292 |
+
assert block_out_channels[-1] % (2 * z_channels) == 0
|
| 293 |
+
|
| 294 |
+
self.z_channels = z_channels
|
| 295 |
+
self.block_out_channels = block_out_channels
|
| 296 |
+
self.num_res_blocks = num_res_blocks
|
| 297 |
+
|
| 298 |
+
# downsampling
|
| 299 |
+
self.conv_in = Conv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 300 |
+
|
| 301 |
+
self.down = nn.ModuleList()
|
| 302 |
+
block_in = block_out_channels[0]
|
| 303 |
+
for i_level, ch in enumerate(block_out_channels):
|
| 304 |
+
block = nn.ModuleList()
|
| 305 |
+
block_out = ch
|
| 306 |
+
for _ in range(self.num_res_blocks):
|
| 307 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 308 |
+
block_in = block_out
|
| 309 |
+
down = nn.Module()
|
| 310 |
+
down.block = block
|
| 311 |
+
|
| 312 |
+
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
|
| 313 |
+
add_temporal_downsample = add_spatial_downsample and bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal))
|
| 314 |
+
if add_spatial_downsample or add_temporal_downsample:
|
| 315 |
+
assert i_level < len(block_out_channels) - 1
|
| 316 |
+
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
|
| 317 |
+
down.downsample = DownsampleDCAE(block_in, block_out, add_temporal_downsample)
|
| 318 |
+
block_in = block_out
|
| 319 |
+
self.down.append(down)
|
| 320 |
+
|
| 321 |
+
# middle
|
| 322 |
+
self.mid = nn.Module()
|
| 323 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 324 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 325 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 326 |
+
|
| 327 |
+
# end
|
| 328 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 329 |
+
self.conv_out = Conv3d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 330 |
+
|
| 331 |
+
self.gradient_checkpointing = False
|
| 332 |
+
|
| 333 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
use_checkpointing = bool(self.training and self.gradient_checkpointing)
|
| 336 |
+
|
| 337 |
+
# downsampling
|
| 338 |
+
h = self.conv_in(x)
|
| 339 |
+
for i_level in range(len(self.block_out_channels)):
|
| 340 |
+
for i_block in range(self.num_res_blocks):
|
| 341 |
+
h = forward_with_checkpointing(self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
|
| 342 |
+
if hasattr(self.down[i_level], "downsample"):
|
| 343 |
+
h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
|
| 344 |
+
|
| 345 |
+
# middle
|
| 346 |
+
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
|
| 347 |
+
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
|
| 348 |
+
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
|
| 349 |
+
|
| 350 |
+
# end
|
| 351 |
+
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
|
| 352 |
+
shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2)
|
| 353 |
+
h = self.norm_out(h)
|
| 354 |
+
h = swish(h)
|
| 355 |
+
h = self.conv_out(h)
|
| 356 |
+
h += shortcut
|
| 357 |
+
return h
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class Decoder(nn.Module):
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
z_channels: int,
|
| 364 |
+
out_channels: int,
|
| 365 |
+
block_out_channels: Tuple[int, ...],
|
| 366 |
+
num_res_blocks: int,
|
| 367 |
+
ffactor_spatial: int,
|
| 368 |
+
ffactor_temporal: int,
|
| 369 |
+
upsample_match_channel: bool = True,
|
| 370 |
+
):
|
| 371 |
+
super().__init__()
|
| 372 |
+
assert block_out_channels[0] % z_channels == 0
|
| 373 |
+
|
| 374 |
+
self.z_channels = z_channels
|
| 375 |
+
self.block_out_channels = block_out_channels
|
| 376 |
+
self.num_res_blocks = num_res_blocks
|
| 377 |
+
|
| 378 |
+
# z to block_in
|
| 379 |
+
block_in = block_out_channels[0]
|
| 380 |
+
self.conv_in = Conv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 381 |
+
|
| 382 |
+
# middle
|
| 383 |
+
self.mid = nn.Module()
|
| 384 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 385 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 386 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 387 |
+
|
| 388 |
+
# upsampling
|
| 389 |
+
self.up = nn.ModuleList()
|
| 390 |
+
for i_level, ch in enumerate(block_out_channels):
|
| 391 |
+
block = nn.ModuleList()
|
| 392 |
+
block_out = ch
|
| 393 |
+
for _ in range(self.num_res_blocks + 1):
|
| 394 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 395 |
+
block_in = block_out
|
| 396 |
+
up = nn.Module()
|
| 397 |
+
up.block = block
|
| 398 |
+
|
| 399 |
+
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
|
| 400 |
+
add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal))
|
| 401 |
+
if add_spatial_upsample or add_temporal_upsample:
|
| 402 |
+
assert i_level < len(block_out_channels) - 1
|
| 403 |
+
block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
|
| 404 |
+
up.upsample = UpsampleDCAE(block_in, block_out, add_temporal_upsample)
|
| 405 |
+
block_in = block_out
|
| 406 |
+
self.up.append(up)
|
| 407 |
+
|
| 408 |
+
# end
|
| 409 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 410 |
+
self.conv_out = Conv3d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
| 411 |
+
|
| 412 |
+
self.gradient_checkpointing = False
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 416 |
+
with torch.no_grad():
|
| 417 |
+
use_checkpointing = bool(self.training and self.gradient_checkpointing)
|
| 418 |
+
# z to block_in
|
| 419 |
+
repeats = self.block_out_channels[0] // (self.z_channels)
|
| 420 |
+
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
| 421 |
+
# middle
|
| 422 |
+
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
|
| 423 |
+
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
|
| 424 |
+
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
|
| 425 |
+
# upsampling
|
| 426 |
+
for i_level in range(len(self.block_out_channels)):
|
| 427 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 428 |
+
h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
|
| 429 |
+
if hasattr(self.up[i_level], "upsample"):
|
| 430 |
+
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
|
| 431 |
+
# end
|
| 432 |
+
h = self.norm_out(h)
|
| 433 |
+
h = swish(h)
|
| 434 |
+
h = self.conv_out(h)
|
| 435 |
+
return h
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
|
| 439 |
+
_supports_gradient_checkpointing = True
|
| 440 |
+
|
| 441 |
+
@register_to_config
|
| 442 |
+
def __init__(
|
| 443 |
+
self,
|
| 444 |
+
in_channels: int,
|
| 445 |
+
out_channels: int,
|
| 446 |
+
latent_channels: int,
|
| 447 |
+
block_out_channels: Tuple[int, ...],
|
| 448 |
+
layers_per_block: int,
|
| 449 |
+
ffactor_spatial: int,
|
| 450 |
+
ffactor_temporal: int,
|
| 451 |
+
sample_size: int,
|
| 452 |
+
sample_tsize: int,
|
| 453 |
+
scaling_factor: float = None,
|
| 454 |
+
shift_factor: Optional[float] = None,
|
| 455 |
+
downsample_match_channel: bool = True,
|
| 456 |
+
upsample_match_channel: bool = True,
|
| 457 |
+
only_encoder: bool = False,
|
| 458 |
+
only_decoder: bool = False,
|
| 459 |
+
):
|
| 460 |
+
super().__init__()
|
| 461 |
+
self.ffactor_spatial = ffactor_spatial
|
| 462 |
+
self.ffactor_temporal = ffactor_temporal
|
| 463 |
+
self.scaling_factor = scaling_factor
|
| 464 |
+
self.shift_factor = shift_factor
|
| 465 |
+
|
| 466 |
+
if not only_decoder:
|
| 467 |
+
self.encoder = Encoder(
|
| 468 |
+
in_channels=in_channels,
|
| 469 |
+
z_channels=latent_channels,
|
| 470 |
+
block_out_channels=block_out_channels,
|
| 471 |
+
num_res_blocks=layers_per_block,
|
| 472 |
+
ffactor_spatial=ffactor_spatial,
|
| 473 |
+
ffactor_temporal=ffactor_temporal,
|
| 474 |
+
downsample_match_channel=downsample_match_channel,
|
| 475 |
+
)
|
| 476 |
+
if not only_encoder:
|
| 477 |
+
self.decoder = Decoder(
|
| 478 |
+
z_channels=latent_channels,
|
| 479 |
+
out_channels=out_channels,
|
| 480 |
+
block_out_channels=list(reversed(block_out_channels)),
|
| 481 |
+
num_res_blocks=layers_per_block,
|
| 482 |
+
ffactor_spatial=ffactor_spatial,
|
| 483 |
+
ffactor_temporal=ffactor_temporal,
|
| 484 |
+
upsample_match_channel=upsample_match_channel,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
self.use_slicing = False
|
| 488 |
+
self.slicing_bsz = 1
|
| 489 |
+
self.use_spatial_tiling = False
|
| 490 |
+
self.use_temporal_tiling = False
|
| 491 |
+
self.use_tiling_during_training = False
|
| 492 |
+
|
| 493 |
+
# only relevant if vae tiling is enabled
|
| 494 |
+
self.tile_sample_min_size = sample_size
|
| 495 |
+
self.tile_latent_min_size = sample_size // ffactor_spatial
|
| 496 |
+
self.tile_sample_min_tsize = sample_tsize
|
| 497 |
+
self.tile_latent_min_tsize = sample_tsize // ffactor_temporal
|
| 498 |
+
self.tile_overlap_factor = 0.125
|
| 499 |
+
|
| 500 |
+
self.use_compile = False
|
| 501 |
+
|
| 502 |
+
self.empty_cache = torch.empty(0, device="cuda")
|
| 503 |
+
|
| 504 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 505 |
+
if isinstance(module, (Encoder, Decoder)):
|
| 506 |
+
module.gradient_checkpointing = value
|
| 507 |
+
|
| 508 |
+
def enable_tiling_during_training(self, use_tiling: bool = True):
|
| 509 |
+
self.use_tiling_during_training = use_tiling
|
| 510 |
+
|
| 511 |
+
def disable_tiling_during_training(self):
|
| 512 |
+
self.enable_tiling_during_training(False)
|
| 513 |
+
|
| 514 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
| 515 |
+
self.use_temporal_tiling = use_tiling
|
| 516 |
+
|
| 517 |
+
def disable_temporal_tiling(self):
|
| 518 |
+
self.enable_temporal_tiling(False)
|
| 519 |
+
|
| 520 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
| 521 |
+
self.use_spatial_tiling = use_tiling
|
| 522 |
+
|
| 523 |
+
def disable_spatial_tiling(self):
|
| 524 |
+
self.enable_spatial_tiling(False)
|
| 525 |
+
|
| 526 |
+
def enable_tiling(self, use_tiling: bool = True):
|
| 527 |
+
self.enable_spatial_tiling(use_tiling)
|
| 528 |
+
|
| 529 |
+
def disable_tiling(self):
|
| 530 |
+
self.disable_spatial_tiling()
|
| 531 |
+
|
| 532 |
+
def enable_slicing(self):
|
| 533 |
+
self.use_slicing = True
|
| 534 |
+
|
| 535 |
+
def disable_slicing(self):
|
| 536 |
+
self.use_slicing = False
|
| 537 |
+
|
| 538 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
|
| 539 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 540 |
+
for x in range(blend_extent):
|
| 541 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
| 542 |
+
return b
|
| 543 |
+
|
| 544 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
|
| 545 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 546 |
+
for y in range(blend_extent):
|
| 547 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
| 548 |
+
return b
|
| 549 |
+
|
| 550 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
|
| 551 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
| 552 |
+
for x in range(blend_extent):
|
| 553 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
|
| 554 |
+
return b
|
| 555 |
+
|
| 556 |
+
def spatial_tiled_encode(self, x: torch.Tensor):
|
| 557 |
+
B, C, T, H, W = x.shape
|
| 558 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
| 559 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) # 8 * 0.25 = 2
|
| 560 |
+
row_limit = self.tile_latent_min_size - blend_extent # 8 - 2 = 6
|
| 561 |
+
|
| 562 |
+
rows = []
|
| 563 |
+
for i in range(0, H, overlap_size):
|
| 564 |
+
row = []
|
| 565 |
+
for j in range(0, W, overlap_size):
|
| 566 |
+
tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
|
| 567 |
+
tile = self.encoder(tile)
|
| 568 |
+
row.append(tile)
|
| 569 |
+
rows.append(row)
|
| 570 |
+
result_rows = []
|
| 571 |
+
for i, row in enumerate(rows):
|
| 572 |
+
result_row = []
|
| 573 |
+
for j, tile in enumerate(row):
|
| 574 |
+
if i > 0:
|
| 575 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 576 |
+
if j > 0:
|
| 577 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 578 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 579 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 580 |
+
moments = torch.cat(result_rows, dim=-2)
|
| 581 |
+
return moments
|
| 582 |
+
|
| 583 |
+
def temporal_tiled_encode(self, x: torch.Tensor):
|
| 584 |
+
B, C, T, H, W = x.shape
|
| 585 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor)) # 64 * (1 - 0.25) = 48
|
| 586 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor) # 8 * 0.25 = 2
|
| 587 |
+
t_limit = self.tile_latent_min_tsize - blend_extent # 8 - 2 = 6
|
| 588 |
+
|
| 589 |
+
row = []
|
| 590 |
+
for i in range(0, T, overlap_size):
|
| 591 |
+
tile = x[:, :, i: i + self.tile_sample_min_tsize, :, :]
|
| 592 |
+
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
|
| 593 |
+
tile = self.spatial_tiled_encode(tile)
|
| 594 |
+
else:
|
| 595 |
+
tile = self.encoder(tile)
|
| 596 |
+
row.append(tile)
|
| 597 |
+
result_row = []
|
| 598 |
+
for i, tile in enumerate(row):
|
| 599 |
+
if i > 0:
|
| 600 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 601 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 602 |
+
moments = torch.cat(result_row, dim=-3)
|
| 603 |
+
return moments
|
| 604 |
+
|
| 605 |
+
def spatial_tiled_decode(self, z: torch.Tensor):
|
| 606 |
+
B, C, T, H, W = z.shape
|
| 607 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) # 24 * (1 - 0.125) = 21
|
| 608 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) # 384 * 0.125 = 48
|
| 609 |
+
row_limit = self.tile_sample_min_size - blend_extent # 384 - 48 = 336
|
| 610 |
+
|
| 611 |
+
# 分布式/多卡:输入不做 padding -> 每 rank 对解码输出做右/下 padding -> GPU all_gather -> rank0重组/融合/裁剪
|
| 612 |
+
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
| 613 |
+
rank = dist.get_rank()
|
| 614 |
+
world_size = dist.get_world_size()
|
| 615 |
+
|
| 616 |
+
# 统计tile
|
| 617 |
+
num_rows = math.ceil(H / overlap_size)
|
| 618 |
+
num_cols = math.ceil(W / overlap_size)
|
| 619 |
+
total_tiles = num_rows * num_cols
|
| 620 |
+
tiles_per_rank = math.ceil(total_tiles / world_size)
|
| 621 |
+
|
| 622 |
+
print(f"==={torch.distributed.get_rank()}, {total_tiles=}, {tiles_per_rank=}, {world_size=}")
|
| 623 |
+
|
| 624 |
+
# 本 rank 的 tile 索引(循环分配):rank, rank+world_size,
|
| 625 |
+
my_linear_indices = list(range(rank, total_tiles, world_size))
|
| 626 |
+
if my_linear_indices == []:
|
| 627 |
+
my_linear_indices = [0]
|
| 628 |
+
print(f"==={torch.distributed.get_rank()}, {my_linear_indices=}")
|
| 629 |
+
decoded_tiles = [] # tiles
|
| 630 |
+
decoded_metas = [] # (ri, rj, pad_w, pad_h)
|
| 631 |
+
H_out_std = self.tile_sample_min_size
|
| 632 |
+
W_out_std = self.tile_sample_min_size
|
| 633 |
+
for lin_idx in my_linear_indices:
|
| 634 |
+
ri = lin_idx // num_cols
|
| 635 |
+
rj = lin_idx % num_cols
|
| 636 |
+
i = ri * overlap_size
|
| 637 |
+
j = rj * overlap_size
|
| 638 |
+
tile = z[
|
| 639 |
+
:,
|
| 640 |
+
:,
|
| 641 |
+
:,
|
| 642 |
+
i : i + self.tile_latent_min_size,
|
| 643 |
+
j : j + self.tile_latent_min_size,
|
| 644 |
+
]
|
| 645 |
+
dec = self.decoder(tile)
|
| 646 |
+
# 对边界 tile 的输出做右/下方向 padding 到标准尺寸
|
| 647 |
+
pad_h = max(0, H_out_std - dec.shape[-2])
|
| 648 |
+
pad_w = max(0, W_out_std - dec.shape[-1])
|
| 649 |
+
if pad_h > 0 or pad_w > 0:
|
| 650 |
+
dec = F.pad(dec, (0, pad_w, 0, pad_h, 0, 0), "constant", 0)
|
| 651 |
+
decoded_tiles.append(dec)
|
| 652 |
+
decoded_metas.append(torch.tensor([ri, rj, pad_w, pad_h], device=z.device, dtype=torch.int64))
|
| 653 |
+
|
| 654 |
+
# 各rank数量不一定相同,进行padding到相同长度
|
| 655 |
+
T_out = decoded_tiles[0].shape[2] if len(decoded_tiles) > 0 else (T-1)*self.ffactor_temporal+1
|
| 656 |
+
while len(decoded_tiles) < tiles_per_rank:
|
| 657 |
+
decoded_tiles.append(torch.zeros([1, 3, T_out, self.tile_sample_min_size, self.tile_sample_min_size], device=z.device, dtype=dec.dtype))
|
| 658 |
+
decoded_metas.append(torch.tensor([-1, -1, self.tile_sample_min_size, self.tile_sample_min_size], device=z.device, dtype=torch.int64))
|
| 659 |
+
|
| 660 |
+
# 进行gpu的all_gather
|
| 661 |
+
decoded_tiles = torch.stack(decoded_tiles, dim=0)
|
| 662 |
+
decoded_metas = torch.stack(decoded_metas, dim=0)
|
| 663 |
+
|
| 664 |
+
tiles_gather_list = [torch.empty_like(decoded_tiles) for _ in range(world_size)]
|
| 665 |
+
metas_gather_list = [torch.empty_like(decoded_metas) for _ in range(world_size)]
|
| 666 |
+
|
| 667 |
+
dist.all_gather(tiles_gather_list, decoded_tiles)
|
| 668 |
+
dist.all_gather(metas_gather_list, decoded_metas)
|
| 669 |
+
|
| 670 |
+
if rank != 0:
|
| 671 |
+
# 非0号rank返回空占位,结果只在rank0上有效
|
| 672 |
+
return torch.empty(0, device=z.device)
|
| 673 |
+
|
| 674 |
+
# rank0:根据 (ri, rj) 元信息重建 tile 网格;跳过占位项 (ri, rj) == (-1, -1)
|
| 675 |
+
rows = [[None for _ in range(num_cols)] for _ in range(num_rows)]
|
| 676 |
+
for r in range(world_size):
|
| 677 |
+
gathered_tiles_r = tiles_gather_list[r] # [tiles_per_rank, B, C, T, H, W]
|
| 678 |
+
gathered_metas_r = metas_gather_list[r] # [tiles_per_rank, 4],元素: (ri, rj, pad_w, pad_h)
|
| 679 |
+
for k in range(gathered_tiles_r.shape[0]):
|
| 680 |
+
ri = int(gathered_metas_r[k][0])
|
| 681 |
+
rj = int(gathered_metas_r[k][1])
|
| 682 |
+
if ri < 0 or rj < 0:
|
| 683 |
+
continue
|
| 684 |
+
if ri < num_rows and rj < num_cols:
|
| 685 |
+
# 去除padding
|
| 686 |
+
pad_w = int(gathered_metas_r[k][2])
|
| 687 |
+
pad_h = int(gathered_metas_r[k][3])
|
| 688 |
+
h_end = None if pad_h == 0 else -pad_h
|
| 689 |
+
w_end = None if pad_w == 0 else -pad_w
|
| 690 |
+
rows[ri][rj] = gathered_tiles_r[k][:, :, :, :h_end, :w_end]
|
| 691 |
+
|
| 692 |
+
result_rows = []
|
| 693 |
+
for i, row in enumerate(rows):
|
| 694 |
+
result_row = []
|
| 695 |
+
for j, tile in enumerate(row):
|
| 696 |
+
if tile is None:
|
| 697 |
+
continue
|
| 698 |
+
if i > 0:
|
| 699 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 700 |
+
if j > 0:
|
| 701 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 702 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 703 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 704 |
+
|
| 705 |
+
dec = torch.cat(result_rows, dim=-2)
|
| 706 |
+
return dec
|
| 707 |
+
|
| 708 |
+
# 单卡:原有串行逻辑
|
| 709 |
+
rows = []
|
| 710 |
+
for i in range(0, H, overlap_size):
|
| 711 |
+
row = []
|
| 712 |
+
for j in range(0, W, overlap_size):
|
| 713 |
+
tile = z[
|
| 714 |
+
:,
|
| 715 |
+
:,
|
| 716 |
+
:,
|
| 717 |
+
i : i + self.tile_latent_min_size,
|
| 718 |
+
j : j + self.tile_latent_min_size,
|
| 719 |
+
]
|
| 720 |
+
decoded = self.decoder(tile)
|
| 721 |
+
row.append(decoded)
|
| 722 |
+
rows.append(row)
|
| 723 |
+
|
| 724 |
+
result_rows = []
|
| 725 |
+
for i, row in enumerate(rows):
|
| 726 |
+
result_row = []
|
| 727 |
+
for j, tile in enumerate(row):
|
| 728 |
+
if i > 0:
|
| 729 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 730 |
+
if j > 0:
|
| 731 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 732 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 733 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 734 |
+
dec = torch.cat(result_rows, dim=-2)
|
| 735 |
+
return dec
|
| 736 |
+
|
| 737 |
+
def temporal_tiled_decode(self, z: torch.Tensor):
|
| 738 |
+
B, C, T, H, W = z.shape
|
| 739 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
| 740 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) # 64 * 0.25 = 16
|
| 741 |
+
t_limit = self.tile_sample_min_tsize - blend_extent # 64 - 16 = 48
|
| 742 |
+
assert 0 < overlap_size < self.tile_latent_min_tsize
|
| 743 |
+
|
| 744 |
+
row = []
|
| 745 |
+
for i in range(0, T, overlap_size):
|
| 746 |
+
tile = z[:, :, i: i + self.tile_latent_min_tsize, :, :]
|
| 747 |
+
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
|
| 748 |
+
decoded = self.spatial_tiled_decode(tile)
|
| 749 |
+
else:
|
| 750 |
+
decoded = self.decoder(tile)
|
| 751 |
+
row.append(decoded)
|
| 752 |
+
|
| 753 |
+
result_row = []
|
| 754 |
+
for i, tile in enumerate(row):
|
| 755 |
+
if i > 0:
|
| 756 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 757 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 758 |
+
dec = torch.cat(result_row, dim=-3)
|
| 759 |
+
return dec
|
| 760 |
+
|
| 761 |
+
def encode(self, x: Tensor, return_dict: bool = True):
|
| 762 |
+
|
| 763 |
+
def _encode(x):
|
| 764 |
+
if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
|
| 765 |
+
return self.temporal_tiled_encode(x)
|
| 766 |
+
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
| 767 |
+
return self.spatial_tiled_encode(x)
|
| 768 |
+
|
| 769 |
+
if self.use_compile:
|
| 770 |
+
@torch.compile
|
| 771 |
+
def encoder(x):
|
| 772 |
+
return self.encoder(x)
|
| 773 |
+
return encoder(x)
|
| 774 |
+
return self.encoder(x)
|
| 775 |
+
|
| 776 |
+
if len(x.shape) != 5: # (B, C, T, H, W)
|
| 777 |
+
x = x[:, :, None]
|
| 778 |
+
assert len(x.shape) == 5 # (B, C, T, H, W)
|
| 779 |
+
if x.shape[2] == 1:
|
| 780 |
+
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
| 781 |
+
else:
|
| 782 |
+
assert x.shape[2] != self.ffactor_temporal and x.shape[2] % self.ffactor_temporal == 0
|
| 783 |
+
|
| 784 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 785 |
+
if self.slicing_bsz == 1:
|
| 786 |
+
encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
|
| 787 |
+
else:
|
| 788 |
+
sections = [self.slicing_bsz] * (x.shape[0] // self.slicing_bsz)
|
| 789 |
+
if x.shape[0] % self.slicing_bsz != 0:
|
| 790 |
+
sections.append(x.shape[0] % self.slicing_bsz)
|
| 791 |
+
encoded_slices = [_encode(x_slice) for x_slice in x.split(sections)]
|
| 792 |
+
h = torch.cat(encoded_slices)
|
| 793 |
+
else:
|
| 794 |
+
h = _encode(x)
|
| 795 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 796 |
+
|
| 797 |
+
if not return_dict:
|
| 798 |
+
return (posterior,)
|
| 799 |
+
|
| 800 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 801 |
+
|
| 802 |
+
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
|
| 803 |
+
|
| 804 |
+
def _decode(z):
|
| 805 |
+
if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
|
| 806 |
+
return self.temporal_tiled_decode(z)
|
| 807 |
+
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
| 808 |
+
return self.spatial_tiled_decode(z)
|
| 809 |
+
return self.decoder(z)
|
| 810 |
+
|
| 811 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 812 |
+
decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
|
| 813 |
+
decoded = torch.cat(decoded_slices)
|
| 814 |
+
else:
|
| 815 |
+
decoded = _decode(z)
|
| 816 |
+
if torch.distributed.is_initialized():
|
| 817 |
+
if torch.distributed.get_rank() != 0:
|
| 818 |
+
return self.empty_cache
|
| 819 |
+
|
| 820 |
+
if z.shape[-3] == 1:
|
| 821 |
+
decoded = decoded[:, :, -1:]
|
| 822 |
+
if not return_dict:
|
| 823 |
+
return (decoded,)
|
| 824 |
+
|
| 825 |
+
return DecoderOutput(sample=decoded)
|
| 826 |
+
|
| 827 |
+
def decode_dist(self, z: Tensor, return_dict: bool = True, generator=None):
|
| 828 |
+
z = z.cuda()
|
| 829 |
+
self.use_spatial_tiling = True
|
| 830 |
+
decoded = self.decode(z)
|
| 831 |
+
self.use_spatial_tiling = False
|
| 832 |
+
return decoded
|
| 833 |
+
|
| 834 |
+
def forward(
|
| 835 |
+
self,
|
| 836 |
+
sample: torch.Tensor,
|
| 837 |
+
sample_posterior: bool = False,
|
| 838 |
+
return_posterior: bool = True,
|
| 839 |
+
return_dict: bool = True
|
| 840 |
+
):
|
| 841 |
+
posterior = self.encode(sample).latent_dist
|
| 842 |
+
z = posterior.sample() if sample_posterior else posterior.mode()
|
| 843 |
+
dec = self.decode(z).sample
|
| 844 |
+
return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior)
|
| 845 |
+
|
| 846 |
+
def random_reset_tiling(self, x: torch.Tensor):
|
| 847 |
+
if x.shape[-3] == 1:
|
| 848 |
+
self.disable_spatial_tiling()
|
| 849 |
+
self.disable_temporal_tiling()
|
| 850 |
+
return
|
| 851 |
+
|
| 852 |
+
# tiling在input_shape和sample_size上限制很多,任意的input_shape和sample_size很可能不满足条件,因此这里使用固定值
|
| 853 |
+
min_sample_size = int(1 / self.tile_overlap_factor) * self.ffactor_spatial
|
| 854 |
+
min_sample_tsize = int(1 / self.tile_overlap_factor) * self.ffactor_temporal
|
| 855 |
+
sample_size = random.choice([None, 1 * min_sample_size, 2 * min_sample_size, 3 * min_sample_size])
|
| 856 |
+
if sample_size is None:
|
| 857 |
+
self.disable_spatial_tiling()
|
| 858 |
+
else:
|
| 859 |
+
self.tile_sample_min_size = sample_size
|
| 860 |
+
self.tile_latent_min_size = sample_size // self.ffactor_spatial
|
| 861 |
+
self.enable_spatial_tiling()
|
| 862 |
+
|
| 863 |
+
sample_tsize = random.choice([None, 1 * min_sample_tsize, 2 * min_sample_tsize, 3 * min_sample_tsize])
|
| 864 |
+
if sample_tsize is None:
|
| 865 |
+
self.disable_temporal_tiling()
|
| 866 |
+
else:
|
| 867 |
+
self.tile_sample_min_tsize = sample_tsize
|
| 868 |
+
self.tile_latent_min_tsize = sample_tsize // self.ffactor_temporal
|
| 869 |
+
self.enable_temporal_tiling()
|
| 870 |
+
|
| 871 |
+
def load_sharded_safetensors(model_dir):
|
| 872 |
+
"""
|
| 873 |
+
手动加载分片的 safetensors 文件
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
model_dir: 包含分片文件的目录路径
|
| 877 |
+
|
| 878 |
+
Returns:
|
| 879 |
+
合并后的完整权重字典
|
| 880 |
+
"""
|
| 881 |
+
# 获取所有分片文件并按编号排序
|
| 882 |
+
shard_files = []
|
| 883 |
+
for file in os.listdir(model_dir):
|
| 884 |
+
if file.endswith(".safetensors"):
|
| 885 |
+
shard_files.append(file)
|
| 886 |
+
|
| 887 |
+
# 按分片编号排序
|
| 888 |
+
shard_files.sort(key=lambda x: int(x.split("-")[1]))
|
| 889 |
+
|
| 890 |
+
print(f"找到 {len(shard_files)} 个分片文件")
|
| 891 |
+
|
| 892 |
+
# 合并所有权重
|
| 893 |
+
merged_state_dict = dict()
|
| 894 |
+
|
| 895 |
+
for shard_file in shard_files:
|
| 896 |
+
shard_path = os.path.join(model_dir, shard_file)
|
| 897 |
+
print(f"加载分片: {shard_file}")
|
| 898 |
+
|
| 899 |
+
# 使用 safetensors 加载当前分片
|
| 900 |
+
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
| 901 |
+
for key in f.keys():
|
| 902 |
+
tensor = f.get_tensor(key)
|
| 903 |
+
merged_state_dict[key] = tensor
|
| 904 |
+
|
| 905 |
+
print(f"合并完成,总键数量: {len(merged_state_dict)}")
|
| 906 |
+
return merged_state_dict
|
| 907 |
+
|
| 908 |
+
def load_weights(model, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
| 909 |
+
def update_state_dict(state_dict: dict[str, torch.Tensor], name, weight):
|
| 910 |
+
if name not in state_dict:
|
| 911 |
+
raise ValueError(f"Unexpected weight {name}")
|
| 912 |
+
|
| 913 |
+
model_tensor = state_dict[name]
|
| 914 |
+
if model_tensor.shape != weight.shape:
|
| 915 |
+
raise ValueError(
|
| 916 |
+
f"Shape mismatch for weight {name}: "
|
| 917 |
+
f"model tensor shape {model_tensor.shape} vs. "
|
| 918 |
+
f"loaded tensor shape {weight.shape}"
|
| 919 |
+
)
|
| 920 |
+
if isinstance(weight, torch.Tensor):
|
| 921 |
+
model_tensor.data.copy_(weight.data)
|
| 922 |
+
else:
|
| 923 |
+
raise ValueError(
|
| 924 |
+
f"Unsupported tensor type in load_weights "
|
| 925 |
+
f"for {name}: {type(weight)}"
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
loaded_params = set()
|
| 929 |
+
for name, load_tensor in weights.items():
|
| 930 |
+
updated = True
|
| 931 |
+
name = name.replace('vae.', '')
|
| 932 |
+
if name in model.state_dict():
|
| 933 |
+
update_state_dict(model.state_dict(), name, load_tensor)
|
| 934 |
+
else:
|
| 935 |
+
updated = False
|
| 936 |
+
|
| 937 |
+
if updated:
|
| 938 |
+
loaded_params.add(name)
|
| 939 |
+
|
| 940 |
+
return loaded_params
|
| 941 |
+
|
| 942 |
+
def _worker(path, config,
|
| 943 |
+
rank=None, world_size=None, port=None, req_queue=None, rsp_queue=None):
|
| 944 |
+
"""
|
| 945 |
+
each rank's worker:
|
| 946 |
+
- idle: block on req_queue.get() (CPU blocking, no GPU)
|
| 947 |
+
- receive request: run runner.predict(), all ranks forward
|
| 948 |
+
- only rank0 put result to rsp_queue
|
| 949 |
+
"""
|
| 950 |
+
# _tame_cpu_threads_and_comm()
|
| 951 |
+
# basic env
|
| 952 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 953 |
+
os.environ["MASTER_PORT"] = str(port)
|
| 954 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 955 |
+
os.environ["RANK"] = str(rank)
|
| 956 |
+
os.environ["LOCAL_RANK"] = str(rank)
|
| 957 |
+
|
| 958 |
+
# device binding should be early than all CUDA operations
|
| 959 |
+
visible = torch.cuda.device_count()
|
| 960 |
+
assert visible >= world_size, f"可见卡数 {visible} < world_size {world_size}"
|
| 961 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 962 |
+
|
| 963 |
+
print(f"[worker {rank}] bind to cuda:{local_rank} (visible={visible})", flush=True)
|
| 964 |
+
if not torch.distributed.is_initialized():
|
| 965 |
+
dist.init_process_group("nccl")
|
| 966 |
+
torch.cuda.set_device(local_rank)
|
| 967 |
+
#from .. import load_vae
|
| 968 |
+
|
| 969 |
+
#vae = load_vae(vae_type, vae_precision, device, logger, args, weights_only, only_encoder, only_decoder, sample_size, skip_create_dist=True)
|
| 970 |
+
#vae = vae.cuda()
|
| 971 |
+
vae = AutoencoderKLConv3D.from_config(config)
|
| 972 |
+
merged_state_dict = load_sharded_safetensors(path)
|
| 973 |
+
loaded_params = load_weights(vae, merged_state_dict)
|
| 974 |
+
vae = vae.cuda()
|
| 975 |
+
vae.eval() # 关闭 Dropout、BatchNorm 训练行为
|
| 976 |
+
for param in vae.parameters():
|
| 977 |
+
param.requires_grad = False #
|
| 978 |
+
|
| 979 |
+
while True:
|
| 980 |
+
req = req_queue.get() # blocking
|
| 981 |
+
if req == "__STOP__":
|
| 982 |
+
break
|
| 983 |
+
out = vae.decode_dist(req, return_dict=False)
|
| 984 |
+
if rank == 0:
|
| 985 |
+
rsp_queue.put(out)
|
| 986 |
+
|
| 987 |
+
#try:
|
| 988 |
+
# while True:
|
| 989 |
+
# # blocking on CPU queue
|
| 990 |
+
# req = req_queue.get() # blocking
|
| 991 |
+
# if req == "__STOP__":
|
| 992 |
+
# break
|
| 993 |
+
# out = vae.decode_dist(req, return_dict=False)
|
| 994 |
+
# if rank == 0:
|
| 995 |
+
# rsp_queue.put(out)
|
| 996 |
+
#finally:
|
| 997 |
+
# # destroy process group before exit
|
| 998 |
+
# try:
|
| 999 |
+
# dist.destroy_process_group()
|
| 1000 |
+
# except Exception:
|
| 1001 |
+
# pass
|
| 1002 |
+
|
| 1003 |
+
#def _find_free_port():
|
| 1004 |
+
# import socket
|
| 1005 |
+
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 1006 |
+
# s.bind(("127.0.0.1", 0))
|
| 1007 |
+
# return s.getsockname()[1]
|
| 1008 |
+
|
| 1009 |
+
# 避免端口冲突的常见做法
|
| 1010 |
+
def _find_free_port(start_port=8100, max_attempts=900):
|
| 1011 |
+
import socket
|
| 1012 |
+
"""获取一个可用的端口"""
|
| 1013 |
+
for port in range(start_port, start_port + max_attempts):
|
| 1014 |
+
try:
|
| 1015 |
+
with socket.socket() as s:
|
| 1016 |
+
s.bind(('localhost', port))
|
| 1017 |
+
return s.getsockname()[1] # 返回实际绑定的端口
|
| 1018 |
+
except OSError:
|
| 1019 |
+
continue
|
| 1020 |
+
raise RuntimeError("找不到可用端口")
|
| 1021 |
+
|
| 1022 |
+
class AutoencoderKLConv3D_Dist(AutoencoderKLConv3D):
|
| 1023 |
+
def __init__(
|
| 1024 |
+
self,
|
| 1025 |
+
in_channels: int,
|
| 1026 |
+
out_channels: int,
|
| 1027 |
+
latent_channels: int,
|
| 1028 |
+
block_out_channels: Tuple[int, ...],
|
| 1029 |
+
layers_per_block: int,
|
| 1030 |
+
ffactor_spatial: int,
|
| 1031 |
+
ffactor_temporal: int,
|
| 1032 |
+
sample_size: int,
|
| 1033 |
+
sample_tsize: int,
|
| 1034 |
+
scaling_factor: float = None,
|
| 1035 |
+
shift_factor: Optional[float] = None,
|
| 1036 |
+
downsample_match_channel: bool = True,
|
| 1037 |
+
upsample_match_channel: bool = True,
|
| 1038 |
+
only_encoder: bool = False,
|
| 1039 |
+
only_decoder: bool = False,
|
| 1040 |
+
):
|
| 1041 |
+
super().__init__(in_channels, out_channels, latent_channels, block_out_channels, layers_per_block, ffactor_spatial, ffactor_temporal, sample_size, sample_tsize, scaling_factor, shift_factor, downsample_match_channel, upsample_match_channel, only_encoder, only_decoder)
|
| 1042 |
+
|
| 1043 |
+
def create_dist(self, path, config,
|
| 1044 |
+
):
|
| 1045 |
+
self.world_size = 8
|
| 1046 |
+
self.port = _find_free_port()
|
| 1047 |
+
ctx = mp.get_context("spawn")
|
| 1048 |
+
# 每个 rank 一个请求队列(纯 CPU),再加一个公共响应队列
|
| 1049 |
+
self.req_queues = [ctx.Queue() for _ in range(self.world_size)]
|
| 1050 |
+
self.rsp_queue = ctx.Queue()
|
| 1051 |
+
|
| 1052 |
+
self.procs = []
|
| 1053 |
+
for rank in range(self.world_size):
|
| 1054 |
+
p = ctx.Process(
|
| 1055 |
+
target=_worker,
|
| 1056 |
+
args=(
|
| 1057 |
+
path, config,
|
| 1058 |
+
rank, self.world_size, self.port,
|
| 1059 |
+
self.req_queues[rank], self.rsp_queue,
|
| 1060 |
+
),
|
| 1061 |
+
daemon=True,
|
| 1062 |
+
)
|
| 1063 |
+
p.start()
|
| 1064 |
+
self.procs.append(p)
|
| 1065 |
+
|
| 1066 |
+
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
|
| 1067 |
+
"""
|
| 1068 |
+
synchronous inference: put the same request to all ranks' queues.
|
| 1069 |
+
return rank0's result.
|
| 1070 |
+
"""
|
| 1071 |
+
# check alive
|
| 1072 |
+
for p in self.procs:
|
| 1073 |
+
if not p.is_alive():
|
| 1074 |
+
raise RuntimeError("One of the processes is not alive")
|
| 1075 |
+
|
| 1076 |
+
# put to each rank's queue
|
| 1077 |
+
for q in self.req_queues:
|
| 1078 |
+
q.put(z)
|
| 1079 |
+
|
| 1080 |
+
# wait for rank0's result
|
| 1081 |
+
return self.rsp_queue.get(timeout=None)
|
cache_utils.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
def cache_init(cache_interval, max_order, num_steps=None,
|
| 7 |
+
enable_first_enhance=False, first_enhance_steps=3,
|
| 8 |
+
enable_tailing_enhance=False, tailing_enhance_steps=1,
|
| 9 |
+
low_freqs_order=0, high_freqs_order=2):
|
| 10 |
+
cache_dic = {}
|
| 11 |
+
cache_dic['counter']= 0
|
| 12 |
+
cache_dic['current_step'] = 0
|
| 13 |
+
cache_dic['cache_interval']= cache_interval
|
| 14 |
+
cache_dic['max_order'] = max_order
|
| 15 |
+
cache_dic['num_steps'] = num_steps
|
| 16 |
+
|
| 17 |
+
# enhance related utils
|
| 18 |
+
|
| 19 |
+
# first enhance: fully compute first some steps, enhancing contour infos
|
| 20 |
+
cache_dic['enable_first_enhance'] = enable_first_enhance
|
| 21 |
+
cache_dic['first_enhance_steps'] = first_enhance_steps
|
| 22 |
+
|
| 23 |
+
# tailing enhance: fully compute the last 1 steps, enhancing details
|
| 24 |
+
cache_dic['enable_tailing_enhance'] = enable_tailing_enhance
|
| 25 |
+
cache_dic['tailing_enhance_steps'] = tailing_enhance_steps
|
| 26 |
+
|
| 27 |
+
# freqs related utils
|
| 28 |
+
cache_dic['low_freqs_order'] = low_freqs_order
|
| 29 |
+
cache_dic['high_freqs_order'] = high_freqs_order
|
| 30 |
+
|
| 31 |
+
# features for training-aware cache, here we don't use these
|
| 32 |
+
cache_dic['enable_force_control']= False
|
| 33 |
+
cache_dic['force_compute']=False
|
| 34 |
+
return cache_dic
|
| 35 |
+
|
| 36 |
+
class TaylorCacheContainer(nn.Module):
|
| 37 |
+
def __init__(self, max_order):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.max_order = max_order
|
| 40 |
+
# 逐个注册buffer
|
| 41 |
+
for i in range(max_order + 1):
|
| 42 |
+
self.register_buffer(f"derivative_{i}", None, persistent=False)
|
| 43 |
+
self.register_buffer(f"temp_derivative_{i}", None, persistent=False)
|
| 44 |
+
|
| 45 |
+
def get_derivative(self, order):
|
| 46 |
+
return getattr(self, f"derivative_{order}")
|
| 47 |
+
|
| 48 |
+
def set_derivative(self, order, tensor):
|
| 49 |
+
setattr(self, f"derivative_{order}", tensor)
|
| 50 |
+
|
| 51 |
+
def set_temp_derivative(self, order, tensor):
|
| 52 |
+
setattr(self, f"temp_derivative_{order}", tensor)
|
| 53 |
+
|
| 54 |
+
def get_temp_derivative(self, order):
|
| 55 |
+
return getattr(self, f"temp_derivative_{order}")
|
| 56 |
+
|
| 57 |
+
def clear_temp_derivative(self):
|
| 58 |
+
for i in range(self.max_order + 1):
|
| 59 |
+
setattr(self, f"temp_derivative_{i}", None)
|
| 60 |
+
|
| 61 |
+
def move_temp_to_derivative(self):
|
| 62 |
+
for i in range(self.max_order + 1):
|
| 63 |
+
if self.get_temp_derivative(i) is not None:
|
| 64 |
+
setattr(self, f"derivative_{i}", self.get_temp_derivative(i))
|
| 65 |
+
else:
|
| 66 |
+
break
|
| 67 |
+
self.clear_temp_derivative()
|
| 68 |
+
|
| 69 |
+
def get_all_derivatives(self):
|
| 70 |
+
return [getattr(self, f"derivative_{i}") for i in range(self.max_order + 1)]
|
| 71 |
+
|
| 72 |
+
def get_all_filled_derivatives(self):
|
| 73 |
+
return [self.get_derivative(i) for i in range(self.max_order + 1) if self.get_derivative(i) is not None]
|
| 74 |
+
|
| 75 |
+
def taylor_formula(self, distance):
|
| 76 |
+
output = 0
|
| 77 |
+
for i in range(len(self.get_all_filled_derivatives())):
|
| 78 |
+
output += (1 / math.factorial(i)) * self.get_derivative(i) * (distance ** i)
|
| 79 |
+
return output
|
| 80 |
+
|
| 81 |
+
def derivatives_computation(self, x, distance):
|
| 82 |
+
'''
|
| 83 |
+
x: tensor, the new x_0
|
| 84 |
+
distance: int, the distance between the current step and the last full computation step
|
| 85 |
+
'''
|
| 86 |
+
self.set_temp_derivative(0, x)
|
| 87 |
+
for i in range(self.max_order):
|
| 88 |
+
if self.get_derivative(i) is not None:
|
| 89 |
+
self.set_temp_derivative(i+1, (self.get_temp_derivative(i) - self.get_derivative(i)) / distance)
|
| 90 |
+
else:
|
| 91 |
+
break
|
| 92 |
+
self.move_temp_to_derivative()
|
| 93 |
+
|
| 94 |
+
def clear_derivatives(self):
|
| 95 |
+
for i in range(self.max_order + 1):
|
| 96 |
+
setattr(self, f"derivative_{i}", None)
|
| 97 |
+
setattr(self, f"temp_derivative_{i}", None)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@torch.compile
|
| 101 |
+
def decomposition_FFT(x: torch.Tensor, cutoff_ratio: float = 0.1) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 102 |
+
"""
|
| 103 |
+
Fast Fourier Transform frequency domain decomposition
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
x: Input tensor [B, H*W, D]
|
| 107 |
+
cutoff_ratio: Cutoff frequency ratio (0~0.5)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Tuple of (low_freq, high_freq) tensors with same dtype as input
|
| 111 |
+
"""
|
| 112 |
+
orig_dtype = x.dtype
|
| 113 |
+
device = x.device
|
| 114 |
+
|
| 115 |
+
x_fp32 = x.to(torch.float32) # Convert to fp32 for FFT compatibility
|
| 116 |
+
|
| 117 |
+
B, HW, D = x_fp32.shape
|
| 118 |
+
freq = torch.fft.fft(x_fp32, dim=1) # FFT on spatial dimension
|
| 119 |
+
|
| 120 |
+
freqs = torch.fft.fftfreq(HW, d=1.0, device=device)
|
| 121 |
+
cutoff = cutoff_ratio * freqs.abs().max()
|
| 122 |
+
|
| 123 |
+
# Create frequency masks
|
| 124 |
+
low_mask = freqs.abs() <= cutoff
|
| 125 |
+
high_mask = ~low_mask
|
| 126 |
+
|
| 127 |
+
low_mask = low_mask[None, :, None] # Broadcast to (B, HW, D)
|
| 128 |
+
high_mask = high_mask[None, :, None]
|
| 129 |
+
|
| 130 |
+
low_freq_complex = freq * low_mask
|
| 131 |
+
high_freq_complex = freq * high_mask
|
| 132 |
+
|
| 133 |
+
# IFFT and take real part
|
| 134 |
+
low_fp32 = torch.fft.ifft(low_freq_complex, dim=1).real
|
| 135 |
+
high_fp32 = torch.fft.ifft(high_freq_complex, dim=1).real
|
| 136 |
+
|
| 137 |
+
low = low_fp32.to(device=device, dtype=orig_dtype)
|
| 138 |
+
high = high_fp32.to(device=device, dtype=orig_dtype)
|
| 139 |
+
|
| 140 |
+
return low, high
|
| 141 |
+
|
| 142 |
+
@torch.compile
|
| 143 |
+
def reconstruction(low_freq: torch.Tensor, high_freq: torch.Tensor) -> torch.Tensor:
|
| 144 |
+
return low_freq + high_freq
|
| 145 |
+
|
| 146 |
+
class CacheWithFreqsContainer(nn.Module):
|
| 147 |
+
def __init__(self, max_order):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.max_order = max_order
|
| 150 |
+
# 逐个注册buffer
|
| 151 |
+
for i in range(max_order + 1):
|
| 152 |
+
self.register_buffer(f"derivative_{i}_low_freqs", None, persistent=False)
|
| 153 |
+
self.register_buffer(f"derivative_{i}_high_freqs", None, persistent=False)
|
| 154 |
+
self.register_buffer(f"temp_derivative_{i}_low_freqs", None, persistent=False)
|
| 155 |
+
self.register_buffer(f"temp_derivative_{i}_high_freqs", None, persistent=False)
|
| 156 |
+
|
| 157 |
+
def get_derivative(self, order, freqs):
|
| 158 |
+
return getattr(self, f"derivative_{order}_{freqs}")
|
| 159 |
+
|
| 160 |
+
def set_derivative(self, order, freqs, tensor):
|
| 161 |
+
setattr(self, f"derivative_{order}_{freqs}", tensor)
|
| 162 |
+
|
| 163 |
+
def set_temp_derivative(self, order, freqs, tensor):
|
| 164 |
+
setattr(self, f"temp_derivative_{order}_{freqs}", tensor)
|
| 165 |
+
|
| 166 |
+
def get_temp_derivative(self, order, freqs):
|
| 167 |
+
return getattr(self, f"temp_derivative_{order}_{freqs}")
|
| 168 |
+
|
| 169 |
+
def move_temp_to_derivative(self):
|
| 170 |
+
for i in range(self.max_order + 1):
|
| 171 |
+
if self.get_temp_derivative(i, "low_freqs") is not None:
|
| 172 |
+
setattr(self, f"derivative_{i}_low_freqs", self.get_temp_derivative(i, "low_freqs"))
|
| 173 |
+
if self.get_temp_derivative(i, "high_freqs") is not None:
|
| 174 |
+
setattr(self, f"derivative_{i}_high_freqs", self.get_temp_derivative(i, "high_freqs"))
|
| 175 |
+
else:
|
| 176 |
+
break
|
| 177 |
+
self.clear_temp_derivative()
|
| 178 |
+
|
| 179 |
+
def get_all_filled_derivatives(self, freqs):
|
| 180 |
+
return [self.get_derivative(i, freqs) for i in range(self.max_order + 1) if self.get_derivative(i, freqs) is not None]
|
| 181 |
+
|
| 182 |
+
def taylor_formula(self, distance):
|
| 183 |
+
low_freqs_output = 0
|
| 184 |
+
high_freqs_output = 0
|
| 185 |
+
for i in range(len(self.get_all_filled_derivatives("low_freqs"))):
|
| 186 |
+
low_freqs_output += (1 / math.factorial(i)) * self.get_derivative(i, "low_freqs") * (distance ** i)
|
| 187 |
+
for i in range(len(self.get_all_filled_derivatives("high_freqs"))):
|
| 188 |
+
high_freqs_output += (1 / math.factorial(i)) * self.get_derivative(i, "high_freqs") * (distance ** i)
|
| 189 |
+
return reconstruction(low_freqs_output, high_freqs_output)
|
| 190 |
+
|
| 191 |
+
def hermite_formula(self, distance):
|
| 192 |
+
low_freqs_output = 0
|
| 193 |
+
high_freqs_output = 0
|
| 194 |
+
for i in range(len(self.get_all_filled_derivatives("low_freqs"))):
|
| 195 |
+
low_freqs_output += (1 / math.factorial(i)) * self.get_derivative(i, "low_freqs") * (distance ** i)
|
| 196 |
+
for i in range(len(self.get_all_filled_derivatives("high_freqs"))):
|
| 197 |
+
high_freqs_output += (1 / math.factorial(i)) * self.get_derivative(i, "high_freqs") * (distance ** i)
|
| 198 |
+
return reconstruction(low_freqs_output, high_freqs_output)
|
| 199 |
+
|
| 200 |
+
def derivatives_computation(self, x, distance, low_freqs_order, high_freqs_order):
|
| 201 |
+
'''
|
| 202 |
+
x: tensor, the new x_0
|
| 203 |
+
distance: int, the distance between the current step and the last full computation step
|
| 204 |
+
'''
|
| 205 |
+
x_low, x_high = decomposition_FFT(x, cutoff_ratio=0.1)
|
| 206 |
+
self.set_temp_derivative(0, "low_freqs", x_low)
|
| 207 |
+
self.set_temp_derivative(0, "high_freqs", x_high)
|
| 208 |
+
for i in range(low_freqs_order):
|
| 209 |
+
if self.get_derivative(i, "low_freqs") is not None:
|
| 210 |
+
self.set_temp_derivative(i+1, "low_freqs", (self.get_temp_derivative(i, "low_freqs") - self.get_derivative(i, "low_freqs")) / distance)
|
| 211 |
+
for i in range(high_freqs_order):
|
| 212 |
+
if self.get_derivative(i, "high_freqs") is not None:
|
| 213 |
+
self.set_temp_derivative(i+1, "high_freqs", (self.get_temp_derivative(i, "high_freqs") - self.get_derivative(i, "high_freqs")) / distance)
|
| 214 |
+
self.move_temp_to_derivative()
|
| 215 |
+
|
| 216 |
+
def clear_temp_derivative(self):
|
| 217 |
+
for i in range(self.max_order + 1):
|
| 218 |
+
setattr(self, f"temp_derivative_{i}_low_freqs", None)
|
| 219 |
+
setattr(self, f"temp_derivative_{i}_high_freqs", None)
|
| 220 |
+
|
| 221 |
+
def clear_derivatives(self):
|
| 222 |
+
for i in range(self.max_order + 1):
|
| 223 |
+
setattr(self, f"derivative_{i}_low_freqs", None)
|
| 224 |
+
setattr(self, f"derivative_{i}_high_freqs", None)
|
| 225 |
+
setattr(self, f"temp_derivative_{i}_low_freqs", None)
|
| 226 |
+
setattr(self, f"temp_derivative_{i}_high_freqs", None)
|
config.json
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_classification_head": false,
|
| 3 |
+
"anyres_pooling_size": 2,
|
| 4 |
+
"anyres_vit_max_image_size": null,
|
| 5 |
+
"anyres_vit_two_views": false,
|
| 6 |
+
"architectures": [
|
| 7 |
+
"HunyuanImage3ForCausalMM"
|
| 8 |
+
],
|
| 9 |
+
"auto_map": {
|
| 10 |
+
"AutoConfig": "configuration_hunyuan_image_3.HunyuanImage3Config",
|
| 11 |
+
"AutoModel": "modeling_hunyuan_image_3.HunyuanImage3Model",
|
| 12 |
+
"AutoModelForCausalLM": "modeling_hunyuan_image_3.HunyuanImage3ForCausalMM"
|
| 13 |
+
},
|
| 14 |
+
"attention_bias": false,
|
| 15 |
+
"attention_dropout": 0.0,
|
| 16 |
+
"attention_head_dim": 128,
|
| 17 |
+
"bos_token_id": 127958,
|
| 18 |
+
"cla_share_factor": 2,
|
| 19 |
+
"class_num": 0,
|
| 20 |
+
"dense_list": [
|
| 21 |
+
4096,
|
| 22 |
+
0
|
| 23 |
+
],
|
| 24 |
+
"eod_token_id": 3,
|
| 25 |
+
"eos_token_id": 127957,
|
| 26 |
+
"group_limited_greedy": false,
|
| 27 |
+
"hidden_act": "silu",
|
| 28 |
+
"hidden_size": 4096,
|
| 29 |
+
"im_end_id": 128001,
|
| 30 |
+
"im_newline_id": 11,
|
| 31 |
+
"im_start_id": 128000,
|
| 32 |
+
"image_token_id": 128006,
|
| 33 |
+
"initializer_range": 0.02,
|
| 34 |
+
"intermediate_size": 3072,
|
| 35 |
+
"kv_lora_rank": null,
|
| 36 |
+
"mask_init_id": 12,
|
| 37 |
+
"max_position_embeddings": 22800,
|
| 38 |
+
"mlp_bias": false,
|
| 39 |
+
"model_type": "hunyuan_image_3_moe",
|
| 40 |
+
"moe_drop_tokens": false,
|
| 41 |
+
"moe_intermediate_size": [
|
| 42 |
+
3072,
|
| 43 |
+
3072,
|
| 44 |
+
3072,
|
| 45 |
+
3072,
|
| 46 |
+
3072,
|
| 47 |
+
3072,
|
| 48 |
+
3072,
|
| 49 |
+
3072,
|
| 50 |
+
3072,
|
| 51 |
+
3072,
|
| 52 |
+
3072,
|
| 53 |
+
3072,
|
| 54 |
+
3072,
|
| 55 |
+
3072,
|
| 56 |
+
3072,
|
| 57 |
+
3072,
|
| 58 |
+
3072,
|
| 59 |
+
3072,
|
| 60 |
+
3072,
|
| 61 |
+
3072,
|
| 62 |
+
3072,
|
| 63 |
+
3072,
|
| 64 |
+
3072,
|
| 65 |
+
3072,
|
| 66 |
+
3072,
|
| 67 |
+
3072,
|
| 68 |
+
3072,
|
| 69 |
+
3072,
|
| 70 |
+
3072,
|
| 71 |
+
3072,
|
| 72 |
+
3072,
|
| 73 |
+
3072
|
| 74 |
+
],
|
| 75 |
+
"moe_layer_num_skipped": 0,
|
| 76 |
+
"moe_random_routing_dropped_token": false,
|
| 77 |
+
"moe_topk": [
|
| 78 |
+
8,
|
| 79 |
+
8,
|
| 80 |
+
8,
|
| 81 |
+
8,
|
| 82 |
+
8,
|
| 83 |
+
8,
|
| 84 |
+
8,
|
| 85 |
+
8,
|
| 86 |
+
8,
|
| 87 |
+
8,
|
| 88 |
+
8,
|
| 89 |
+
8,
|
| 90 |
+
8,
|
| 91 |
+
8,
|
| 92 |
+
8,
|
| 93 |
+
8,
|
| 94 |
+
8,
|
| 95 |
+
8,
|
| 96 |
+
8,
|
| 97 |
+
8,
|
| 98 |
+
8,
|
| 99 |
+
8,
|
| 100 |
+
8,
|
| 101 |
+
8,
|
| 102 |
+
8,
|
| 103 |
+
8,
|
| 104 |
+
8,
|
| 105 |
+
8,
|
| 106 |
+
8,
|
| 107 |
+
8,
|
| 108 |
+
8,
|
| 109 |
+
8
|
| 110 |
+
],
|
| 111 |
+
"n_group": false,
|
| 112 |
+
"norm_topk_prob": true,
|
| 113 |
+
"norm_type": "rms",
|
| 114 |
+
"num_attention_heads": 32,
|
| 115 |
+
"num_experts": 64,
|
| 116 |
+
"num_hidden_layers": 32,
|
| 117 |
+
"num_key_value_heads": 8,
|
| 118 |
+
"num_media_embeds": 257,
|
| 119 |
+
"num_shared_expert": [
|
| 120 |
+
1,
|
| 121 |
+
1,
|
| 122 |
+
1,
|
| 123 |
+
1,
|
| 124 |
+
1,
|
| 125 |
+
1,
|
| 126 |
+
1,
|
| 127 |
+
1,
|
| 128 |
+
1,
|
| 129 |
+
1,
|
| 130 |
+
1,
|
| 131 |
+
1,
|
| 132 |
+
1,
|
| 133 |
+
1,
|
| 134 |
+
1,
|
| 135 |
+
1,
|
| 136 |
+
1,
|
| 137 |
+
1,
|
| 138 |
+
1,
|
| 139 |
+
1,
|
| 140 |
+
1,
|
| 141 |
+
1,
|
| 142 |
+
1,
|
| 143 |
+
1,
|
| 144 |
+
1,
|
| 145 |
+
1,
|
| 146 |
+
1,
|
| 147 |
+
1,
|
| 148 |
+
1,
|
| 149 |
+
1,
|
| 150 |
+
1,
|
| 151 |
+
1
|
| 152 |
+
],
|
| 153 |
+
"pad_id": 128009,
|
| 154 |
+
"pad_token_id": 128009,
|
| 155 |
+
"pool_type": "last",
|
| 156 |
+
"position_embedding_xdrope": false,
|
| 157 |
+
"pretraining_tp": 1,
|
| 158 |
+
"q_lora_rank": null,
|
| 159 |
+
"qk_nope_head_dim": null,
|
| 160 |
+
"qk_rope_head_dim": null,
|
| 161 |
+
"rms_norm_eps": 1e-05,
|
| 162 |
+
"rope_scaling": {
|
| 163 |
+
"alpha": 1.0,
|
| 164 |
+
"beta_fast": 32,
|
| 165 |
+
"beta_slow": 1,
|
| 166 |
+
"factor": 1.0,
|
| 167 |
+
"mscale": 1.0,
|
| 168 |
+
"mscale_all_dim": 1.0,
|
| 169 |
+
"type": "custom"
|
| 170 |
+
},
|
| 171 |
+
"rope_theta": 10000.0,
|
| 172 |
+
"routed_scaling_factor": false,
|
| 173 |
+
"skip_cls_token": false,
|
| 174 |
+
"text_end_id": 7,
|
| 175 |
+
"text_start_id": 6,
|
| 176 |
+
"tie_word_embeddings": false,
|
| 177 |
+
"topk_group": false,
|
| 178 |
+
"torch_dtype": "bfloat16",
|
| 179 |
+
"transformers_version": "4.50.0",
|
| 180 |
+
"use_cache": true,
|
| 181 |
+
"use_cla": false,
|
| 182 |
+
"use_mixed_mlp_moe": true,
|
| 183 |
+
"use_mla": false,
|
| 184 |
+
"use_qk_norm": true,
|
| 185 |
+
"use_rotary_pos_emb": true,
|
| 186 |
+
"v_head_dim": null,
|
| 187 |
+
"video_end_id": 10,
|
| 188 |
+
"video_start_id": 9,
|
| 189 |
+
"vit_add_patchemb_bias": false,
|
| 190 |
+
"vit_input_resolution": 224,
|
| 191 |
+
"vit_mapping_type": "resampler",
|
| 192 |
+
"vit_norm_type": "fused",
|
| 193 |
+
"vit_patch": 1,
|
| 194 |
+
"vit_path": null,
|
| 195 |
+
"vit_remove_prenorm": false,
|
| 196 |
+
"vit_token": 64,
|
| 197 |
+
"vit_type": "siglip2-so400m-patch16-naflex",
|
| 198 |
+
"vit_used_rms_norm": false,
|
| 199 |
+
"vocab_size": 133120,
|
| 200 |
+
"xdrope_section": null,
|
| 201 |
+
"head_dim": 128,
|
| 202 |
+
"rope_type": "2d",
|
| 203 |
+
"vae_downsample_factor": [
|
| 204 |
+
16,
|
| 205 |
+
16
|
| 206 |
+
],
|
| 207 |
+
"vit_downsample_factor": [
|
| 208 |
+
16,
|
| 209 |
+
16
|
| 210 |
+
],
|
| 211 |
+
"cond_token_attn_type": "joint_full",
|
| 212 |
+
"cond_image_type": "vae_vit",
|
| 213 |
+
"vae_type": "hunyuan-image-vae-v1",
|
| 214 |
+
"vae_dtype": "float32",
|
| 215 |
+
"vae_autocast_dtype": "float16",
|
| 216 |
+
"vae": {
|
| 217 |
+
"_class_name": "AutoencoderKLConv3D",
|
| 218 |
+
"block_out_channels": [
|
| 219 |
+
128,
|
| 220 |
+
256,
|
| 221 |
+
512,
|
| 222 |
+
1024,
|
| 223 |
+
1024
|
| 224 |
+
],
|
| 225 |
+
"in_channels": 3,
|
| 226 |
+
"out_channels": 3,
|
| 227 |
+
"latent_channels": 32,
|
| 228 |
+
"layers_per_block": 2,
|
| 229 |
+
"ffactor_spatial": 16,
|
| 230 |
+
"ffactor_temporal": 4,
|
| 231 |
+
"sample_size": 384,
|
| 232 |
+
"sample_tsize": 96,
|
| 233 |
+
"downsample_match_channel": true,
|
| 234 |
+
"upsample_match_channel": true,
|
| 235 |
+
"scaling_factor": 0.562679178327931
|
| 236 |
+
},
|
| 237 |
+
"vit": {
|
| 238 |
+
"_attn_implementation": "sdpa",
|
| 239 |
+
"attention_dropout": 0.0,
|
| 240 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 241 |
+
"hidden_size": 1152,
|
| 242 |
+
"intermediate_size": 4304,
|
| 243 |
+
"layer_norm_eps": 1e-06,
|
| 244 |
+
"num_attention_heads": 16,
|
| 245 |
+
"num_channels": 3,
|
| 246 |
+
"num_hidden_layers": 27,
|
| 247 |
+
"num_patches": 256,
|
| 248 |
+
"patch_size": 16,
|
| 249 |
+
"torch_dtype": "float32",
|
| 250 |
+
"output_attentions": false,
|
| 251 |
+
"output_hidden_states": false,
|
| 252 |
+
"use_return_dict": true
|
| 253 |
+
},
|
| 254 |
+
"vit_processor": {
|
| 255 |
+
"do_convert_rgb": null,
|
| 256 |
+
"do_normalize": true,
|
| 257 |
+
"do_rescale": true,
|
| 258 |
+
"do_resize": true,
|
| 259 |
+
"image_mean": [
|
| 260 |
+
0.5,
|
| 261 |
+
0.5,
|
| 262 |
+
0.5
|
| 263 |
+
],
|
| 264 |
+
"image_processor_type": "Siglip2ImageProcessorFast",
|
| 265 |
+
"image_std": [
|
| 266 |
+
0.5,
|
| 267 |
+
0.5,
|
| 268 |
+
0.5
|
| 269 |
+
],
|
| 270 |
+
"max_num_patches": 1024,
|
| 271 |
+
"patch_size": 16,
|
| 272 |
+
"processor_class": "Siglip2Processor",
|
| 273 |
+
"resample": 2,
|
| 274 |
+
"rescale_factor": 0.00392156862745098
|
| 275 |
+
},
|
| 276 |
+
"vit_aligner": {
|
| 277 |
+
"projector_type": "mlp_gelu",
|
| 278 |
+
"input_dim": 1152,
|
| 279 |
+
"n_embed": 4096,
|
| 280 |
+
"depth": 2,
|
| 281 |
+
"torch_dtype": "float32"
|
| 282 |
+
}
|
| 283 |
+
}
|
configuration_hunyuan_image_3.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
|
| 14 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 15 |
+
from transformers.utils import logging
|
| 16 |
+
from typing import List, Union, Optional
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class HunyuanImage3Config(PretrainedConfig):
|
| 23 |
+
r"""
|
| 24 |
+
This is the configuration class to store the configuration of a [`HunyuanImage3Model`]. It is used to instantiate
|
| 25 |
+
an Hunyuan model according to the specified arguments, defining the model architecture. Instantiating a
|
| 26 |
+
configuration with the defaults will yield a similar configuration to that of the Hunyuan-7B.
|
| 27 |
+
|
| 28 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 29 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 34 |
+
Vocabulary size of the Hunyuan Image 3 model. Defines the number of different tokens that can be
|
| 35 |
+
represented by the `inputs_ids` passed when calling [`HunyuanImage3Model`]
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 37 |
+
Dimension of the hidden representations.
|
| 38 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 39 |
+
Dimension of the MLP representations or shared MLP representations.
|
| 40 |
+
moe_intermediate_size (`int` or `List`, *optional*, defaults to 11008):
|
| 41 |
+
Dimension of the MLP representations in MoE. Use a list if you want a different size per layer.
|
| 42 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 43 |
+
Number of hidden layers in the Transformer decoder.
|
| 44 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 45 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 46 |
+
num_key_value_heads (`int`, *optional*):
|
| 47 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 48 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 49 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 50 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 51 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 52 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 53 |
+
`num_attention_heads`.
|
| 54 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the decoder.
|
| 56 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 57 |
+
The maximum sequence length that this model might ever be used with.
|
| 58 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 59 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 60 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 61 |
+
The epsilon used by the rms normalization layers.
|
| 62 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 64 |
+
relevant if `config.is_decoder=True`.
|
| 65 |
+
pad_token_id (`int`, *optional*):
|
| 66 |
+
Padding token id.
|
| 67 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 68 |
+
Beginning of stream token id.
|
| 69 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 70 |
+
End of stream token id.
|
| 71 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 72 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 73 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
| 74 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
| 75 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 76 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 77 |
+
Whether to tie weight embeddings
|
| 78 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 79 |
+
The base period of the RoPE embeddings.
|
| 80 |
+
rope_scaling (`Dict`, *optional*):
|
| 81 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 82 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 83 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 84 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 85 |
+
these scaling strategies behave:
|
| 86 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 87 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 88 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 89 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 90 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 91 |
+
The dropout ratio for the attention probabilities.
|
| 92 |
+
use_qk_norm (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Whether query and key in attention use norm
|
| 94 |
+
use_cla (`bool`, *optional*, defaults to `False`):
|
| 95 |
+
Whether to use CLA in attention
|
| 96 |
+
cla_share_factor (`int`, *optional*, defaults to 1):
|
| 97 |
+
The share factor of CLA
|
| 98 |
+
num_experts (`int` or `List`, *optional*, defaults to 1):
|
| 99 |
+
The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
|
| 100 |
+
num_shared_expert (`int` or `List`, *optional*, defaults to 1):
|
| 101 |
+
The number of shared experts for moe. If it is a list, it will be used as the number of shared experts
|
| 102 |
+
for each layer.
|
| 103 |
+
moe_topk (`int` or `List`, *optional*, defaults to 1):
|
| 104 |
+
The topk value for moe. If it is a list, it will be used as the topk value for each layer.
|
| 105 |
+
capacity_factor (Not used) (`float` or `List`, *optional*, defaults to 1.0):
|
| 106 |
+
The capacity factor for moe. If it is a list, it will be used as the capacity factor for each layer.
|
| 107 |
+
moe_layer_num_skipped (`int`, *optional*, defaults to 0):
|
| 108 |
+
First moe_layer_num_skipped layers do not use MoE.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
model_type = "Hunyuan"
|
| 112 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
vocab_size: int = 290943,
|
| 117 |
+
hidden_size: int = 4096,
|
| 118 |
+
intermediate_size: int = 11008,
|
| 119 |
+
moe_intermediate_size: Union[int, List] = None,
|
| 120 |
+
num_hidden_layers: int = 32,
|
| 121 |
+
num_attention_heads: int = 32,
|
| 122 |
+
num_key_value_heads: Optional[int] = None,
|
| 123 |
+
attention_head_dim: Optional[int] = None,
|
| 124 |
+
hidden_act="silu",
|
| 125 |
+
max_position_embeddings=2048,
|
| 126 |
+
initializer_range=0.02,
|
| 127 |
+
rms_norm_eps=1e-5,
|
| 128 |
+
use_cache=True,
|
| 129 |
+
pad_token_id=0,
|
| 130 |
+
bos_token_id=1,
|
| 131 |
+
eos_token_id=2,
|
| 132 |
+
eod_token_id=3,
|
| 133 |
+
im_start_id=4,
|
| 134 |
+
im_end_id=5,
|
| 135 |
+
text_start_id=6,
|
| 136 |
+
text_end_id=7,
|
| 137 |
+
image_token_id=8,
|
| 138 |
+
video_start_id=9,
|
| 139 |
+
video_end_id=10,
|
| 140 |
+
im_newline_id=11,
|
| 141 |
+
mask_init_id=12,
|
| 142 |
+
pretraining_tp=1,
|
| 143 |
+
tie_word_embeddings=False,
|
| 144 |
+
rope_theta=10000.0,
|
| 145 |
+
rope_scaling=None,
|
| 146 |
+
attention_bias=False,
|
| 147 |
+
mlp_bias=False,
|
| 148 |
+
attention_dropout=0.0,
|
| 149 |
+
use_qk_norm=False,
|
| 150 |
+
use_rotary_pos_emb=True,
|
| 151 |
+
use_cla=False,
|
| 152 |
+
cla_share_factor=1,
|
| 153 |
+
norm_type="hf_rms",
|
| 154 |
+
num_experts: Union[int, List] = 1,
|
| 155 |
+
use_mixed_mlp_moe=False,
|
| 156 |
+
num_shared_expert: Union[int, List] = 1,
|
| 157 |
+
moe_topk: Union[int, List] = 1,
|
| 158 |
+
capacity_factor: int = 1.0,
|
| 159 |
+
moe_drop_tokens=False,
|
| 160 |
+
moe_random_routing_dropped_token=False,
|
| 161 |
+
use_mla=False,
|
| 162 |
+
kv_lora_rank=512,
|
| 163 |
+
q_lora_rank=1536,
|
| 164 |
+
qk_rope_head_dim=64,
|
| 165 |
+
v_head_dim=128,
|
| 166 |
+
qk_nope_head_dim=128,
|
| 167 |
+
moe_layer_num_skipped=0,
|
| 168 |
+
norm_topk_prob=True,
|
| 169 |
+
routed_scaling_factor=1.0,
|
| 170 |
+
group_limited_greedy=False,
|
| 171 |
+
n_group=None,
|
| 172 |
+
topk_group=None,
|
| 173 |
+
add_classification_head=False,
|
| 174 |
+
class_num=0,
|
| 175 |
+
pool_type="last",
|
| 176 |
+
pad_id=-1,
|
| 177 |
+
# Added
|
| 178 |
+
moe_impl="eager",
|
| 179 |
+
vae_downsample_factor=(16, 16), # (h, w)
|
| 180 |
+
img_proj_type="unet",
|
| 181 |
+
patch_size=1,
|
| 182 |
+
patch_embed_hidden_dim=1024,
|
| 183 |
+
image_base_size=1024,
|
| 184 |
+
rope_type="2d",
|
| 185 |
+
cond_token_attn_type="full",
|
| 186 |
+
cond_image_type="vae_vit",
|
| 187 |
+
vae_type=None,
|
| 188 |
+
vae_dtype="float32",
|
| 189 |
+
vae_autocast_dtype="float16",
|
| 190 |
+
vae=None,
|
| 191 |
+
vit_type=None,
|
| 192 |
+
vit=None,
|
| 193 |
+
vit_processor=None,
|
| 194 |
+
vit_aligner=None,
|
| 195 |
+
cfg_distilled=False,
|
| 196 |
+
use_meanflow=False,
|
| 197 |
+
**kwargs,
|
| 198 |
+
):
|
| 199 |
+
self.vocab_size = vocab_size
|
| 200 |
+
self.max_position_embeddings = max_position_embeddings
|
| 201 |
+
self.hidden_size = hidden_size
|
| 202 |
+
self.intermediate_size = intermediate_size
|
| 203 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 204 |
+
self.num_hidden_layers = num_hidden_layers
|
| 205 |
+
self.num_attention_heads = num_attention_heads
|
| 206 |
+
self.moe_impl = moe_impl
|
| 207 |
+
self.num_experts = num_experts
|
| 208 |
+
self.use_mixed_mlp_moe = use_mixed_mlp_moe
|
| 209 |
+
self.num_shared_expert = num_shared_expert
|
| 210 |
+
self.moe_topk = moe_topk
|
| 211 |
+
self.capacity_factor = capacity_factor
|
| 212 |
+
self.moe_drop_tokens = moe_drop_tokens
|
| 213 |
+
self.moe_random_routing_dropped_token = moe_random_routing_dropped_token
|
| 214 |
+
|
| 215 |
+
if attention_head_dim is not None:
|
| 216 |
+
self.attention_head_dim = attention_head_dim
|
| 217 |
+
else:
|
| 218 |
+
self.attention_head_dim = self.hidden_size // num_attention_heads
|
| 219 |
+
|
| 220 |
+
# for backward compatibility
|
| 221 |
+
if num_key_value_heads is None:
|
| 222 |
+
num_key_value_heads = num_attention_heads
|
| 223 |
+
|
| 224 |
+
self.num_key_value_heads = num_key_value_heads
|
| 225 |
+
self.hidden_act = hidden_act
|
| 226 |
+
self.initializer_range = initializer_range
|
| 227 |
+
self.rms_norm_eps = rms_norm_eps
|
| 228 |
+
self.pretraining_tp = pretraining_tp
|
| 229 |
+
self.use_cache = use_cache
|
| 230 |
+
self.rope_theta = rope_theta
|
| 231 |
+
self.rope_scaling = rope_scaling
|
| 232 |
+
self.attention_bias = attention_bias
|
| 233 |
+
self.mlp_bias = mlp_bias
|
| 234 |
+
self.attention_dropout = attention_dropout
|
| 235 |
+
self.use_qk_norm = use_qk_norm
|
| 236 |
+
self.use_rotary_pos_emb = use_rotary_pos_emb
|
| 237 |
+
self.use_cla = use_cla
|
| 238 |
+
self.cla_share_factor = cla_share_factor
|
| 239 |
+
self.norm_type = norm_type
|
| 240 |
+
# MLA args
|
| 241 |
+
self.use_mla = use_mla
|
| 242 |
+
self.kv_lora_rank = kv_lora_rank
|
| 243 |
+
self.q_lora_rank = q_lora_rank
|
| 244 |
+
self.qk_rope_head_dim = qk_rope_head_dim
|
| 245 |
+
self.qk_nope_head_dim = qk_nope_head_dim
|
| 246 |
+
self.v_head_dim = v_head_dim
|
| 247 |
+
|
| 248 |
+
# DeepSeek related args
|
| 249 |
+
self.moe_layer_num_skipped = moe_layer_num_skipped
|
| 250 |
+
self.norm_topk_prob = norm_topk_prob
|
| 251 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 252 |
+
self.group_limited_greedy = group_limited_greedy
|
| 253 |
+
self.n_group = n_group
|
| 254 |
+
self.topk_group = topk_group
|
| 255 |
+
self.add_classification_head = add_classification_head
|
| 256 |
+
self.class_num = class_num
|
| 257 |
+
self.pool_type = pool_type
|
| 258 |
+
self.pad_id = pad_id
|
| 259 |
+
|
| 260 |
+
if self.class_num is not None:
|
| 261 |
+
self.dense_list = [self.hidden_size, self.class_num]
|
| 262 |
+
|
| 263 |
+
# Conditioning image configs
|
| 264 |
+
self.cond_token_attn_type = cond_token_attn_type
|
| 265 |
+
self.cond_image_type = cond_image_type
|
| 266 |
+
|
| 267 |
+
# ViT args
|
| 268 |
+
self.vit_type = vit_type
|
| 269 |
+
self.vit = vit
|
| 270 |
+
self.vit_processor = vit_processor
|
| 271 |
+
self.vit_aligner = vit_aligner
|
| 272 |
+
|
| 273 |
+
# Image Gen args
|
| 274 |
+
self.vae_type = vae_type
|
| 275 |
+
self.vae_dtype = vae_dtype
|
| 276 |
+
self.vae_autocast_dtype = vae_autocast_dtype
|
| 277 |
+
self.vae = vae
|
| 278 |
+
self.vae_downsample_factor = vae_downsample_factor
|
| 279 |
+
self.img_proj_type = img_proj_type
|
| 280 |
+
self.patch_size = patch_size
|
| 281 |
+
self.patch_embed_hidden_dim = patch_embed_hidden_dim
|
| 282 |
+
self.image_base_size = image_base_size
|
| 283 |
+
self.rope_type = rope_type
|
| 284 |
+
|
| 285 |
+
# token id
|
| 286 |
+
self.eod_token_id = eod_token_id
|
| 287 |
+
self.im_start_id = im_start_id
|
| 288 |
+
self.im_end_id = im_end_id
|
| 289 |
+
self.text_start_id = text_start_id
|
| 290 |
+
self.text_end_id = text_end_id
|
| 291 |
+
self.image_token_id = image_token_id
|
| 292 |
+
self.video_start_id = video_start_id
|
| 293 |
+
self.video_end_id = video_end_id
|
| 294 |
+
self.im_newline_id = im_newline_id
|
| 295 |
+
self.mask_init_id = mask_init_id
|
| 296 |
+
|
| 297 |
+
# flag of cfg distilled model
|
| 298 |
+
self.cfg_distilled = cfg_distilled
|
| 299 |
+
# flag of meanflow distilled model
|
| 300 |
+
self.use_meanflow = use_meanflow
|
| 301 |
+
super().__init__(
|
| 302 |
+
pad_token_id=pad_token_id,
|
| 303 |
+
bos_token_id=bos_token_id,
|
| 304 |
+
eos_token_id=eos_token_id,
|
| 305 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 306 |
+
**kwargs,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
__all__ = ["HunyuanImage3Config"]
|
generation_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"disable_compile": true,
|
| 3 |
+
"eos_token_id": [
|
| 4 |
+
127957
|
| 5 |
+
],
|
| 6 |
+
"pad_token_id": 128009,
|
| 7 |
+
"do_sample": true,
|
| 8 |
+
"top_k": 1024,
|
| 9 |
+
"top_p": 0.95,
|
| 10 |
+
"temperature": 0.6,
|
| 11 |
+
"max_length": 22800,
|
| 12 |
+
"sequence_template": "instruct",
|
| 13 |
+
"diff_infer_steps": 50,
|
| 14 |
+
"diff_guidance_scale": 2.5,
|
| 15 |
+
"flow_shift": 3.0,
|
| 16 |
+
"use_system_prompt": "en_unified",
|
| 17 |
+
"drop_think": false,
|
| 18 |
+
"bot_task": "think_recaption",
|
| 19 |
+
"max_new_tokens": 2048,
|
| 20 |
+
"transformers_version": "4.50.0"
|
| 21 |
+
}
|
hunyuan_image_3_pipeline.py
ADDED
|
@@ -0,0 +1,913 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
#
|
| 14 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 15 |
+
#
|
| 16 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 17 |
+
# you may not use this file except in compliance with the License.
|
| 18 |
+
# You may obtain a copy of the License at
|
| 19 |
+
#
|
| 20 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 21 |
+
#
|
| 22 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 23 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 24 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 25 |
+
# See the License for the specific language governing permissions and
|
| 26 |
+
# limitations under the License.
|
| 27 |
+
# ==============================================================================================
|
| 28 |
+
|
| 29 |
+
import inspect
|
| 30 |
+
import math
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
+
from typing import Any, Callable, Dict, List
|
| 33 |
+
from typing import Optional, Tuple, Union
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import torch
|
| 37 |
+
from PIL import Image
|
| 38 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 39 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 40 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 41 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 42 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 43 |
+
from diffusers.utils import BaseOutput, logging
|
| 44 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 45 |
+
from .cache_utils import cache_init
|
| 46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def retrieve_timesteps(
|
| 50 |
+
scheduler,
|
| 51 |
+
num_inference_steps: Optional[int] = None,
|
| 52 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 53 |
+
timesteps: Optional[List[int]] = None,
|
| 54 |
+
sigmas: Optional[List[float]] = None,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 59 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
scheduler (`SchedulerMixin`):
|
| 63 |
+
The scheduler to get timesteps from.
|
| 64 |
+
num_inference_steps (`int`):
|
| 65 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 66 |
+
must be `None`.
|
| 67 |
+
device (`str` or `torch.device`, *optional*):
|
| 68 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 69 |
+
timesteps (`List[int]`, *optional*):
|
| 70 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 71 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 72 |
+
sigmas (`List[float]`, *optional*):
|
| 73 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 74 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 78 |
+
second element is the number of inference steps.
|
| 79 |
+
"""
|
| 80 |
+
if timesteps is not None and sigmas is not None:
|
| 81 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 82 |
+
if timesteps is not None:
|
| 83 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 84 |
+
if not accepts_timesteps:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 87 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 88 |
+
)
|
| 89 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 90 |
+
timesteps = scheduler.timesteps
|
| 91 |
+
num_inference_steps = len(timesteps)
|
| 92 |
+
elif sigmas is not None:
|
| 93 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 94 |
+
if not accept_sigmas:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 97 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 98 |
+
)
|
| 99 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 100 |
+
timesteps = scheduler.timesteps
|
| 101 |
+
num_inference_steps = len(timesteps)
|
| 102 |
+
else:
|
| 103 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 104 |
+
timesteps = scheduler.timesteps
|
| 105 |
+
return timesteps, num_inference_steps
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 109 |
+
r"""
|
| 110 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 111 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 112 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
noise_cfg (`torch.Tensor`):
|
| 116 |
+
The predicted noise tensor for the guided diffusion process.
|
| 117 |
+
noise_pred_text (`torch.Tensor`):
|
| 118 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 119 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 120 |
+
A rescale factor applied to the noise predictions.
|
| 121 |
+
Returns:
|
| 122 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 123 |
+
"""
|
| 124 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 125 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 126 |
+
# rescale the results from guidance (fixes overexposure)
|
| 127 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 128 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 129 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 130 |
+
return noise_cfg
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@dataclass
|
| 134 |
+
class HunyuanImage3Text2ImagePipelineOutput(BaseOutput):
|
| 135 |
+
samples: Union[List[Any], np.ndarray]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
|
| 140 |
+
"""
|
| 141 |
+
Output class for the scheduler's `step` function output.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 145 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 146 |
+
denoising loop.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
prev_sample: torch.FloatTensor
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 153 |
+
"""
|
| 154 |
+
Euler scheduler.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 157 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 161 |
+
The number of diffusion steps to train the model.
|
| 162 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 163 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 164 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 165 |
+
shift (`float`, defaults to 1.0):
|
| 166 |
+
The shift value for the timestep schedule.
|
| 167 |
+
reverse (`bool`, defaults to `True`):
|
| 168 |
+
Whether to reverse the timestep schedule.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
_compatibles = []
|
| 172 |
+
order = 1
|
| 173 |
+
|
| 174 |
+
@register_to_config
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
num_train_timesteps: int = 1000,
|
| 178 |
+
shift: float = 1.0,
|
| 179 |
+
reverse: bool = True,
|
| 180 |
+
solver: str = "euler",
|
| 181 |
+
use_flux_shift: bool = False,
|
| 182 |
+
flux_base_shift: float = 0.5,
|
| 183 |
+
flux_max_shift: float = 1.15,
|
| 184 |
+
n_tokens: Optional[int] = None,
|
| 185 |
+
):
|
| 186 |
+
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
|
| 187 |
+
|
| 188 |
+
if not reverse:
|
| 189 |
+
sigmas = sigmas.flip(0)
|
| 190 |
+
|
| 191 |
+
self.sigmas = sigmas
|
| 192 |
+
# the value fed to model
|
| 193 |
+
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
|
| 194 |
+
self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32)
|
| 195 |
+
|
| 196 |
+
self._step_index = None
|
| 197 |
+
self._begin_index = None
|
| 198 |
+
|
| 199 |
+
self.supported_solver = [
|
| 200 |
+
"euler",
|
| 201 |
+
"heun-2", "midpoint-2",
|
| 202 |
+
"kutta-4",
|
| 203 |
+
]
|
| 204 |
+
if solver not in self.supported_solver:
|
| 205 |
+
raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
|
| 206 |
+
|
| 207 |
+
# empty dt and derivative (for heun)
|
| 208 |
+
self.derivative_1 = None
|
| 209 |
+
self.derivative_2 = None
|
| 210 |
+
self.derivative_3 = None
|
| 211 |
+
self.dt = None
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def step_index(self):
|
| 215 |
+
"""
|
| 216 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 217 |
+
"""
|
| 218 |
+
return self._step_index
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def begin_index(self):
|
| 222 |
+
"""
|
| 223 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 224 |
+
"""
|
| 225 |
+
return self._begin_index
|
| 226 |
+
|
| 227 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 228 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 229 |
+
"""
|
| 230 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
begin_index (`int`):
|
| 234 |
+
The begin index for the scheduler.
|
| 235 |
+
"""
|
| 236 |
+
self._begin_index = begin_index
|
| 237 |
+
|
| 238 |
+
def _sigma_to_t(self, sigma):
|
| 239 |
+
return sigma * self.config.num_train_timesteps
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def state_in_first_order(self):
|
| 243 |
+
return self.derivative_1 is None
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def state_in_second_order(self):
|
| 247 |
+
return self.derivative_2 is None
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def state_in_third_order(self):
|
| 251 |
+
return self.derivative_3 is None
|
| 252 |
+
|
| 253 |
+
def get_timestep_r(self, timestep: Union[float, torch.FloatTensor]):
|
| 254 |
+
if self.step_index is None:
|
| 255 |
+
self._init_step_index(timestep)
|
| 256 |
+
return self.timesteps_full[self.step_index + 1]
|
| 257 |
+
|
| 258 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
|
| 259 |
+
n_tokens: int = None):
|
| 260 |
+
"""
|
| 261 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
num_inference_steps (`int`):
|
| 265 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 266 |
+
device (`str` or `torch.device`, *optional*):
|
| 267 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 268 |
+
n_tokens (`int`, *optional*):
|
| 269 |
+
Number of tokens in the input sequence.
|
| 270 |
+
"""
|
| 271 |
+
self.num_inference_steps = num_inference_steps
|
| 272 |
+
|
| 273 |
+
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
|
| 274 |
+
|
| 275 |
+
# Apply timestep shift
|
| 276 |
+
if self.config.use_flux_shift:
|
| 277 |
+
assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift"
|
| 278 |
+
mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens)
|
| 279 |
+
sigmas = self.flux_time_shift(mu, 1.0, sigmas)
|
| 280 |
+
elif self.config.shift != 1.:
|
| 281 |
+
sigmas = self.sd3_time_shift(sigmas)
|
| 282 |
+
|
| 283 |
+
if not self.config.reverse:
|
| 284 |
+
sigmas = 1 - sigmas
|
| 285 |
+
|
| 286 |
+
self.sigmas = sigmas
|
| 287 |
+
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
|
| 288 |
+
self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
|
| 289 |
+
|
| 290 |
+
# empty dt and derivative (for kutta)
|
| 291 |
+
self.derivative_1 = None
|
| 292 |
+
self.derivative_2 = None
|
| 293 |
+
self.derivative_3 = None
|
| 294 |
+
self.dt = None
|
| 295 |
+
|
| 296 |
+
# Reset step index
|
| 297 |
+
self._step_index = None
|
| 298 |
+
|
| 299 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 300 |
+
if schedule_timesteps is None:
|
| 301 |
+
schedule_timesteps = self.timesteps
|
| 302 |
+
|
| 303 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 304 |
+
|
| 305 |
+
# The sigma index that is taken for the **very** first `step`
|
| 306 |
+
# is always the second index (or the last index if there is only 1)
|
| 307 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 308 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 309 |
+
pos = 1 if len(indices) > 1 else 0
|
| 310 |
+
|
| 311 |
+
return indices[pos].item()
|
| 312 |
+
|
| 313 |
+
def _init_step_index(self, timestep):
|
| 314 |
+
if self.begin_index is None:
|
| 315 |
+
if isinstance(timestep, torch.Tensor):
|
| 316 |
+
timestep = timestep.to(self.timesteps.device)
|
| 317 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 318 |
+
else:
|
| 319 |
+
self._step_index = self._begin_index
|
| 320 |
+
|
| 321 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
| 322 |
+
return sample
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
|
| 326 |
+
m = (y2 - y1) / (x2 - x1)
|
| 327 |
+
b = y1 - m * x1
|
| 328 |
+
return lambda x: m * x + b
|
| 329 |
+
|
| 330 |
+
@staticmethod
|
| 331 |
+
def flux_time_shift(mu: float, sigma: float, t: torch.Tensor):
|
| 332 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 333 |
+
|
| 334 |
+
def sd3_time_shift(self, t: torch.Tensor):
|
| 335 |
+
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
|
| 336 |
+
|
| 337 |
+
def step(
|
| 338 |
+
self,
|
| 339 |
+
model_output: torch.FloatTensor,
|
| 340 |
+
timestep: Union[float, torch.FloatTensor],
|
| 341 |
+
sample: torch.FloatTensor,
|
| 342 |
+
pred_uncond: torch.FloatTensor = None,
|
| 343 |
+
generator: Optional[torch.Generator] = None,
|
| 344 |
+
n_tokens: Optional[int] = None,
|
| 345 |
+
return_dict: bool = True,
|
| 346 |
+
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
|
| 347 |
+
"""
|
| 348 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 349 |
+
process from the learned model outputs (most often the predicted noise).
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
model_output (`torch.FloatTensor`):
|
| 353 |
+
The direct output from learned diffusion model.
|
| 354 |
+
timestep (`float`):
|
| 355 |
+
The current discrete timestep in the diffusion chain.
|
| 356 |
+
sample (`torch.FloatTensor`):
|
| 357 |
+
A current instance of a sample created by the diffusion process.
|
| 358 |
+
generator (`torch.Generator`, *optional*):
|
| 359 |
+
A random number generator.
|
| 360 |
+
n_tokens (`int`, *optional*):
|
| 361 |
+
Number of tokens in the input sequence.
|
| 362 |
+
return_dict (`bool`):
|
| 363 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 364 |
+
tuple.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 368 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 369 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
if (
|
| 373 |
+
isinstance(timestep, int)
|
| 374 |
+
or isinstance(timestep, torch.IntTensor)
|
| 375 |
+
or isinstance(timestep, torch.LongTensor)
|
| 376 |
+
):
|
| 377 |
+
raise ValueError(
|
| 378 |
+
(
|
| 379 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 380 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 381 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 382 |
+
),
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
if self.step_index is None:
|
| 386 |
+
self._init_step_index(timestep)
|
| 387 |
+
|
| 388 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 389 |
+
sample = sample.to(torch.float32)
|
| 390 |
+
model_output = model_output.to(torch.float32)
|
| 391 |
+
pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None
|
| 392 |
+
|
| 393 |
+
# dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
|
| 394 |
+
sigma = self.sigmas[self.step_index]
|
| 395 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
| 396 |
+
|
| 397 |
+
last_inner_step = True
|
| 398 |
+
if self.config.solver == "euler":
|
| 399 |
+
derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample)
|
| 400 |
+
elif self.config.solver in ["heun-2", "midpoint-2"]:
|
| 401 |
+
derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample)
|
| 402 |
+
elif self.config.solver == "kutta-4":
|
| 403 |
+
derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample)
|
| 404 |
+
else:
|
| 405 |
+
raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
|
| 406 |
+
|
| 407 |
+
prev_sample = sample + derivative * dt
|
| 408 |
+
|
| 409 |
+
# Cast sample back to model compatible dtype
|
| 410 |
+
# prev_sample = prev_sample.to(model_output.dtype)
|
| 411 |
+
|
| 412 |
+
# upon completion increase step index by one
|
| 413 |
+
if last_inner_step:
|
| 414 |
+
self._step_index += 1
|
| 415 |
+
|
| 416 |
+
if not return_dict:
|
| 417 |
+
return (prev_sample,)
|
| 418 |
+
|
| 419 |
+
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 420 |
+
|
| 421 |
+
def first_order_method(self, model_output, sigma, sigma_next, sample):
|
| 422 |
+
derivative = model_output
|
| 423 |
+
dt = sigma_next - sigma
|
| 424 |
+
return derivative, dt, sample, True
|
| 425 |
+
|
| 426 |
+
def second_order_method(self, model_output, sigma, sigma_next, sample):
|
| 427 |
+
if self.state_in_first_order:
|
| 428 |
+
# store for 2nd order step
|
| 429 |
+
self.derivative_1 = model_output
|
| 430 |
+
self.dt = sigma_next - sigma
|
| 431 |
+
self.sample = sample
|
| 432 |
+
|
| 433 |
+
derivative = model_output
|
| 434 |
+
if self.config.solver == 'heun-2':
|
| 435 |
+
dt = self.dt
|
| 436 |
+
elif self.config.solver == 'midpoint-2':
|
| 437 |
+
dt = self.dt / 2
|
| 438 |
+
else:
|
| 439 |
+
raise NotImplementedError(f"Solver {self.config.solver} not supported.")
|
| 440 |
+
last_inner_step = False
|
| 441 |
+
|
| 442 |
+
else:
|
| 443 |
+
if self.config.solver == 'heun-2':
|
| 444 |
+
derivative = 0.5 * (self.derivative_1 + model_output)
|
| 445 |
+
elif self.config.solver == 'midpoint-2':
|
| 446 |
+
derivative = model_output
|
| 447 |
+
else:
|
| 448 |
+
raise NotImplementedError(f"Solver {self.config.solver} not supported.")
|
| 449 |
+
|
| 450 |
+
# 3. take prev timestep & sample
|
| 451 |
+
dt = self.dt
|
| 452 |
+
sample = self.sample
|
| 453 |
+
last_inner_step = True
|
| 454 |
+
|
| 455 |
+
# free dt and derivative
|
| 456 |
+
# Note, this puts the scheduler in "first order mode"
|
| 457 |
+
self.derivative_1 = None
|
| 458 |
+
self.dt = None
|
| 459 |
+
self.sample = None
|
| 460 |
+
|
| 461 |
+
return derivative, dt, sample, last_inner_step
|
| 462 |
+
|
| 463 |
+
def fourth_order_method(self, model_output, sigma, sigma_next, sample):
|
| 464 |
+
if self.state_in_first_order:
|
| 465 |
+
self.derivative_1 = model_output
|
| 466 |
+
self.dt = sigma_next - sigma
|
| 467 |
+
self.sample = sample
|
| 468 |
+
derivative = model_output
|
| 469 |
+
dt = self.dt / 2
|
| 470 |
+
last_inner_step = False
|
| 471 |
+
|
| 472 |
+
elif self.state_in_second_order:
|
| 473 |
+
self.derivative_2 = model_output
|
| 474 |
+
derivative = model_output
|
| 475 |
+
dt = self.dt / 2
|
| 476 |
+
last_inner_step = False
|
| 477 |
+
|
| 478 |
+
elif self.state_in_third_order:
|
| 479 |
+
self.derivative_3 = model_output
|
| 480 |
+
derivative = model_output
|
| 481 |
+
dt = self.dt
|
| 482 |
+
last_inner_step = False
|
| 483 |
+
|
| 484 |
+
else:
|
| 485 |
+
derivative = (1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 +
|
| 486 |
+
1/6 * model_output)
|
| 487 |
+
|
| 488 |
+
# 3. take prev timestep & sample
|
| 489 |
+
dt = self.dt
|
| 490 |
+
sample = self.sample
|
| 491 |
+
last_inner_step = True
|
| 492 |
+
|
| 493 |
+
# free dt and derivative
|
| 494 |
+
# Note, this puts the scheduler in "first order mode"
|
| 495 |
+
self.derivative_1 = None
|
| 496 |
+
self.derivative_2 = None
|
| 497 |
+
self.derivative_3 = None
|
| 498 |
+
self.dt = None
|
| 499 |
+
self.sample = None
|
| 500 |
+
|
| 501 |
+
return derivative, dt, sample, last_inner_step
|
| 502 |
+
|
| 503 |
+
def __len__(self):
|
| 504 |
+
return self.config.num_train_timesteps
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class ClassifierFreeGuidance:
|
| 508 |
+
def __init__(
|
| 509 |
+
self,
|
| 510 |
+
use_original_formulation: bool = False,
|
| 511 |
+
start: float = 0.0,
|
| 512 |
+
stop: float = 1.0,
|
| 513 |
+
):
|
| 514 |
+
super().__init__()
|
| 515 |
+
self.use_original_formulation = use_original_formulation
|
| 516 |
+
|
| 517 |
+
def __call__(
|
| 518 |
+
self,
|
| 519 |
+
pred_cond: torch.Tensor,
|
| 520 |
+
pred_uncond: Optional[torch.Tensor],
|
| 521 |
+
guidance_scale: float,
|
| 522 |
+
step: int,
|
| 523 |
+
) -> torch.Tensor:
|
| 524 |
+
|
| 525 |
+
shift = pred_cond - pred_uncond
|
| 526 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
| 527 |
+
pred = pred + guidance_scale * shift
|
| 528 |
+
|
| 529 |
+
return pred
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
class HunyuanImage3Text2ImagePipeline(DiffusionPipeline):
|
| 533 |
+
r"""
|
| 534 |
+
Pipeline for condition-to-sample generation using Stable Diffusion.
|
| 535 |
+
|
| 536 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 537 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
model ([`ModelMixin`]):
|
| 541 |
+
A model to denoise the diffused latents.
|
| 542 |
+
scheduler ([`SchedulerMixin`]):
|
| 543 |
+
A scheduler to be used in combination with `diffusion_model` to denoise the diffused latents. Can be one of
|
| 544 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
model_cpu_offload_seq = ""
|
| 548 |
+
_optional_components = []
|
| 549 |
+
_exclude_from_cpu_offload = []
|
| 550 |
+
_callback_tensor_inputs = ["latents"]
|
| 551 |
+
|
| 552 |
+
def __init__(
|
| 553 |
+
self,
|
| 554 |
+
model,
|
| 555 |
+
scheduler: SchedulerMixin,
|
| 556 |
+
vae,
|
| 557 |
+
progress_bar_config: Dict[str, Any] = None,
|
| 558 |
+
):
|
| 559 |
+
super().__init__()
|
| 560 |
+
|
| 561 |
+
# ==========================================================================================
|
| 562 |
+
if progress_bar_config is None:
|
| 563 |
+
progress_bar_config = {}
|
| 564 |
+
if not hasattr(self, '_progress_bar_config'):
|
| 565 |
+
self._progress_bar_config = {}
|
| 566 |
+
self._progress_bar_config.update(progress_bar_config)
|
| 567 |
+
# ==========================================================================================
|
| 568 |
+
|
| 569 |
+
self.register_modules(
|
| 570 |
+
model=model,
|
| 571 |
+
scheduler=scheduler,
|
| 572 |
+
vae=vae,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# should be a tuple or a list corresponding to the size of latents (batch_size, channel, *size)
|
| 576 |
+
# if None, will be treated as a tuple of 1
|
| 577 |
+
self.latent_scale_factor = self.model.config.vae_downsample_factor
|
| 578 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.latent_scale_factor)
|
| 579 |
+
|
| 580 |
+
# Must start with APG_mode_
|
| 581 |
+
self.cfg_operator = ClassifierFreeGuidance()
|
| 582 |
+
|
| 583 |
+
@staticmethod
|
| 584 |
+
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
| 585 |
+
"""
|
| 586 |
+
Denormalize an image array to [0,1].
|
| 587 |
+
"""
|
| 588 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
| 589 |
+
|
| 590 |
+
@staticmethod
|
| 591 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
| 592 |
+
"""
|
| 593 |
+
Convert a PyTorch tensor to a NumPy image.
|
| 594 |
+
"""
|
| 595 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 596 |
+
return images
|
| 597 |
+
|
| 598 |
+
@staticmethod
|
| 599 |
+
def numpy_to_pil(images: np.ndarray):
|
| 600 |
+
"""
|
| 601 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 602 |
+
"""
|
| 603 |
+
if images.ndim == 3:
|
| 604 |
+
images = images[None, ...]
|
| 605 |
+
images = (images * 255).round().astype("uint8")
|
| 606 |
+
if images.shape[-1] == 1:
|
| 607 |
+
# special case for grayscale (single channel) images
|
| 608 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 609 |
+
else:
|
| 610 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 611 |
+
|
| 612 |
+
return pil_images
|
| 613 |
+
|
| 614 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
| 615 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 616 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 617 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 618 |
+
# and should be between [0, 1]
|
| 619 |
+
extra_kwargs = {}
|
| 620 |
+
|
| 621 |
+
for k, v in kwargs.items():
|
| 622 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
| 623 |
+
if accepts:
|
| 624 |
+
extra_kwargs[k] = v
|
| 625 |
+
return extra_kwargs
|
| 626 |
+
|
| 627 |
+
def prepare_latents(self, batch_size, latent_channel, image_size, dtype, device, generator, latents=None):
|
| 628 |
+
if self.latent_scale_factor is None:
|
| 629 |
+
latent_scale_factor = (1,) * len(image_size)
|
| 630 |
+
elif isinstance(self.latent_scale_factor, int):
|
| 631 |
+
latent_scale_factor = (self.latent_scale_factor,) * len(image_size)
|
| 632 |
+
elif isinstance(self.latent_scale_factor, tuple) or isinstance(self.latent_scale_factor, list):
|
| 633 |
+
assert len(self.latent_scale_factor) == len(image_size), \
|
| 634 |
+
"len(latent_scale_factor) shoudl be the same as len(image_size)"
|
| 635 |
+
latent_scale_factor = self.latent_scale_factor
|
| 636 |
+
else:
|
| 637 |
+
raise ValueError(
|
| 638 |
+
f"latent_scale_factor should be either None, int, tuple of int, or list of int, "
|
| 639 |
+
f"but got {self.latent_scale_factor}"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
latents_shape = (
|
| 643 |
+
batch_size,
|
| 644 |
+
latent_channel,
|
| 645 |
+
*[int(s) // f for s, f in zip(image_size, latent_scale_factor)],
|
| 646 |
+
)
|
| 647 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 648 |
+
raise ValueError(
|
| 649 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 650 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
if latents is None:
|
| 654 |
+
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
| 655 |
+
else:
|
| 656 |
+
latents = latents.to(device)
|
| 657 |
+
|
| 658 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
| 659 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 660 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 661 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 662 |
+
|
| 663 |
+
return latents
|
| 664 |
+
|
| 665 |
+
@property
|
| 666 |
+
def guidance_scale(self):
|
| 667 |
+
return self._guidance_scale
|
| 668 |
+
|
| 669 |
+
@property
|
| 670 |
+
def guidance_rescale(self):
|
| 671 |
+
return self._guidance_rescale
|
| 672 |
+
|
| 673 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 674 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 675 |
+
# corresponds to doing no classifier free guidance.
|
| 676 |
+
@property
|
| 677 |
+
def do_classifier_free_guidance(self):
|
| 678 |
+
return self._guidance_scale > 1.0
|
| 679 |
+
|
| 680 |
+
@property
|
| 681 |
+
def num_timesteps(self):
|
| 682 |
+
return self._num_timesteps
|
| 683 |
+
|
| 684 |
+
def set_scheduler(self, new_scheduler):
|
| 685 |
+
self.register_modules(scheduler=new_scheduler)
|
| 686 |
+
|
| 687 |
+
@torch.no_grad()
|
| 688 |
+
def __call__(
|
| 689 |
+
self,
|
| 690 |
+
batch_size: int,
|
| 691 |
+
image_size: List[int],
|
| 692 |
+
num_inference_steps: int = 50,
|
| 693 |
+
timesteps: List[int] = None,
|
| 694 |
+
sigmas: List[float] = None,
|
| 695 |
+
guidance_scale: float = 7.5,
|
| 696 |
+
meanflow: bool = False,
|
| 697 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 698 |
+
latents: Optional[torch.Tensor] = None,
|
| 699 |
+
output_type: Optional[str] = "pil",
|
| 700 |
+
return_dict: bool = True,
|
| 701 |
+
guidance_rescale: float = 0.0,
|
| 702 |
+
callback_on_step_end: Optional[
|
| 703 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 704 |
+
] = None,
|
| 705 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 706 |
+
model_kwargs: Dict[str, Any] = None,
|
| 707 |
+
**kwargs,
|
| 708 |
+
):
|
| 709 |
+
r"""
|
| 710 |
+
The call function to the pipeline for generation.
|
| 711 |
+
|
| 712 |
+
Args:
|
| 713 |
+
prompt (`str` or `List[str]`):
|
| 714 |
+
The text to guide image generation.
|
| 715 |
+
image_size (`Tuple[int]` or `List[int]`):
|
| 716 |
+
The size (height, width) of the generated image.
|
| 717 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 718 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 719 |
+
expense of slower inference.
|
| 720 |
+
timesteps (`List[int]`, *optional*):
|
| 721 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 722 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 723 |
+
passed will be used. Must be in descending order.
|
| 724 |
+
sigmas (`List[float]`, *optional*):
|
| 725 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 726 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 727 |
+
will be used.
|
| 728 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 729 |
+
A higher guidance scale value encourages the model to generate samples closely linked to the
|
| 730 |
+
`condition` at the expense of lower sample quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 731 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 732 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 733 |
+
generation deterministic.
|
| 734 |
+
latents (`torch.Tensor`, *optional*):
|
| 735 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for sample
|
| 736 |
+
generation. Can be used to tweak the same generation with different conditions. If not provided,
|
| 737 |
+
a latents tensor is generated by sampling using the supplied random `generator`.
|
| 738 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 739 |
+
The output format of the generated sample.
|
| 740 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 741 |
+
Whether or not to return a [`~DiffusionPipelineOutput`] instead of a
|
| 742 |
+
plain tuple.
|
| 743 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 744 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
| 745 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
| 746 |
+
using zero terminal SNR.
|
| 747 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 748 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 749 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 750 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 751 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 752 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 753 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 754 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 755 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 756 |
+
|
| 757 |
+
Examples:
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
[`~DiffusionPipelineOutput`] or `tuple`:
|
| 761 |
+
If `return_dict` is `True`, [`~DiffusionPipelineOutput`] is returned,
|
| 762 |
+
otherwise a `tuple` is returned where the first element is a list with the generated samples.
|
| 763 |
+
"""
|
| 764 |
+
|
| 765 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 766 |
+
pbar_steps = kwargs.pop("pbar_steps", None)
|
| 767 |
+
|
| 768 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 769 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 770 |
+
|
| 771 |
+
self._guidance_scale = guidance_scale
|
| 772 |
+
self._guidance_rescale = guidance_rescale
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
if not kwargs.get('cfg_distilled', False):
|
| 776 |
+
cfg_factor = 1 + self.do_classifier_free_guidance
|
| 777 |
+
else:
|
| 778 |
+
cfg_factor = 1
|
| 779 |
+
# Define call parameters
|
| 780 |
+
device = self._execution_device
|
| 781 |
+
|
| 782 |
+
# Prepare timesteps
|
| 783 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 784 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# Prepare latent variables
|
| 788 |
+
latents = self.prepare_latents(
|
| 789 |
+
batch_size=batch_size,
|
| 790 |
+
latent_channel=self.model.config.vae["latent_channels"],
|
| 791 |
+
image_size=image_size,
|
| 792 |
+
dtype=torch.bfloat16,
|
| 793 |
+
device=device,
|
| 794 |
+
generator=generator,
|
| 795 |
+
latents=latents,
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# Prepare extra step kwargs.
|
| 799 |
+
_scheduler_step_extra_kwargs = self.prepare_extra_func_kwargs(
|
| 800 |
+
self.scheduler.step, {"generator": generator}
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# Prepare model kwargs
|
| 804 |
+
input_ids = model_kwargs.pop("input_ids")
|
| 805 |
+
attention_mask = self.model._prepare_attention_mask_for_generation( # noqa
|
| 806 |
+
input_ids, self.model.generation_config, model_kwargs=model_kwargs,
|
| 807 |
+
)
|
| 808 |
+
model_kwargs["attention_mask"] = attention_mask.to(latents.device)
|
| 809 |
+
|
| 810 |
+
# Sampling loop
|
| 811 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 812 |
+
self._num_timesteps = len(timesteps)
|
| 813 |
+
|
| 814 |
+
# Taylor cache
|
| 815 |
+
cache_dic = None
|
| 816 |
+
if self.model.use_taylor_cache:
|
| 817 |
+
cache_dic = cache_init(cache_interval=self.model.taylor_cache_interval, max_order=self.model.taylor_cache_order, num_steps=len(timesteps),
|
| 818 |
+
enable_first_enhance=self.model.taylor_cache_enable_first_enhance, first_enhance_steps=self.model.taylor_cache_first_enhance_steps,
|
| 819 |
+
enable_tailing_enhance=self.model.taylor_cache_enable_tailing_enhance,
|
| 820 |
+
tailing_enhance_steps=self.model.taylor_cache_tailing_enhance_steps,
|
| 821 |
+
low_freqs_order=self.model.taylor_cache_low_freqs_order,
|
| 822 |
+
high_freqs_order=self.model.taylor_cache_high_freqs_order)
|
| 823 |
+
print(f"***use_taylor_cache: {self.model.use_taylor_cache}, cache_dic: {cache_dic}")
|
| 824 |
+
|
| 825 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 826 |
+
for i, t in enumerate(timesteps):
|
| 827 |
+
# expand the latents if we are doing classifier free guidance
|
| 828 |
+
latent_model_input = torch.cat([latents] * cfg_factor)
|
| 829 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 830 |
+
|
| 831 |
+
if meanflow:
|
| 832 |
+
r = self.scheduler.get_timestep_r(t)
|
| 833 |
+
r_expand = r.repeat(latent_model_input.shape[0])
|
| 834 |
+
else:
|
| 835 |
+
r_expand = None
|
| 836 |
+
model_kwargs["timesteps_r"] = r_expand
|
| 837 |
+
|
| 838 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 839 |
+
|
| 840 |
+
if self.model.use_taylor_cache:
|
| 841 |
+
cache_dic['current_step'] = i
|
| 842 |
+
model_kwargs['cache_dic'] = cache_dic
|
| 843 |
+
if kwargs.get('cfg_distilled', False):
|
| 844 |
+
model_kwargs["guidance"] = torch.tensor(
|
| 845 |
+
[1000.0*self._guidance_scale], device=self.device, dtype=torch.bfloat16
|
| 846 |
+
)
|
| 847 |
+
model_inputs = self.model.prepare_inputs_for_generation(
|
| 848 |
+
input_ids,
|
| 849 |
+
images=latent_model_input,
|
| 850 |
+
timesteps=t_expand,
|
| 851 |
+
**model_kwargs,
|
| 852 |
+
)
|
| 853 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 854 |
+
model_output = self.model(**model_inputs, first_step=(i == 0))
|
| 855 |
+
pred = model_output["diffusion_prediction"]
|
| 856 |
+
pred = pred.to(dtype=torch.float32)
|
| 857 |
+
# perform guidance
|
| 858 |
+
if self.do_classifier_free_guidance:
|
| 859 |
+
if not kwargs.get('cfg_distilled', False):
|
| 860 |
+
pred_cond, pred_uncond = pred.chunk(2)
|
| 861 |
+
pred = self.cfg_operator(pred_cond, pred_uncond, self.guidance_scale, step=i)
|
| 862 |
+
|
| 863 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 864 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 865 |
+
pred = rescale_noise_cfg(pred, pred_cond, guidance_rescale=self.guidance_rescale)
|
| 866 |
+
|
| 867 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 868 |
+
latents = self.scheduler.step(pred, t, latents, **_scheduler_step_extra_kwargs, return_dict=False)[0]
|
| 869 |
+
|
| 870 |
+
if i != len(timesteps) - 1:
|
| 871 |
+
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa
|
| 872 |
+
model_output,
|
| 873 |
+
model_kwargs,
|
| 874 |
+
)
|
| 875 |
+
input_ids = None
|
| 876 |
+
# if input_ids.shape[1] != model_kwargs["position_ids"].shape[1]:
|
| 877 |
+
# input_ids = torch.gather(input_ids, 1, index=model_kwargs["position_ids"])
|
| 878 |
+
|
| 879 |
+
if callback_on_step_end is not None:
|
| 880 |
+
callback_kwargs = {}
|
| 881 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 882 |
+
callback_kwargs[k] = locals()[k]
|
| 883 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 884 |
+
|
| 885 |
+
latents = callback_outputs.pop("latents", latents)
|
| 886 |
+
|
| 887 |
+
# call the callback, if provided
|
| 888 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 889 |
+
progress_bar.update()
|
| 890 |
+
|
| 891 |
+
if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor:
|
| 892 |
+
latents = latents / self.vae.config.scaling_factor
|
| 893 |
+
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
|
| 894 |
+
latents = latents + self.vae.config.shift_factor
|
| 895 |
+
|
| 896 |
+
if hasattr(self.vae, "ffactor_temporal"):
|
| 897 |
+
latents = latents.unsqueeze(2)
|
| 898 |
+
|
| 899 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
| 900 |
+
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
| 901 |
+
|
| 902 |
+
# b c t h w
|
| 903 |
+
if hasattr(self.vae, "ffactor_temporal"):
|
| 904 |
+
assert image.shape[2] == 1, "image should have shape [B, C, T, H, W] and T should be 1"
|
| 905 |
+
image = image.squeeze(2)
|
| 906 |
+
|
| 907 |
+
do_denormalize = [True] * image.shape[0]
|
| 908 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 909 |
+
|
| 910 |
+
if not return_dict:
|
| 911 |
+
return (image,)
|
| 912 |
+
|
| 913 |
+
return HunyuanImage3Text2ImagePipelineOutput(samples=image)
|
image_processor.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
# ==============================================================================
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass, field, asdict
|
| 15 |
+
from typing import Tuple, Optional, Callable, Union, Any
|
| 16 |
+
import random
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from torchvision import transforms
|
| 22 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 23 |
+
from transformers.image_utils import load_image
|
| 24 |
+
from transformers.models.siglip2.image_processing_siglip2_fast import Siglip2ImageProcessorFast
|
| 25 |
+
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
|
| 26 |
+
|
| 27 |
+
from .tokenization_hunyuan_image_3 import ImageInfo, ImageTensor, CondImage, Resolution, ResolutionGroup
|
| 28 |
+
|
| 29 |
+
InputImage = Union[Image.Image, str]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SliceVocabLogitsProcessor(LogitsProcessor):
|
| 33 |
+
"""
|
| 34 |
+
[`LogitsProcessor`] that performs vocab slicing, i.e. restricting probabilities with in some range. This processor
|
| 35 |
+
is often used in multimodal discrete LLMs, which ensure that we only sample within one modality
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
vocab_start (`int`): start of slice, default None meaning from 0
|
| 39 |
+
vocab_end (`int`): end of slice, default None meaning to the end of list
|
| 40 |
+
when start and end are all None, this processor does noting
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, vocab_start: int = None, vocab_end: int = None, **kwargs):
|
| 45 |
+
if vocab_start is not None and vocab_end is not None:
|
| 46 |
+
assert vocab_start < vocab_end, f"Ensure vocab_start {vocab_start} < vocab_end {vocab_end}"
|
| 47 |
+
self.vocab_start = vocab_start
|
| 48 |
+
self.vocab_end = vocab_end
|
| 49 |
+
self.other_slices = kwargs.get("other_slices", [])
|
| 50 |
+
|
| 51 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 52 |
+
scores_processed = scores[:, self.vocab_start: self.vocab_end]
|
| 53 |
+
for other_slice in self.other_slices:
|
| 54 |
+
scores_processed = torch.cat([scores_processed, scores[:, other_slice[0]: other_slice[1]]], dim=-1)
|
| 55 |
+
return scores_processed
|
| 56 |
+
|
| 57 |
+
def __repr__(self):
|
| 58 |
+
return f"SliceVocabLogitsWarper(vocab_start={self.vocab_start}, vocab_end={self.vocab_end}, other_slices={self.other_slices})"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def resize_and_crop(image: Image.Image, target_size: Tuple[int, int], resample=Image.Resampling.LANCZOS, crop_type='center', crop_coords=None) -> Image.Image:
|
| 62 |
+
tw, th = target_size
|
| 63 |
+
w, h = image.size
|
| 64 |
+
|
| 65 |
+
tr = th / tw
|
| 66 |
+
r = h / w
|
| 67 |
+
|
| 68 |
+
if crop_type == "resize":
|
| 69 |
+
resize_width = tw
|
| 70 |
+
resize_height = th
|
| 71 |
+
crop_top = 0
|
| 72 |
+
crop_left = 0
|
| 73 |
+
image = image.resize((resize_width, resize_height), resample=resample)
|
| 74 |
+
else:
|
| 75 |
+
# maintain the aspect ratio
|
| 76 |
+
if r < tr:
|
| 77 |
+
resize_height = th
|
| 78 |
+
resize_width = int(round(th / h * w))
|
| 79 |
+
else:
|
| 80 |
+
resize_width = tw
|
| 81 |
+
resize_height = int(round(tw / w * h))
|
| 82 |
+
|
| 83 |
+
if crop_type == 'center':
|
| 84 |
+
crop_top = int(round((resize_height - th) / 2.0))
|
| 85 |
+
crop_left = int(round((resize_width - tw) / 2.0))
|
| 86 |
+
elif crop_type == 'random':
|
| 87 |
+
crop_top = random.randint(0, resize_height - th)
|
| 88 |
+
crop_left = random.randint(0, resize_width - tw)
|
| 89 |
+
elif crop_type == 'fixed':
|
| 90 |
+
assert crop_coords is not None, 'crop_coords should be provided when crop_type is fixed.'
|
| 91 |
+
crop_left, crop_top = crop_coords
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f'crop_type must be center, random or fixed, but got {crop_type}')
|
| 94 |
+
|
| 95 |
+
image = image.resize((resize_width, resize_height), resample=resample)
|
| 96 |
+
image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th))
|
| 97 |
+
|
| 98 |
+
return image
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass
|
| 102 |
+
class ResolutionGroupConfig:
|
| 103 |
+
base_size: int = None
|
| 104 |
+
step: Optional[int] = None
|
| 105 |
+
align: int = 16
|
| 106 |
+
|
| 107 |
+
def to_dict(self):
|
| 108 |
+
return asdict(self)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@dataclass
|
| 112 |
+
class VAEInfo:
|
| 113 |
+
encoder_type: str
|
| 114 |
+
down_h_factor: int = -1
|
| 115 |
+
down_w_factor: int = -1
|
| 116 |
+
patch_size: int = 1
|
| 117 |
+
h_factor: int = -1
|
| 118 |
+
w_factor: int = -1
|
| 119 |
+
image_type: str = None
|
| 120 |
+
|
| 121 |
+
def __post_init__(self):
|
| 122 |
+
self.h_factor = self.down_h_factor * self.patch_size
|
| 123 |
+
self.w_factor = self.down_w_factor * self.patch_size
|
| 124 |
+
if self.image_type is None:
|
| 125 |
+
self.image_type = "vae"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class ViTInfo:
|
| 130 |
+
encoder_type: str
|
| 131 |
+
h_factor: int = -1
|
| 132 |
+
w_factor: int = -1
|
| 133 |
+
max_token_length: int = 0 # pad to max_token_length
|
| 134 |
+
processor: Callable = field(default_factory=BaseImageProcessor)
|
| 135 |
+
image_type: str = None
|
| 136 |
+
|
| 137 |
+
def __post_init__(self):
|
| 138 |
+
if self.image_type is None:
|
| 139 |
+
self.image_type = self.encoder_type.split("-")[0]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class HunyuanImage3ImageProcessor(object):
|
| 143 |
+
def __init__(self, config):
|
| 144 |
+
self.config = config
|
| 145 |
+
|
| 146 |
+
self.reso_group_config = ResolutionGroupConfig(base_size=config.image_base_size)
|
| 147 |
+
self.vae_reso_group = ResolutionGroup(
|
| 148 |
+
**self.reso_group_config.to_dict(),
|
| 149 |
+
extra_resolutions=[
|
| 150 |
+
Resolution("1024x768"),
|
| 151 |
+
Resolution("1280x720"),
|
| 152 |
+
Resolution("768x1024"),
|
| 153 |
+
Resolution("720x1280"),
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
self.img_ratio_slice_logits_processor = None
|
| 157 |
+
self.pil_image_to_tensor = transforms.Compose([
|
| 158 |
+
transforms.ToTensor(),
|
| 159 |
+
transforms.Normalize([0.5], [0.5]), # transform to [-1, 1]
|
| 160 |
+
])
|
| 161 |
+
self.vae_info = VAEInfo(
|
| 162 |
+
encoder_type=config.vae_type,
|
| 163 |
+
down_h_factor=config.vae_downsample_factor[0], down_w_factor=config.vae_downsample_factor[0],
|
| 164 |
+
patch_size=config.patch_size,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if config.vit_type == "siglip2-so400m-patch16-naflex":
|
| 168 |
+
self.vit_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor)
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Unsupported vit_type: {config.vit_type}")
|
| 171 |
+
self.vit_info = ViTInfo(
|
| 172 |
+
encoder_type=config.vit_type,
|
| 173 |
+
h_factor=self.vit_processor.patch_size,
|
| 174 |
+
w_factor=self.vit_processor.patch_size,
|
| 175 |
+
max_token_length=self.vit_processor.max_num_patches,
|
| 176 |
+
processor=self.vit_processor,
|
| 177 |
+
)
|
| 178 |
+
self.cond_token_attn_type = config.cond_token_attn_type
|
| 179 |
+
self.cond_image_type = config.cond_image_type
|
| 180 |
+
|
| 181 |
+
def build_gen_image_info(self, image_size, add_guidance_token=False, add_timestep_r_token=False) -> ImageInfo:
|
| 182 |
+
# parse image size (HxW, H:W, or <img_ratio_i>)
|
| 183 |
+
if isinstance(image_size, str):
|
| 184 |
+
if image_size.startswith("<img_ratio_"):
|
| 185 |
+
ratio_index = int(image_size.split("_")[-1].rstrip(">"))
|
| 186 |
+
reso = self.vae_reso_group[ratio_index]
|
| 187 |
+
image_size = reso.height, reso.width
|
| 188 |
+
elif 'x' in image_size:
|
| 189 |
+
image_size = [int(s) for s in image_size.split('x')]
|
| 190 |
+
elif ':' in image_size:
|
| 191 |
+
image_size = [int(s) for s in image_size.split(':')]
|
| 192 |
+
assert len(image_size) == 2, f"`image_size` should be in the format of 'W:H', got {image_size}."
|
| 193 |
+
# Note that ratio is width:height
|
| 194 |
+
image_size = [image_size[1], image_size[0]]
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"`image_size` should be in the format of 'HxW', 'W:H' or <img_ratio_i>, got {image_size}.")
|
| 198 |
+
assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}."
|
| 199 |
+
elif isinstance(image_size, (list, tuple)):
|
| 200 |
+
assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \
|
| 201 |
+
f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}."
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', "
|
| 204 |
+
f"got {image_size}.")
|
| 205 |
+
image_width, image_height = self.vae_reso_group.get_target_size(image_size[1], image_size[0])
|
| 206 |
+
token_height = image_height // self.vae_info.h_factor
|
| 207 |
+
token_width = image_width // self.vae_info.w_factor
|
| 208 |
+
base_size, ratio_idx = self.vae_reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0])
|
| 209 |
+
image_info = ImageInfo(
|
| 210 |
+
image_type="gen_image", image_width=image_width, image_height=image_height,
|
| 211 |
+
token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx,
|
| 212 |
+
add_guidance_token=add_guidance_token, add_timestep_r_token=add_timestep_r_token,
|
| 213 |
+
)
|
| 214 |
+
return image_info
|
| 215 |
+
|
| 216 |
+
def as_image_tensor(self, image, image_type, **kwargs) -> ImageTensor:
|
| 217 |
+
if isinstance(image, Image.Image):
|
| 218 |
+
tensor = self.pil_image_to_tensor(image)
|
| 219 |
+
else:
|
| 220 |
+
tensor = image
|
| 221 |
+
|
| 222 |
+
origin_size = kwargs["origin_size"]
|
| 223 |
+
ori_image_width = origin_size[0]
|
| 224 |
+
ori_image_height = origin_size[1]
|
| 225 |
+
|
| 226 |
+
if image_type == "vae":
|
| 227 |
+
assert tensor.ndim == 3 or tensor.ndim == 4
|
| 228 |
+
h, w = tensor.shape[-2], tensor.shape[-1]
|
| 229 |
+
assert (h % self.vae_info.h_factor == 0 and w % self.vae_info.w_factor == 0), \
|
| 230 |
+
(f"Image size should be divisible by ({self.vae_info.h_factor}, {self.vae_info.w_factor}), "
|
| 231 |
+
f"but got ({h} x {w}).")
|
| 232 |
+
tk_height = h // self.vae_info.h_factor
|
| 233 |
+
tk_width = w // self.vae_info.w_factor
|
| 234 |
+
base_size, ratio_idx = self.vae_reso_group.get_base_size_and_ratio_index(w, h)
|
| 235 |
+
tensor.i = ImageInfo(
|
| 236 |
+
image_type=image_type,
|
| 237 |
+
image_width=w, image_height=h, token_width=tk_width, token_height=tk_height,
|
| 238 |
+
base_size=base_size, ratio_index=ratio_idx,
|
| 239 |
+
ori_image_width=ori_image_width,
|
| 240 |
+
ori_image_height=ori_image_height,
|
| 241 |
+
)
|
| 242 |
+
tensor.section_type = "cond_vae_image"
|
| 243 |
+
elif image_type == "siglip2":
|
| 244 |
+
spatial_shapes = kwargs["spatial_shapes"] # 2 (h, w)
|
| 245 |
+
pixel_attention_mask = kwargs["pixel_attention_mask"] # seq_len
|
| 246 |
+
tensor.i = ImageInfo(
|
| 247 |
+
image_type=image_type,
|
| 248 |
+
image_width=spatial_shapes[1].item() * self.vit_info.w_factor,
|
| 249 |
+
image_height=spatial_shapes[0].item() * self.vit_info.h_factor,
|
| 250 |
+
token_width=spatial_shapes[1].item(),
|
| 251 |
+
token_height=spatial_shapes[0].item(),
|
| 252 |
+
image_token_length=self.vit_info.max_token_length,
|
| 253 |
+
ori_image_width=ori_image_width,
|
| 254 |
+
ori_image_height=ori_image_height,
|
| 255 |
+
)
|
| 256 |
+
tensor.section_type = "cond_vit_image"
|
| 257 |
+
tensor.vision_encoder_kwargs = {
|
| 258 |
+
"spatial_shapes": spatial_shapes,
|
| 259 |
+
"pixel_attention_mask": pixel_attention_mask,
|
| 260 |
+
}
|
| 261 |
+
elif image_type == "anyres":
|
| 262 |
+
token_width = kwargs["resized_image_width"] // self.vit_info.w_factor
|
| 263 |
+
token_height = kwargs["resized_image_height"] // self.vit_info.h_factor
|
| 264 |
+
tensor.i = ImageInfo(
|
| 265 |
+
image_type=image_type,
|
| 266 |
+
image_width=kwargs["resized_image_width"],
|
| 267 |
+
image_height=kwargs["resized_image_height"],
|
| 268 |
+
token_width=token_width,
|
| 269 |
+
token_height=token_height,
|
| 270 |
+
image_token_length=token_height * (token_width + 1) + 2,
|
| 271 |
+
)
|
| 272 |
+
tensor.section_type = "cond_vit_image"
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"Unknown image type: {image_type}")
|
| 275 |
+
return tensor
|
| 276 |
+
|
| 277 |
+
def vae_process_image(self, image, target_size, random_crop: bool | str = False) -> ImageTensor:
|
| 278 |
+
origin_size = image.size
|
| 279 |
+
crop_type = random_crop if isinstance(random_crop, str) else ("random" if random_crop else "center")
|
| 280 |
+
resized_image = resize_and_crop(image, target_size, crop_type=crop_type)
|
| 281 |
+
return self.as_image_tensor(resized_image, image_type=self.vae_info.image_type, origin_size=origin_size)
|
| 282 |
+
|
| 283 |
+
def vit_process_image(self, image) -> ImageTensor:
|
| 284 |
+
origin_size = image.size
|
| 285 |
+
inputs = self.vit_info.processor(image)
|
| 286 |
+
image = inputs["pixel_values"].squeeze(0) # (seq_len, dim)
|
| 287 |
+
|
| 288 |
+
remain_keys = set(inputs.keys()) - {"pixel_values"}
|
| 289 |
+
remain_kwargs = {}
|
| 290 |
+
for key in remain_keys:
|
| 291 |
+
if isinstance(inputs[key], torch.Tensor):
|
| 292 |
+
remain_kwargs[key] = inputs[key].squeeze(0)
|
| 293 |
+
else:
|
| 294 |
+
remain_kwargs[key] = inputs[key]
|
| 295 |
+
|
| 296 |
+
return self.as_image_tensor(image, image_type=self.vit_info.image_type, origin_size=origin_size, **remain_kwargs)
|
| 297 |
+
|
| 298 |
+
def get_image_with_size(
|
| 299 |
+
self,
|
| 300 |
+
src: InputImage,
|
| 301 |
+
random_crop: bool | str = False,
|
| 302 |
+
return_type: str = "vae",
|
| 303 |
+
) -> tuple[ImageTensor | CondImage, bool]:
|
| 304 |
+
""" For various image generation tasks, dynamic image sizes """
|
| 305 |
+
image = load_image(src)
|
| 306 |
+
image_flag = "normal"
|
| 307 |
+
img_success = image_flag != "gray"
|
| 308 |
+
origin_size = image.size # (w_ori, h_ori)
|
| 309 |
+
|
| 310 |
+
if "vae" in return_type:
|
| 311 |
+
target_size = self.vae_reso_group.get_target_size(*origin_size)
|
| 312 |
+
vae_image_tensor = self.vae_process_image(image, target_size, random_crop=random_crop)
|
| 313 |
+
else:
|
| 314 |
+
vae_image_tensor = None
|
| 315 |
+
|
| 316 |
+
if "vit" in return_type:
|
| 317 |
+
vit_image_tensor = self.vit_process_image(image)
|
| 318 |
+
else:
|
| 319 |
+
vit_image_tensor = None
|
| 320 |
+
|
| 321 |
+
if return_type == "vae":
|
| 322 |
+
image_tensor = vae_image_tensor
|
| 323 |
+
elif return_type == "vit":
|
| 324 |
+
image_tensor = vit_image_tensor
|
| 325 |
+
elif return_type == "vae_vit":
|
| 326 |
+
image_tensor = CondImage(image_type=return_type, vae_image=vae_image_tensor, vit_image=vit_image_tensor)
|
| 327 |
+
else:
|
| 328 |
+
raise ValueError(f"Unknown return_type: {return_type}")
|
| 329 |
+
|
| 330 |
+
return image_tensor, img_success
|
| 331 |
+
|
| 332 |
+
def build_cond_images(
|
| 333 |
+
self,
|
| 334 |
+
image_list: Optional[list[InputImage]] = None,
|
| 335 |
+
message_list: Optional[list[dict[str, Any]]] = None,
|
| 336 |
+
infer_align_image_size: bool = False,
|
| 337 |
+
) -> Optional[list[CondImage]]:
|
| 338 |
+
if image_list is not None and message_list is not None:
|
| 339 |
+
raise ValueError("`image_list` and `message_list` cannot be provided at the same time.")
|
| 340 |
+
if message_list is not None:
|
| 341 |
+
image_list = []
|
| 342 |
+
for message in message_list:
|
| 343 |
+
visuals = [
|
| 344 |
+
content
|
| 345 |
+
for content in message["content"]
|
| 346 |
+
if isinstance(content, dict) and content["type"] in ["image"]
|
| 347 |
+
]
|
| 348 |
+
image_list.extend([
|
| 349 |
+
vision_info[key]
|
| 350 |
+
for vision_info in visuals
|
| 351 |
+
for key in ["image", "url", "path", "base64"]
|
| 352 |
+
if key in vision_info and vision_info["type"] == "image"
|
| 353 |
+
])
|
| 354 |
+
|
| 355 |
+
if infer_align_image_size:
|
| 356 |
+
random_crop = "resize"
|
| 357 |
+
else:
|
| 358 |
+
random_crop = "center"
|
| 359 |
+
|
| 360 |
+
return [
|
| 361 |
+
self.get_image_with_size(src, return_type=self.cond_image_type, random_crop=random_crop)[0]
|
| 362 |
+
for src in image_list
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
def prepare_full_attn_slices(self, output, batch_idx=None, with_gen=True):
|
| 366 |
+
""" Determine full attention image slices according to strategies. """
|
| 367 |
+
if self.cond_image_type == "vae":
|
| 368 |
+
cond_choices = dict(
|
| 369 |
+
causal=[],
|
| 370 |
+
full=output.vae_image_slices[batch_idx] if batch_idx is not None else output.vae_image_slices
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
elif self.cond_image_type == "vit":
|
| 374 |
+
cond_choices = dict(
|
| 375 |
+
causal=[],
|
| 376 |
+
full=output.vit_image_slices[batch_idx] if batch_idx is not None else output.vit_image_slices
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
elif self.cond_image_type == "vae_vit":
|
| 380 |
+
cond_choices = {
|
| 381 |
+
"causal": [],
|
| 382 |
+
"full": (
|
| 383 |
+
output.vae_image_slices[batch_idx] + output.vit_image_slices[batch_idx]
|
| 384 |
+
if batch_idx is not None
|
| 385 |
+
else output.vae_image_slices + output.vit_image_slices
|
| 386 |
+
),
|
| 387 |
+
"joint_full": (
|
| 388 |
+
output.joint_image_slices[batch_idx]
|
| 389 |
+
if batch_idx is not None
|
| 390 |
+
else output.joint_image_slices
|
| 391 |
+
),
|
| 392 |
+
"full_causal": (
|
| 393 |
+
output.vae_image_slices[batch_idx]
|
| 394 |
+
if batch_idx is not None
|
| 395 |
+
else output.vae_image_slices
|
| 396 |
+
),
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Unknown cond_image_type: {self.cond_image_type}")
|
| 401 |
+
slices = cond_choices[self.cond_token_attn_type]
|
| 402 |
+
|
| 403 |
+
if with_gen:
|
| 404 |
+
gen_image_slices = (
|
| 405 |
+
output.gen_image_slices[batch_idx]
|
| 406 |
+
if batch_idx is not None
|
| 407 |
+
else output.gen_image_slices
|
| 408 |
+
)
|
| 409 |
+
slices = slices + gen_image_slices
|
| 410 |
+
return slices
|
| 411 |
+
|
| 412 |
+
def build_img_ratio_slice_logits_processor(self, tokenizer):
|
| 413 |
+
if self.img_ratio_slice_logits_processor is None:
|
| 414 |
+
self.img_ratio_slice_logits_processor = LogitsProcessorList()
|
| 415 |
+
self.img_ratio_slice_logits_processor.append(
|
| 416 |
+
SliceVocabLogitsProcessor(
|
| 417 |
+
vocab_start=tokenizer.start_ratio_token_id,
|
| 418 |
+
vocab_end=tokenizer.end_ratio_token_id + 1,
|
| 419 |
+
other_slices=getattr(tokenizer, "ratio_token_other_slices", []),
|
| 420 |
+
)
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
def postprocess_outputs(self, outputs: list[Image.Image], batch_cond_images, infer_align_image_size: bool = False):
|
| 424 |
+
if infer_align_image_size:
|
| 425 |
+
target_area = self.vae_reso_group.base_size ** 2
|
| 426 |
+
|
| 427 |
+
for batch_index, (output_image, cond_images) in enumerate(zip(outputs, batch_cond_images)):
|
| 428 |
+
output_image_ratio_index = self.vae_reso_group.get_base_size_and_ratio_index(width=output_image.width, height=output_image.height)[1]
|
| 429 |
+
cond_images_ratio_index_list = []
|
| 430 |
+
cond_images_ori_width_list = []
|
| 431 |
+
cond_images_ori_height_list = []
|
| 432 |
+
for cond_image in cond_images:
|
| 433 |
+
if isinstance(cond_image, ImageTensor):
|
| 434 |
+
cond_images_ratio_index_list.append(cond_image.i.ratio_index)
|
| 435 |
+
cond_images_ori_width_list.append(cond_image.i.ori_image_width)
|
| 436 |
+
cond_images_ori_height_list.append(cond_image.i.ori_image_height)
|
| 437 |
+
else: # CondImage
|
| 438 |
+
cond_images_ratio_index_list.append(cond_image.vae_image.i.ratio_index)
|
| 439 |
+
cond_images_ori_width_list.append(cond_image.vae_image.i.ori_image_width)
|
| 440 |
+
cond_images_ori_height_list.append(cond_image.vae_image.i.ori_image_height)
|
| 441 |
+
|
| 442 |
+
if len(cond_images) == 0:
|
| 443 |
+
continue
|
| 444 |
+
elif len(cond_images) == 1:
|
| 445 |
+
if output_image_ratio_index == cond_images_ratio_index_list[0]:
|
| 446 |
+
if abs(cond_images_ori_height_list[0] / cond_images_ori_width_list[0] - self.vae_reso_group[output_image_ratio_index].ratio) >= 0.01:
|
| 447 |
+
scale = math.sqrt(target_area / (cond_images_ori_width_list[0] * cond_images_ori_height_list[0]))
|
| 448 |
+
new_w = round(cond_images_ori_width_list[0] * scale)
|
| 449 |
+
new_h = round(cond_images_ori_height_list[0] * scale)
|
| 450 |
+
outputs[batch_index] = output_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS)
|
| 451 |
+
else:
|
| 452 |
+
for cond_image_ratio_index, cond_image_ori_width, cond_image_ori_height in zip(cond_images_ratio_index_list, cond_images_ori_width_list, cond_images_ori_height_list):
|
| 453 |
+
if output_image_ratio_index == cond_image_ratio_index:
|
| 454 |
+
if abs(cond_image_ori_height / cond_image_ori_width - self.vae_reso_group[output_image_ratio_index].ratio) >= 0.01:
|
| 455 |
+
scale = math.sqrt(target_area / (cond_image_ori_width * cond_image_ori_height))
|
| 456 |
+
new_w = round(cond_image_ori_width * scale)
|
| 457 |
+
new_h = round(cond_image_ori_height * scale)
|
| 458 |
+
outputs[batch_index] = output_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS)
|
| 459 |
+
break
|
| 460 |
+
|
| 461 |
+
return outputs
|
| 462 |
+
|
| 463 |
+
__all__ = [
|
| 464 |
+
"HunyuanImage3ImageProcessor"
|
| 465 |
+
]
|
model-0001-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2db6ab327b5a5a9ff2be48bc41fae98d7de01b0a29f1a5ecc88b079637bce016
|
| 3 |
+
size 5363066616
|
model-0002-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09538e24c7437751d2384dde73cf4e913dce1e67bfdf87b0b1933963dc117a41
|
| 3 |
+
size 5318937248
|
model-0003-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c268190bd3c0d57b05cd5e859d5dcce1b30df2ede2486396179b28b2517cf820
|
| 3 |
+
size 5344627472
|
model-0004-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:315220e3fcc1a02673670e63c1eb8d2a73970e17e5b787156902fd7f7258220d
|
| 3 |
+
size 5327343192
|
model-0005-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ad5d55a79c80186537367d8bfcee7de722f8b34c820391b656f14a5fed1b085
|
| 3 |
+
size 5344103080
|
model-0006-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:834abe96ce34acbbb77a72990058c6e324207db8dcd878a01752792dd1fb38b4
|
| 3 |
+
size 5318937248
|
model-0007-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3825bbee6d58000f357ea24b312040eb16a466d14bed5b44b89eeda07344a4fa
|
| 3 |
+
size 5344103088
|
model-0008-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec4108cb77f70a545335ba2108d7fb4bcb6f6831dd3080e5a6b330433e7de69f
|
| 3 |
+
size 5318937256
|
model-0009-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aee7924c4d27ba7fe722b2581d2a87775715085ad38016042b6c1b8c998afb7b
|
| 3 |
+
size 5344103088
|
model-0010-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:559b75c2da0baf8c0389a889990925d140dcee1f8a57ecd0b3cc6db4bb7013be
|
| 3 |
+
size 5318937304
|
model-0011-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09882d830666ae891b41023d4e7af0cc7083d731d198295411b2ebfbabedbc7c
|
| 3 |
+
size 5344103232
|
model-0012-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c304453394b2b30dc281b9f4e155b6709b682a8602994093240afefc49548bd7
|
| 3 |
+
size 5318937400
|
model-0013-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48d01ca1697b035d9dd0f4482ac17d740876ec82bdc511931f128978bdbdd5c3
|
| 3 |
+
size 5344103232
|
model-0014-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c472e079b15d85e0cd81ee57f34d7815834b0b43b5a8fbccc561285e6cafacb1
|
| 3 |
+
size 5318937400
|
model-0015-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:886837fb5ba9b75b6bf754c98563af1955ad9b27a8130e5480dd608638f03576
|
| 3 |
+
size 5344103232
|
model-0016-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa722844f3e88ff87973a385cc46dd4e5be2a820b1dd9be6957c152b62dd25b2
|
| 3 |
+
size 5318937400
|
model-0017-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a237fe2fad03479bd29f635e8ef7e4a1c9e354e29c2295ecea10fd43602f3e05
|
| 3 |
+
size 5344103224
|
model-0018-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:889f2bb13328f2085593dfe06328ce4be24d2379e294ef37869827a9f59b59a6
|
| 3 |
+
size 5327859080
|
model-0019-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb4dd0053a798204b952a28b103d8abecd35a40ad1702ae6a39f71aedbeef627
|
| 3 |
+
size 5344111888
|
model-0020-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7baf9d180d2e5ef99e35864440190d700f6378c33a8dbc1adc629b2a2ec263ca
|
| 3 |
+
size 5318937392
|
model-0021-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c419d5445acff5eb09a436e0a6236f7b833a4aac6b3f2316c044abe770ed58e
|
| 3 |
+
size 5344103232
|
model-0022-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20d6d65cf9208fb649daafb0e1511b526c0b2eeac0463ead6fa151e6bd8e3207
|
| 3 |
+
size 5318937400
|
model-0023-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c5b70acb118d07ab02683c3467e7636e8197cb3b8969220eb447c7dd97470bd
|
| 3 |
+
size 5344103232
|
model-0024-of-0032.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65d5830e1162a4bed3f4e8af24d8fcf22fbf9ceb8d0260e8a51ca14cf748d64f
|
| 3 |
+
size 5318937400
|