Instructions to use Snapmap/diffcheckstuffiused with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Snapmap/diffcheckstuffiused with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Snapmap/diffcheckstuffiused", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +7 -0
- LLM/Florence-2-base/model.safetensors +3 -0
- LLM/Florence-2-base/pytorch_model.bin +3 -0
- SEEDVR2/ema_vae_fp16.safetensors +3 -0
- assets/bpe_simple_vocab_16e6.txt.gz +3 -0
- checkpoints/qwen_image_fp8_hq.safetensors +3 -0
- depthcrafter/stabilityai_stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors +3 -0
- detection/vitpose_h_wholebody_model.onnx +3 -0
- detection/yolov10m.onnx +3 -0
- mediapipe/selfie_multiclass_256x256.tflite +3 -0
- sams/bpe_simple_vocab_16e6.txt.gz +3 -0
- ultralytics/bbox/adetailerFootYolov8x_v20.pt +3 -0
- ultralytics/bbox/face_yolov8m.pt +3 -0
- ultralytics/bbox/face_yolov8m[1].pt +3 -0
- unet/Z-Image/assets/DMDR.webp +3 -0
- unet/Z-Image/assets/Z-Image-Gallery.pdf +3 -0
- unet/Z-Image/assets/architecture.webp +3 -0
- unet/Z-Image/assets/decoupled-dmd.webp +3 -0
- unet/Z-Image/assets/image_arena_all.jpg +3 -0
- unet/Z-Image/assets/reasoning.png +3 -0
- unet/Z-Image/assets/showcase.jpg +3 -0
- unet/Z-Image/src/config/__init__.py +91 -0
- unet/Z-Image/src/config/inference.py +8 -0
- unet/Z-Image/src/config/manifests/z-image-turbo.txt +20 -0
- unet/Z-Image/src/config/model.py +45 -0
- unet/Z-Image/src/tools/__init__.py +9 -0
- unet/Z-Image/src/tools/generate_manifest.py +127 -0
- unet/Z-Image/src/utils/__init__.py +15 -0
- unet/Z-Image/src/utils/attention.py +516 -0
- unet/Z-Image/src/utils/helpers.py +260 -0
- unet/Z-Image/src/utils/import_utils.py +31 -0
- unet/Z-Image/src/utils/loader.py +224 -0
- unet/Z-Image/src/zimage/__init__.py +9 -0
- unet/Z-Image/src/zimage/autoencoder.py +369 -0
- unet/Z-Image/src/zimage/pipeline.py +293 -0
- unet/Z-Image/src/zimage/transformer.py +571 -0
- upscale_models/1x-ITF-SkinDiffDetail-Lite-v1.pth +3 -0
- upscale_models/1x_PureVision.pth +3 -0
- upscale_models/2x_PureVision.pth +3 -0
- upscale_models/4x-ClearRealityV1.pth +3 -0
- upscale_models/4x-UltraSharp.pth +3 -0
- upscale_models/4xFFHQDAT.safetensors +3 -0
- upscale_models/4xNomos8k_atd_jpg.pth +3 -0
- upscale_models/4xNomos8k_span_otf_weak.pth +3 -0
- upscale_models/4x_NMKD-Siax_200k.pth +3 -0
- upscale_models/4x_NMKD-Superscale-SP_178000_G.pth +3 -0
- upscale_models/4x_foolhardy_Remacri.pth +3 -0
- upscale_models/RealESRGAN_x4plus.pth +3 -0
- vae_approx/taew2_1.pth +3 -0
- vitmatte/model.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
unet/Z-Image/assets/Z-Image-Gallery.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
unet/Z-Image/assets/image_arena_all.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
unet/Z-Image/assets/decoupled-dmd.webp filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
unet/Z-Image/assets/architecture.webp filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
unet/Z-Image/assets/reasoning.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
unet/Z-Image/assets/DMDR.webp filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
unet/Z-Image/assets/showcase.jpg filter=lfs diff=lfs merge=lfs -text
|
LLM/Florence-2-base/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03075d2d2d2bbd3e180b9ba0afae4aa8563226e2d32911656966e05b2f2ee060
|
| 3 |
+
size 463221266
|
LLM/Florence-2-base/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b480ac374593b0dcb18ffa63b23213734e04cd43eab0d620d23e39708d4a4a7e
|
| 3 |
+
size 464421827
|
SEEDVR2/ema_vae_fp16.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20678548f420d98d26f11442d3528f8b8c94e57ee046ef93dbb7633da8612ca1
|
| 3 |
+
size 501324814
|
assets/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
checkpoints/qwen_image_fp8_hq.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a61c42a58c181813fd94748df62ca1cdb53d0ec4b32c34af09375e5309126fa
|
| 3 |
+
size 89460
|
depthcrafter/stabilityai_stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af602cd0eb4ad6086ec94fbf1438dfb1be5ec9ac03fd0215640854e90d6463a3
|
| 3 |
+
size 195531910
|
detection/vitpose_h_wholebody_model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f21466cd6c93d0066782ad5923c14a4e6569133def212dc2895c73596c2e553b
|
| 3 |
+
size 420252
|
detection/yolov10m.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89b526498a6d55f869a6ab52e3a2eb20ad45b3711c1f7de3dd9ca0b399dfd6d7
|
| 3 |
+
size 61659339
|
mediapipe/selfie_multiclass_256x256.tflite
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
|
| 3 |
+
size 16371837
|
sams/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|
ultralytics/bbox/adetailerFootYolov8x_v20.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f39f32ab83b43002ca466605603b6c6dcff124ecfb23dab1c74c36ecb95cb4b
|
| 3 |
+
size 136712062
|
ultralytics/bbox/face_yolov8m.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:717923c19b3f4bbf5250b728f1fa6b2cb72a33aed1d236ea9caf0e21ad943e5f
|
| 3 |
+
size 52026019
|
ultralytics/bbox/face_yolov8m[1].pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:717923c19b3f4bbf5250b728f1fa6b2cb72a33aed1d236ea9caf0e21ad943e5f
|
| 3 |
+
size 52026019
|
unet/Z-Image/assets/DMDR.webp
ADDED
|
Git LFS Details
|
unet/Z-Image/assets/Z-Image-Gallery.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f9895b3246d2547bac74bbe0be975da500eaae93f2cad4248ad3281786b1ac6
|
| 3 |
+
size 15767436
|
unet/Z-Image/assets/architecture.webp
ADDED
|
Git LFS Details
|
unet/Z-Image/assets/decoupled-dmd.webp
ADDED
|
Git LFS Details
|
unet/Z-Image/assets/image_arena_all.jpg
ADDED
|
Git LFS Details
|
unet/Z-Image/assets/reasoning.png
ADDED
|
Git LFS Details
|
unet/Z-Image/assets/showcase.jpg
ADDED
|
Git LFS Details
|
unet/Z-Image/src/config/__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Z-Image Configuration."""
|
| 2 |
+
|
| 3 |
+
from .inference import (
|
| 4 |
+
DEFAULT_CFG_TRUNCATION,
|
| 5 |
+
DEFAULT_GUIDANCE_SCALE,
|
| 6 |
+
DEFAULT_HEIGHT,
|
| 7 |
+
DEFAULT_INFERENCE_STEPS,
|
| 8 |
+
DEFAULT_MAX_SEQUENCE_LENGTH,
|
| 9 |
+
DEFAULT_WIDTH,
|
| 10 |
+
)
|
| 11 |
+
from .model import (
|
| 12 |
+
ADALN_EMBED_DIM,
|
| 13 |
+
BASE_IMAGE_SEQ_LEN,
|
| 14 |
+
BASE_SHIFT,
|
| 15 |
+
BYTES_PER_GB,
|
| 16 |
+
DEFAULT_LOAD_DEVICE,
|
| 17 |
+
DEFAULT_LOAD_DTYPE_STR,
|
| 18 |
+
DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
|
| 19 |
+
DEFAULT_SCHEDULER_SHIFT,
|
| 20 |
+
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
|
| 21 |
+
DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
|
| 22 |
+
DEFAULT_TRANSFORMER_DIM,
|
| 23 |
+
DEFAULT_TRANSFORMER_F_PATCH_SIZE,
|
| 24 |
+
DEFAULT_TRANSFORMER_IN_CHANNELS,
|
| 25 |
+
DEFAULT_TRANSFORMER_N_HEADS,
|
| 26 |
+
DEFAULT_TRANSFORMER_N_KV_HEADS,
|
| 27 |
+
DEFAULT_TRANSFORMER_N_LAYERS,
|
| 28 |
+
DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
|
| 29 |
+
DEFAULT_TRANSFORMER_NORM_EPS,
|
| 30 |
+
DEFAULT_TRANSFORMER_PATCH_SIZE,
|
| 31 |
+
DEFAULT_TRANSFORMER_QK_NORM,
|
| 32 |
+
DEFAULT_TRANSFORMER_T_SCALE,
|
| 33 |
+
DEFAULT_VAE_IN_CHANNELS,
|
| 34 |
+
DEFAULT_VAE_LATENT_CHANNELS,
|
| 35 |
+
DEFAULT_VAE_NORM_NUM_GROUPS,
|
| 36 |
+
DEFAULT_VAE_OUT_CHANNELS,
|
| 37 |
+
DEFAULT_VAE_SCALE_FACTOR,
|
| 38 |
+
DEFAULT_VAE_SCALING_FACTOR,
|
| 39 |
+
FREQUENCY_EMBEDDING_SIZE,
|
| 40 |
+
MAX_IMAGE_SEQ_LEN,
|
| 41 |
+
MAX_PERIOD,
|
| 42 |
+
MAX_SHIFT,
|
| 43 |
+
ROPE_AXES_DIMS,
|
| 44 |
+
ROPE_AXES_LENS,
|
| 45 |
+
ROPE_THETA,
|
| 46 |
+
SEQ_MULTI_OF,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
__all__ = [
|
| 50 |
+
"ADALN_EMBED_DIM",
|
| 51 |
+
"SEQ_MULTI_OF",
|
| 52 |
+
"ROPE_THETA",
|
| 53 |
+
"ROPE_AXES_DIMS",
|
| 54 |
+
"ROPE_AXES_LENS",
|
| 55 |
+
"FREQUENCY_EMBEDDING_SIZE",
|
| 56 |
+
"MAX_PERIOD",
|
| 57 |
+
"BASE_IMAGE_SEQ_LEN",
|
| 58 |
+
"MAX_IMAGE_SEQ_LEN",
|
| 59 |
+
"BASE_SHIFT",
|
| 60 |
+
"MAX_SHIFT",
|
| 61 |
+
"DEFAULT_VAE_SCALE_FACTOR",
|
| 62 |
+
"DEFAULT_VAE_IN_CHANNELS",
|
| 63 |
+
"DEFAULT_VAE_OUT_CHANNELS",
|
| 64 |
+
"DEFAULT_VAE_LATENT_CHANNELS",
|
| 65 |
+
"DEFAULT_VAE_NORM_NUM_GROUPS",
|
| 66 |
+
"DEFAULT_VAE_SCALING_FACTOR",
|
| 67 |
+
"DEFAULT_TRANSFORMER_PATCH_SIZE",
|
| 68 |
+
"DEFAULT_TRANSFORMER_F_PATCH_SIZE",
|
| 69 |
+
"DEFAULT_TRANSFORMER_IN_CHANNELS",
|
| 70 |
+
"DEFAULT_TRANSFORMER_DIM",
|
| 71 |
+
"DEFAULT_TRANSFORMER_N_LAYERS",
|
| 72 |
+
"DEFAULT_TRANSFORMER_N_REFINER_LAYERS",
|
| 73 |
+
"DEFAULT_TRANSFORMER_N_HEADS",
|
| 74 |
+
"DEFAULT_TRANSFORMER_N_KV_HEADS",
|
| 75 |
+
"DEFAULT_TRANSFORMER_NORM_EPS",
|
| 76 |
+
"DEFAULT_TRANSFORMER_QK_NORM",
|
| 77 |
+
"DEFAULT_TRANSFORMER_CAP_FEAT_DIM",
|
| 78 |
+
"DEFAULT_TRANSFORMER_T_SCALE",
|
| 79 |
+
"DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS",
|
| 80 |
+
"DEFAULT_SCHEDULER_SHIFT",
|
| 81 |
+
"DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING",
|
| 82 |
+
"DEFAULT_LOAD_DEVICE",
|
| 83 |
+
"DEFAULT_LOAD_DTYPE_STR",
|
| 84 |
+
"BYTES_PER_GB",
|
| 85 |
+
"DEFAULT_HEIGHT",
|
| 86 |
+
"DEFAULT_WIDTH",
|
| 87 |
+
"DEFAULT_INFERENCE_STEPS",
|
| 88 |
+
"DEFAULT_GUIDANCE_SCALE",
|
| 89 |
+
"DEFAULT_CFG_TRUNCATION",
|
| 90 |
+
"DEFAULT_MAX_SEQUENCE_LENGTH",
|
| 91 |
+
]
|
unet/Z-Image/src/config/inference.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference-specific configuration for Z-Image."""
|
| 2 |
+
|
| 3 |
+
DEFAULT_HEIGHT = 1024
|
| 4 |
+
DEFAULT_WIDTH = 1024
|
| 5 |
+
DEFAULT_INFERENCE_STEPS = 8
|
| 6 |
+
DEFAULT_GUIDANCE_SCALE = 0.0
|
| 7 |
+
DEFAULT_CFG_TRUNCATION = 1.0
|
| 8 |
+
DEFAULT_MAX_SEQUENCE_LENGTH = 512
|
unet/Z-Image/src/config/manifests/z-image-turbo.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Z-Image Model Manifest
|
| 2 |
+
# Format: <md5hash> <filepath>
|
| 3 |
+
# Generated automatically - DO NOT edit manually
|
| 4 |
+
|
| 5 |
+
5e3226ed72a9a4a080f2a4ca78b98ddc model_index.json
|
| 6 |
+
ca682fcc6c5a94cf726b7187e64b9411 scheduler/scheduler_config.json
|
| 7 |
+
1e97eb35d9d0b6aa60c58a8df8d7d99a text_encoder/config.json
|
| 8 |
+
30b85686b9a9b002e012494fadc027cb text_encoder/model-00001-of-00003.safetensors
|
| 9 |
+
e6a24ea164404a01ad2800dbae4e1a13 text_encoder/model-00002-of-00003.safetensors
|
| 10 |
+
09e190ed15ff14795b6277e023cfcb2d text_encoder/model-00003-of-00003.safetensors
|
| 11 |
+
589f5395156900f49d617aee8a8d8708 text_encoder/model.safetensors.index.json
|
| 12 |
+
6423133b9cc1a2077b57822c30c211aa tokenizer/tokenizer.json
|
| 13 |
+
b06e103ac555ec4b51266078b518c0f0 tokenizer/tokenizer_config.json
|
| 14 |
+
baed87136fe5f848e24b072f99856cc3 transformer/config.json
|
| 15 |
+
54889d0dd179b4fa2fd7bd0e487d856e transformer/diffusion_pytorch_model-00001-of-00003.safetensors
|
| 16 |
+
fe81e804658d345323512c63224b0604 transformer/diffusion_pytorch_model-00002-of-00003.safetensors
|
| 17 |
+
4e074e09129f98ad840414951f122feb transformer/diffusion_pytorch_model-00003-of-00003.safetensors
|
| 18 |
+
76d788eb0d42c59cc8f8ec007db639aa transformer/diffusion_pytorch_model.safetensors.index.json
|
| 19 |
+
ba9e2980c8630b4abccc643bc9f4a542 vae/config.json
|
| 20 |
+
6f83de55cb720c7fae051b14528577bf vae/diffusion_pytorch_model.safetensors
|
unet/Z-Image/src/config/model.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model configuration constants for Z-Image."""
|
| 2 |
+
|
| 3 |
+
ADALN_EMBED_DIM = 256
|
| 4 |
+
SEQ_MULTI_OF = 32
|
| 5 |
+
|
| 6 |
+
ROPE_THETA = 256.0
|
| 7 |
+
ROPE_AXES_DIMS = [32, 48, 48]
|
| 8 |
+
ROPE_AXES_LENS = [1536, 512, 512]
|
| 9 |
+
|
| 10 |
+
FREQUENCY_EMBEDDING_SIZE = 256
|
| 11 |
+
MAX_PERIOD = 10000
|
| 12 |
+
|
| 13 |
+
BASE_IMAGE_SEQ_LEN = 256
|
| 14 |
+
MAX_IMAGE_SEQ_LEN = 4096
|
| 15 |
+
BASE_SHIFT = 0.5
|
| 16 |
+
MAX_SHIFT = 1.15
|
| 17 |
+
|
| 18 |
+
DEFAULT_VAE_SCALE_FACTOR = 8
|
| 19 |
+
DEFAULT_VAE_IN_CHANNELS = 3
|
| 20 |
+
DEFAULT_VAE_OUT_CHANNELS = 3
|
| 21 |
+
DEFAULT_VAE_LATENT_CHANNELS = 4
|
| 22 |
+
DEFAULT_VAE_NORM_NUM_GROUPS = 32
|
| 23 |
+
DEFAULT_VAE_SCALING_FACTOR = 0.18215
|
| 24 |
+
|
| 25 |
+
DEFAULT_TRANSFORMER_PATCH_SIZE = (2,)
|
| 26 |
+
DEFAULT_TRANSFORMER_F_PATCH_SIZE = (1,)
|
| 27 |
+
DEFAULT_TRANSFORMER_IN_CHANNELS = 16
|
| 28 |
+
DEFAULT_TRANSFORMER_DIM = 3840
|
| 29 |
+
DEFAULT_TRANSFORMER_N_LAYERS = 30
|
| 30 |
+
DEFAULT_TRANSFORMER_N_REFINER_LAYERS = 2
|
| 31 |
+
DEFAULT_TRANSFORMER_N_HEADS = 30
|
| 32 |
+
DEFAULT_TRANSFORMER_N_KV_HEADS = 30
|
| 33 |
+
DEFAULT_TRANSFORMER_NORM_EPS = 1e-5
|
| 34 |
+
DEFAULT_TRANSFORMER_QK_NORM = True
|
| 35 |
+
DEFAULT_TRANSFORMER_CAP_FEAT_DIM = 2560
|
| 36 |
+
DEFAULT_TRANSFORMER_T_SCALE = 1000.0
|
| 37 |
+
|
| 38 |
+
DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS = 1000
|
| 39 |
+
DEFAULT_SCHEDULER_SHIFT = 3.0
|
| 40 |
+
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING = False
|
| 41 |
+
|
| 42 |
+
DEFAULT_LOAD_DEVICE = "cuda"
|
| 43 |
+
DEFAULT_LOAD_DTYPE_STR = "bfloat16"
|
| 44 |
+
|
| 45 |
+
BYTES_PER_GB = 2**30
|
unet/Z-Image/src/tools/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools for Z-Image model management."""
|
| 2 |
+
|
| 3 |
+
from .generate_manifest import compute_md5, get_essential_files
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"compute_md5",
|
| 7 |
+
"get_essential_files",
|
| 8 |
+
]
|
| 9 |
+
|
unet/Z-Image/src/tools/generate_manifest.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate manifest file with MD5 checksums for model weights.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m tools.generate_manifest ckpts/Z-Image-Turbo
|
| 6 |
+
python -m tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums # Only list files
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import hashlib
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def compute_md5(file_path: Path, chunk_size: int = 8192) -> str:
|
| 16 |
+
"""Compute MD5 hash of a file."""
|
| 17 |
+
md5_hash = hashlib.md5()
|
| 18 |
+
with open(file_path, "rb") as f:
|
| 19 |
+
while chunk := f.read(chunk_size):
|
| 20 |
+
md5_hash.update(chunk)
|
| 21 |
+
return md5_hash.hexdigest()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_essential_files(model_dir: Path) -> List[Path]:
|
| 25 |
+
"""Get list of essential model files."""
|
| 26 |
+
essential_patterns = [
|
| 27 |
+
"model_index.json",
|
| 28 |
+
"transformer/config.json",
|
| 29 |
+
"transformer/*.safetensors*",
|
| 30 |
+
"vae/config.json",
|
| 31 |
+
"vae/*.safetensors",
|
| 32 |
+
"text_encoder/config.json",
|
| 33 |
+
"text_encoder/*.safetensors*",
|
| 34 |
+
"tokenizer/tokenizer.json",
|
| 35 |
+
"tokenizer/tokenizer_config.json",
|
| 36 |
+
"scheduler/scheduler_config.json",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
files = []
|
| 40 |
+
for pattern in essential_patterns:
|
| 41 |
+
if "*" in pattern:
|
| 42 |
+
files.extend(model_dir.glob(pattern))
|
| 43 |
+
else:
|
| 44 |
+
file_path = model_dir / pattern
|
| 45 |
+
if file_path.exists():
|
| 46 |
+
files.append(file_path)
|
| 47 |
+
|
| 48 |
+
return sorted(files)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main():
|
| 52 |
+
parser = argparse.ArgumentParser(description="Generate manifest file for model weights")
|
| 53 |
+
parser.add_argument("model_dir", type=str, help="Path to model directory")
|
| 54 |
+
parser.add_argument("--output", "-o", type=str, default=None,
|
| 55 |
+
help="Output manifest file path (default: auto-detect to config/manifests/)")
|
| 56 |
+
parser.add_argument("--no-checksums", action="store_true",
|
| 57 |
+
help="Only list files without computing checksums")
|
| 58 |
+
parser.add_argument("--verbose", "-v", action="store_true",
|
| 59 |
+
help="Print progress")
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
model_dir = Path(args.model_dir)
|
| 64 |
+
if not model_dir.exists():
|
| 65 |
+
print(f"Error: Model directory not found: {model_dir}")
|
| 66 |
+
return 1
|
| 67 |
+
|
| 68 |
+
# Determine output path
|
| 69 |
+
if args.output:
|
| 70 |
+
output_file = Path(args.output)
|
| 71 |
+
else:
|
| 72 |
+
# Auto-detect: save to config/manifests/{model-name}.txt
|
| 73 |
+
model_name = model_dir.name.lower() # e.g., "Z-Image-Turbo" -> "z-image-turbo"
|
| 74 |
+
script_dir = Path(__file__).parent
|
| 75 |
+
config_dir = script_dir.parent / "config" / "manifests"
|
| 76 |
+
config_dir.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
output_file = config_dir / f"{model_name}.txt"
|
| 78 |
+
|
| 79 |
+
# Get essential files
|
| 80 |
+
files = get_essential_files(model_dir)
|
| 81 |
+
|
| 82 |
+
if not files:
|
| 83 |
+
print(f"Warning: No essential files found in {model_dir}")
|
| 84 |
+
return 1
|
| 85 |
+
|
| 86 |
+
print(f"Found {len(files)} essential files")
|
| 87 |
+
|
| 88 |
+
# Generate manifest
|
| 89 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 90 |
+
f.write("# Z-Image Model Manifest\n")
|
| 91 |
+
if args.no_checksums:
|
| 92 |
+
f.write("# Format: <filepath>\n")
|
| 93 |
+
else:
|
| 94 |
+
f.write("# Format: <md5hash> <filepath>\n")
|
| 95 |
+
f.write("# Generated automatically - DO NOT edit manually\n\n")
|
| 96 |
+
|
| 97 |
+
for file_path in files:
|
| 98 |
+
rel_path = file_path.relative_to(model_dir)
|
| 99 |
+
|
| 100 |
+
if args.no_checksums:
|
| 101 |
+
f.write(f"{rel_path}\n")
|
| 102 |
+
if args.verbose:
|
| 103 |
+
print(f" {rel_path}")
|
| 104 |
+
else:
|
| 105 |
+
if args.verbose:
|
| 106 |
+
print(f"Computing MD5 for {rel_path}...", end=" ", flush=True)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
md5_hash = compute_md5(file_path)
|
| 110 |
+
f.write(f"{md5_hash} {rel_path}\n")
|
| 111 |
+
if args.verbose:
|
| 112 |
+
print(f"✓ {md5_hash}")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"✗ Error: {e}")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
print(f"\n✓ Manifest saved to: {output_file}")
|
| 118 |
+
print(f" Total files: {len(files)}")
|
| 119 |
+
if not args.no_checksums:
|
| 120 |
+
print(f" With MD5 checksums for integrity verification")
|
| 121 |
+
|
| 122 |
+
return 0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
exit(main())
|
| 127 |
+
|
unet/Z-Image/src/utils/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for Z-Image."""
|
| 2 |
+
|
| 3 |
+
from .attention import AttentionBackend, dispatch_attention, set_attention_backend
|
| 4 |
+
from .helpers import format_bytes, print_memory_stats, ensure_model_weights
|
| 5 |
+
from .loader import load_from_local_dir
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"load_from_local_dir",
|
| 9 |
+
"format_bytes",
|
| 10 |
+
"print_memory_stats",
|
| 11 |
+
"ensure_model_weights",
|
| 12 |
+
"AttentionBackend",
|
| 13 |
+
"set_attention_backend",
|
| 14 |
+
"dispatch_attention",
|
| 15 |
+
]
|
unet/Z-Image/src/utils/attention.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Attention backend utilities for Z-Image."""
|
| 2 |
+
|
| 3 |
+
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py
|
| 4 |
+
from enum import Enum
|
| 5 |
+
import functools
|
| 6 |
+
import inspect
|
| 7 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from .import_utils import is_flash_attn_3_available, is_flash_attn_available, is_torch_version
|
| 13 |
+
|
| 14 |
+
_CAN_USE_FLASH_ATTN_2 = is_flash_attn_available()
|
| 15 |
+
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
| 16 |
+
|
| 17 |
+
# MPS Flash Attention (Apple Silicon)
|
| 18 |
+
try:
|
| 19 |
+
import mps_flash_attn
|
| 20 |
+
_CAN_USE_MPS_FLASH = mps_flash_attn.is_available()
|
| 21 |
+
except ImportError:
|
| 22 |
+
_CAN_USE_MPS_FLASH = False
|
| 23 |
+
mps_flash_attn = None
|
| 24 |
+
_TORCH_VERSION_CHECK = is_torch_version(">=", "2.5.0") # have enable_gqa func call in SPDA
|
| 25 |
+
|
| 26 |
+
if not _TORCH_VERSION_CHECK:
|
| 27 |
+
raise RuntimeError("PyTorch version must be >= 2.5.0 to use this backend.")
|
| 28 |
+
else:
|
| 29 |
+
print("PyTorch version is >= 2.5.0, check pass.")
|
| 30 |
+
|
| 31 |
+
if _CAN_USE_FLASH_ATTN_2:
|
| 32 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 33 |
+
else:
|
| 34 |
+
flash_attn_func = None
|
| 35 |
+
flash_attn_varlen_func = None
|
| 36 |
+
|
| 37 |
+
if _CAN_USE_FLASH_ATTN_3:
|
| 38 |
+
from flash_attn_interface import (
|
| 39 |
+
flash_attn_func as flash_attn_3_func,
|
| 40 |
+
flash_attn_varlen_func as flash_attn_3_varlen_func,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_flash_attn_3_sig = inspect.signature(flash_attn_3_func)
|
| 44 |
+
_FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = "return_attn_probs" in _flash_attn_3_sig.parameters
|
| 45 |
+
else:
|
| 46 |
+
flash_attn_3_func = None
|
| 47 |
+
flash_attn_3_varlen_func = None
|
| 48 |
+
_FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AttentionBackend(str, Enum):
|
| 52 |
+
"""Supported attention backends."""
|
| 53 |
+
|
| 54 |
+
# Flash Attention
|
| 55 |
+
FLASH = "flash"
|
| 56 |
+
FLASH_VARLEN = "flash_varlen"
|
| 57 |
+
FLASH_3 = "_flash_3"
|
| 58 |
+
FLASH_VARLEN_3 = "_flash_varlen_3"
|
| 59 |
+
# MPS Flash Attention (Apple Silicon)
|
| 60 |
+
MPS_FLASH = "mps_flash"
|
| 61 |
+
# PyTorch Native Backends
|
| 62 |
+
NATIVE = "native"
|
| 63 |
+
NATIVE_FLASH = "_native_flash"
|
| 64 |
+
NATIVE_MATH = "_native_math"
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def print_available_backends(cls):
|
| 68 |
+
available_backends = [backend.value for backend in cls.__members__.values()]
|
| 69 |
+
print(f"Available attention backends list: {available_backends}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Registry for attention implementations
|
| 73 |
+
_ATTENTION_BACKENDS: Dict[str, Callable] = {}
|
| 74 |
+
_ATTENTION_CONSTRAINTS: Dict[str, List[Callable]] = {}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def register_backend(name: str, constraints: Optional[List[Callable]] = None):
|
| 78 |
+
def decorator(func):
|
| 79 |
+
_ATTENTION_BACKENDS[name] = func
|
| 80 |
+
_ATTENTION_CONSTRAINTS[name] = constraints or []
|
| 81 |
+
return func
|
| 82 |
+
|
| 83 |
+
return decorator
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# --- Checks ---
|
| 87 |
+
def _check_device_cuda(query: torch.Tensor, **kwargs) -> None:
|
| 88 |
+
if query.device.type != "cuda":
|
| 89 |
+
raise ValueError("Query must be on a CUDA device.")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, **kwargs) -> None:
|
| 93 |
+
if query.dtype not in (torch.bfloat16, torch.float16):
|
| 94 |
+
raise ValueError("Query must be either bfloat16 or float16.")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _check_device_mps(query: torch.Tensor, **kwargs) -> None:
|
| 98 |
+
if query.device.type != "mps":
|
| 99 |
+
raise ValueError("Query must be on MPS device.")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
|
| 103 |
+
if attn_mask is None:
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
if attn_mask.ndim == 2:
|
| 107 |
+
attn_mask = attn_mask[:, None, None, :]
|
| 108 |
+
|
| 109 |
+
# Convert bool mask to float additive mask
|
| 110 |
+
if attn_mask.dtype == torch.bool:
|
| 111 |
+
# NOTE: We skip checking for all-True mask (torch.all) to avoid graph breaks in torch.compile
|
| 112 |
+
new_mask = torch.zeros_like(attn_mask, dtype=dtype)
|
| 113 |
+
new_mask.masked_fill_(~attn_mask, float("-inf"))
|
| 114 |
+
return new_mask
|
| 115 |
+
|
| 116 |
+
return attn_mask
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
|
| 120 |
+
"""Normalize an attention mask to shape [batch_size, seq_len_k] (bool)."""
|
| 121 |
+
if attn_mask.dtype != torch.bool:
|
| 122 |
+
# Try to convert float mask back to bool if possible, or assume it's float mask
|
| 123 |
+
# For varlen flash attn, we strictly need bool mask indicating valid tokens
|
| 124 |
+
if torch.is_floating_point(attn_mask):
|
| 125 |
+
return attn_mask > -1 # Assuming -inf is masked
|
| 126 |
+
# raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
|
| 127 |
+
|
| 128 |
+
if attn_mask.ndim == 1:
|
| 129 |
+
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
|
| 130 |
+
elif attn_mask.ndim == 2:
|
| 131 |
+
if attn_mask.size(0) not in [1, batch_size]:
|
| 132 |
+
attn_mask = attn_mask.expand(batch_size, seq_len_k)
|
| 133 |
+
elif attn_mask.ndim == 3:
|
| 134 |
+
attn_mask = attn_mask.any(dim=1)
|
| 135 |
+
attn_mask = attn_mask.expand(batch_size, seq_len_k)
|
| 136 |
+
elif attn_mask.ndim == 4:
|
| 137 |
+
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k)
|
| 138 |
+
attn_mask = attn_mask.any(dim=(1, 2))
|
| 139 |
+
|
| 140 |
+
if attn_mask.shape != (batch_size, seq_len_k):
|
| 141 |
+
# Fallback reshape
|
| 142 |
+
return attn_mask.view(batch_size, seq_len_k)
|
| 143 |
+
|
| 144 |
+
return attn_mask
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@functools.lru_cache(maxsize=128)
|
| 148 |
+
def _prepare_for_flash_attn_varlen_without_mask(
|
| 149 |
+
batch_size: int,
|
| 150 |
+
seq_len_q: int,
|
| 151 |
+
seq_len_kv: int,
|
| 152 |
+
device: Optional[torch.device] = None,
|
| 153 |
+
):
|
| 154 |
+
# Optimized to avoid Inductor "pointless_cumsum_replacement" crash and remove graph breaks
|
| 155 |
+
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
| 156 |
+
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
| 157 |
+
|
| 158 |
+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
|
| 159 |
+
cu_seqlens_k = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_kv
|
| 160 |
+
|
| 161 |
+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (seq_len_q, seq_len_kv)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _prepare_for_flash_attn_varlen_with_mask(
|
| 165 |
+
batch_size: int,
|
| 166 |
+
seq_len_q: int,
|
| 167 |
+
attn_mask: torch.Tensor,
|
| 168 |
+
device: Optional[torch.device] = None,
|
| 169 |
+
):
|
| 170 |
+
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
| 171 |
+
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
|
| 172 |
+
# Use arange for Q to avoid Inductor crash
|
| 173 |
+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
|
| 174 |
+
|
| 175 |
+
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
| 176 |
+
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
| 177 |
+
|
| 178 |
+
max_seqlen_q = seq_len_q
|
| 179 |
+
max_seqlen_k = attn_mask.shape[1] # not max().item(), static shape to avoid graph break
|
| 180 |
+
|
| 181 |
+
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _prepare_for_flash_attn_varlen(
|
| 185 |
+
batch_size: int,
|
| 186 |
+
seq_len_q: int,
|
| 187 |
+
seq_len_kv: int,
|
| 188 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 189 |
+
device: Optional[torch.device] = None,
|
| 190 |
+
) -> None:
|
| 191 |
+
if attn_mask is None:
|
| 192 |
+
return _prepare_for_flash_attn_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
|
| 193 |
+
return _prepare_for_flash_attn_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@register_backend(AttentionBackend.FLASH, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
|
| 197 |
+
def _flash_attention(
|
| 198 |
+
query: torch.Tensor,
|
| 199 |
+
key: torch.Tensor,
|
| 200 |
+
value: torch.Tensor,
|
| 201 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 202 |
+
dropout_p: float = 0.0,
|
| 203 |
+
is_causal: bool = False,
|
| 204 |
+
scale: Optional[float] = None,
|
| 205 |
+
) -> torch.Tensor:
|
| 206 |
+
if not _CAN_USE_FLASH_ATTN_2:
|
| 207 |
+
raise RuntimeError(
|
| 208 |
+
f"Flash Attention backend '{AttentionBackend.FLASH}' is not usable because of missing package."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
out = flash_attn_func(
|
| 212 |
+
q=query,
|
| 213 |
+
k=key,
|
| 214 |
+
v=value,
|
| 215 |
+
dropout_p=dropout_p,
|
| 216 |
+
softmax_scale=scale,
|
| 217 |
+
causal=is_causal,
|
| 218 |
+
)
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@register_backend(AttentionBackend.FLASH_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
|
| 223 |
+
def _flash_varlen_attention(
|
| 224 |
+
query: torch.Tensor,
|
| 225 |
+
key: torch.Tensor,
|
| 226 |
+
value: torch.Tensor,
|
| 227 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 228 |
+
dropout_p: float = 0.0,
|
| 229 |
+
is_causal: bool = False,
|
| 230 |
+
scale: Optional[float] = None,
|
| 231 |
+
) -> torch.Tensor:
|
| 232 |
+
if not _CAN_USE_FLASH_ATTN_2:
|
| 233 |
+
raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN}' requires flash-attn.")
|
| 234 |
+
|
| 235 |
+
batch_size, seq_len_q, _, _ = query.shape
|
| 236 |
+
_, seq_len_kv, _, _ = key.shape
|
| 237 |
+
|
| 238 |
+
if attn_mask is not None:
|
| 239 |
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
| 240 |
+
|
| 241 |
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
|
| 242 |
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
query_packed = query.flatten(0, 1)
|
| 246 |
+
|
| 247 |
+
if attn_mask is not None:
|
| 248 |
+
key_valid = []
|
| 249 |
+
value_valid = []
|
| 250 |
+
for b in range(batch_size):
|
| 251 |
+
valid_len = seqlens_k[b]
|
| 252 |
+
key_valid.append(key[b, :valid_len])
|
| 253 |
+
value_valid.append(value[b, :valid_len])
|
| 254 |
+
key_packed = torch.cat(key_valid, dim=0)
|
| 255 |
+
value_packed = torch.cat(value_valid, dim=0)
|
| 256 |
+
else:
|
| 257 |
+
key_packed = key.flatten(0, 1)
|
| 258 |
+
value_packed = value.flatten(0, 1)
|
| 259 |
+
|
| 260 |
+
out = flash_attn_varlen_func(
|
| 261 |
+
q=query_packed,
|
| 262 |
+
k=key_packed,
|
| 263 |
+
v=value_packed,
|
| 264 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 265 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 266 |
+
max_seqlen_q=max_seqlen_q,
|
| 267 |
+
max_seqlen_k=max_seqlen_k,
|
| 268 |
+
dropout_p=dropout_p,
|
| 269 |
+
softmax_scale=scale,
|
| 270 |
+
causal=is_causal,
|
| 271 |
+
)
|
| 272 |
+
out = out.unflatten(0, (batch_size, -1))
|
| 273 |
+
return out
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@register_backend(AttentionBackend.FLASH_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
|
| 277 |
+
def _flash_attention_3(
|
| 278 |
+
query: torch.Tensor,
|
| 279 |
+
key: torch.Tensor,
|
| 280 |
+
value: torch.Tensor,
|
| 281 |
+
attn_mask: Optional[torch.Tensor] = None, # Unused in simple FA3 func
|
| 282 |
+
dropout_p: float = 0.0,
|
| 283 |
+
is_causal: bool = False,
|
| 284 |
+
scale: Optional[float] = None,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
if not _CAN_USE_FLASH_ATTN_3:
|
| 287 |
+
raise RuntimeError(f"Backend '{AttentionBackend.FLASH_3}' requires Flash Attention 3 beta.")
|
| 288 |
+
|
| 289 |
+
kwargs = {
|
| 290 |
+
"q": query,
|
| 291 |
+
"k": key,
|
| 292 |
+
"v": value,
|
| 293 |
+
"softmax_scale": scale,
|
| 294 |
+
"causal": is_causal,
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS:
|
| 298 |
+
kwargs["return_attn_probs"] = False
|
| 299 |
+
|
| 300 |
+
out = flash_attn_3_func(**kwargs)
|
| 301 |
+
|
| 302 |
+
if isinstance(out, tuple):
|
| 303 |
+
out = out[0]
|
| 304 |
+
|
| 305 |
+
return out
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@register_backend(AttentionBackend.FLASH_VARLEN_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
|
| 309 |
+
def _flash_varlen_attention_3(
|
| 310 |
+
query: torch.Tensor,
|
| 311 |
+
key: torch.Tensor,
|
| 312 |
+
value: torch.Tensor,
|
| 313 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 314 |
+
dropout_p: float = 0.0,
|
| 315 |
+
is_causal: bool = False,
|
| 316 |
+
scale: Optional[float] = None,
|
| 317 |
+
) -> torch.Tensor:
|
| 318 |
+
if not _CAN_USE_FLASH_ATTN_3:
|
| 319 |
+
raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN_3}' requires Flash Attention 3 beta.")
|
| 320 |
+
|
| 321 |
+
batch_size, seq_len_q, _, _ = query.shape
|
| 322 |
+
_, seq_len_kv, _, _ = key.shape
|
| 323 |
+
|
| 324 |
+
if attn_mask is not None:
|
| 325 |
+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
| 326 |
+
|
| 327 |
+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
|
| 328 |
+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
query_packed = query.flatten(0, 1)
|
| 332 |
+
|
| 333 |
+
if attn_mask is not None:
|
| 334 |
+
key_valid = []
|
| 335 |
+
value_valid = []
|
| 336 |
+
for b in range(batch_size):
|
| 337 |
+
valid_len = seqlens_k[b]
|
| 338 |
+
key_valid.append(key[b, :valid_len])
|
| 339 |
+
value_valid.append(value[b, :valid_len])
|
| 340 |
+
key_packed = torch.cat(key_valid, dim=0)
|
| 341 |
+
value_packed = torch.cat(value_valid, dim=0)
|
| 342 |
+
else:
|
| 343 |
+
key_packed = key.flatten(0, 1)
|
| 344 |
+
value_packed = value.flatten(0, 1)
|
| 345 |
+
|
| 346 |
+
kwargs = {
|
| 347 |
+
"q": query_packed,
|
| 348 |
+
"k": key_packed,
|
| 349 |
+
"v": value_packed,
|
| 350 |
+
"cu_seqlens_q": cu_seqlens_q,
|
| 351 |
+
"cu_seqlens_k": cu_seqlens_k,
|
| 352 |
+
"max_seqlen_q": max_seqlen_q,
|
| 353 |
+
"max_seqlen_k": max_seqlen_k,
|
| 354 |
+
"softmax_scale": scale,
|
| 355 |
+
"causal": is_causal,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
supports_return_probs = "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters
|
| 359 |
+
|
| 360 |
+
if supports_return_probs:
|
| 361 |
+
kwargs["return_attn_probs"] = False
|
| 362 |
+
|
| 363 |
+
out = flash_attn_3_varlen_func(**kwargs)
|
| 364 |
+
|
| 365 |
+
if isinstance(out, tuple):
|
| 366 |
+
out = out[0]
|
| 367 |
+
|
| 368 |
+
out = out.unflatten(0, (batch_size, -1))
|
| 369 |
+
return out
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@register_backend(AttentionBackend.MPS_FLASH, constraints=[_check_device_mps, _check_qkv_dtype_bf16_or_fp16])
|
| 373 |
+
def _mps_flash_attention(
|
| 374 |
+
query: torch.Tensor,
|
| 375 |
+
key: torch.Tensor,
|
| 376 |
+
value: torch.Tensor,
|
| 377 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 378 |
+
dropout_p: float = 0.0,
|
| 379 |
+
is_causal: bool = False,
|
| 380 |
+
scale: Optional[float] = None,
|
| 381 |
+
) -> torch.Tensor:
|
| 382 |
+
"""MPS Flash Attention for Apple Silicon (M1/M2/M3/M4)."""
|
| 383 |
+
if not _CAN_USE_MPS_FLASH:
|
| 384 |
+
raise RuntimeError(
|
| 385 |
+
f"MPS Flash Attention backend '{AttentionBackend.MPS_FLASH}' requires mps-flash-attn package. "
|
| 386 |
+
"Install with: pip install mps-flash-attn"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Convert from (B, S, H, D) to (B, H, S, D) for mps-flash-attn
|
| 390 |
+
query = query.transpose(1, 2)
|
| 391 |
+
key = key.transpose(1, 2)
|
| 392 |
+
value = value.transpose(1, 2)
|
| 393 |
+
|
| 394 |
+
# Convert mask to MFA format (bool, True = masked)
|
| 395 |
+
mfa_mask = None
|
| 396 |
+
if attn_mask is not None:
|
| 397 |
+
mfa_mask = mps_flash_attn.convert_mask(_process_mask(attn_mask, query.dtype))
|
| 398 |
+
|
| 399 |
+
out = mps_flash_attn.flash_attention(
|
| 400 |
+
query, key, value,
|
| 401 |
+
is_causal=is_causal,
|
| 402 |
+
scale=scale,
|
| 403 |
+
attn_mask=mfa_mask,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Convert back to (B, S, H, D)
|
| 407 |
+
return out.transpose(1, 2).contiguous()
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _native_attention_wrapper(
|
| 411 |
+
query: torch.Tensor,
|
| 412 |
+
key: torch.Tensor,
|
| 413 |
+
value: torch.Tensor,
|
| 414 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 415 |
+
dropout_p: float = 0.0,
|
| 416 |
+
is_causal: bool = False,
|
| 417 |
+
scale: Optional[float] = None,
|
| 418 |
+
backend_kernel=None,
|
| 419 |
+
) -> torch.Tensor:
|
| 420 |
+
|
| 421 |
+
query = query.transpose(1, 2)
|
| 422 |
+
key = key.transpose(1, 2)
|
| 423 |
+
value = value.transpose(1, 2)
|
| 424 |
+
attn_mask = _process_mask(attn_mask, query.dtype)
|
| 425 |
+
|
| 426 |
+
if backend_kernel is not None:
|
| 427 |
+
with torch.nn.attention.sdpa_kernel(backend_kernel):
|
| 428 |
+
out = F.scaled_dot_product_attention(
|
| 429 |
+
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
out = F.scaled_dot_product_attention(
|
| 433 |
+
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
return out.transpose(1, 2).contiguous()
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@register_backend(AttentionBackend.NATIVE_FLASH)
|
| 440 |
+
def _native_flash_attention(
|
| 441 |
+
query: torch.Tensor,
|
| 442 |
+
key: torch.Tensor,
|
| 443 |
+
value: torch.Tensor,
|
| 444 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 445 |
+
dropout_p: float = 0.0,
|
| 446 |
+
is_causal: bool = False,
|
| 447 |
+
scale: Optional[float] = None,
|
| 448 |
+
) -> torch.Tensor:
|
| 449 |
+
return _native_attention_wrapper(
|
| 450 |
+
query,
|
| 451 |
+
key,
|
| 452 |
+
value,
|
| 453 |
+
attn_mask=None,
|
| 454 |
+
dropout_p=dropout_p,
|
| 455 |
+
is_causal=is_causal,
|
| 456 |
+
scale=scale,
|
| 457 |
+
backend_kernel=torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
@register_backend(AttentionBackend.NATIVE_MATH)
|
| 462 |
+
def _math_attention(*args, **kwargs):
|
| 463 |
+
return _native_attention_wrapper(*args, **kwargs, backend_kernel=torch.nn.attention.SDPBackend.MATH)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
@register_backend(AttentionBackend.NATIVE)
|
| 467 |
+
def _native_attention(*args, **kwargs):
|
| 468 |
+
return _native_attention_wrapper(*args, **kwargs, backend_kernel=None)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def dispatch_attention(
|
| 472 |
+
query: torch.Tensor,
|
| 473 |
+
key: torch.Tensor,
|
| 474 |
+
value: torch.Tensor,
|
| 475 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 476 |
+
dropout_p: float = 0.0,
|
| 477 |
+
is_causal: bool = False,
|
| 478 |
+
scale: Optional[float] = None,
|
| 479 |
+
backend: Union[str, AttentionBackend, None] = None,
|
| 480 |
+
) -> torch.Tensor:
|
| 481 |
+
|
| 482 |
+
if isinstance(backend, AttentionBackend):
|
| 483 |
+
backend = backend.value
|
| 484 |
+
elif backend is None:
|
| 485 |
+
backend = AttentionBackend.NATIVE
|
| 486 |
+
else:
|
| 487 |
+
backend = str(backend)
|
| 488 |
+
|
| 489 |
+
# Explicit dispatch to avoid dynamo guard issues on global dict
|
| 490 |
+
if backend == AttentionBackend.FLASH:
|
| 491 |
+
return _flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 492 |
+
elif backend == AttentionBackend.FLASH_VARLEN:
|
| 493 |
+
return _flash_varlen_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 494 |
+
elif backend == AttentionBackend.FLASH_3:
|
| 495 |
+
return _flash_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 496 |
+
elif backend == AttentionBackend.FLASH_VARLEN_3:
|
| 497 |
+
return _flash_varlen_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 498 |
+
elif backend == AttentionBackend.MPS_FLASH:
|
| 499 |
+
return _mps_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 500 |
+
elif backend == AttentionBackend.NATIVE_FLASH:
|
| 501 |
+
return _native_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 502 |
+
elif backend == AttentionBackend.NATIVE_MATH:
|
| 503 |
+
return _math_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 504 |
+
else:
|
| 505 |
+
return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def set_attention_backend(backend: Union[str, AttentionBackend, None]):
|
| 509 |
+
try:
|
| 510 |
+
from zimage.transformer import ZImageAttention
|
| 511 |
+
|
| 512 |
+
if backend is not None:
|
| 513 |
+
backend = str(backend)
|
| 514 |
+
ZImageAttention._attention_backend = backend
|
| 515 |
+
except ImportError:
|
| 516 |
+
pass
|
unet/Z-Image/src/utils/helpers.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helper utilities for Z-Image."""
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional, List, Tuple, Dict
|
| 7 |
+
|
| 8 |
+
from loguru import logger
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from config import BYTES_PER_GB
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def format_bytes(size: float) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Format bytes to GB string.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
size: Size in bytes
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Formatted string in GB
|
| 23 |
+
"""
|
| 24 |
+
n = size / BYTES_PER_GB
|
| 25 |
+
return f"{n:.2f} GB"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def print_memory_stats(stage: str) -> None:
|
| 29 |
+
"""
|
| 30 |
+
Print CUDA memory statistics.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
stage: Description of current stage
|
| 34 |
+
"""
|
| 35 |
+
if not torch.cuda.is_available():
|
| 36 |
+
logger.warning("CUDA not available, skipping memory stats")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
torch.cuda.synchronize()
|
| 40 |
+
allocated = torch.cuda.max_memory_allocated()
|
| 41 |
+
reserved = torch.cuda.max_memory_reserved()
|
| 42 |
+
current_allocated = torch.cuda.memory_allocated()
|
| 43 |
+
current_reserved = torch.cuda.memory_reserved()
|
| 44 |
+
|
| 45 |
+
logger.info(f"[{stage}] Memory Stats:")
|
| 46 |
+
logger.info(f" Current Allocated: {format_bytes(current_allocated)}")
|
| 47 |
+
logger.info(f" Current Reserved: {format_bytes(current_reserved)}")
|
| 48 |
+
logger.info(f" Peak Allocated: {format_bytes(allocated)}")
|
| 49 |
+
logger.info(f" Peak Reserved: {format_bytes(reserved)}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compute_file_md5(file_path: Path, chunk_size: int = 8192) -> str:
|
| 53 |
+
"""Compute MD5 hash of a file."""
|
| 54 |
+
md5_hash = hashlib.md5()
|
| 55 |
+
with open(file_path, "rb") as f:
|
| 56 |
+
while chunk := f.read(chunk_size):
|
| 57 |
+
md5_hash.update(chunk)
|
| 58 |
+
return md5_hash.hexdigest()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_manifest(manifest_file: Path) -> Dict[str, Optional[str]]:
|
| 62 |
+
"""Load manifest file. Returns dict mapping file paths to MD5 hashes (or None)."""
|
| 63 |
+
manifest = {}
|
| 64 |
+
if not manifest_file.exists():
|
| 65 |
+
return manifest
|
| 66 |
+
|
| 67 |
+
with open(manifest_file, "r", encoding="utf-8") as f:
|
| 68 |
+
for line_num, line in enumerate(f, 1):
|
| 69 |
+
line = line.strip()
|
| 70 |
+
# Skip empty lines and comments
|
| 71 |
+
if not line or line.startswith("#"):
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
parts = line.split()
|
| 75 |
+
|
| 76 |
+
if len(parts) == 1:
|
| 77 |
+
# Only file path, no checksum
|
| 78 |
+
file_path = parts[0]
|
| 79 |
+
manifest[file_path] = None
|
| 80 |
+
elif len(parts) == 2:
|
| 81 |
+
# File path with checksum
|
| 82 |
+
if len(parts[0]) == 32 and all(c in '0123456789abcdef' for c in parts[0].lower()):
|
| 83 |
+
md5_hash, file_path = parts
|
| 84 |
+
else:
|
| 85 |
+
file_path, md5_hash = parts
|
| 86 |
+
manifest[file_path] = md5_hash
|
| 87 |
+
else:
|
| 88 |
+
logger.warning(f"Invalid manifest format at line {line_num}: {line}")
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
return manifest
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def verify_file_integrity(
|
| 95 |
+
base_dir: Path,
|
| 96 |
+
manifest: Dict[str, Optional[str]],
|
| 97 |
+
verify_checksums: bool = True
|
| 98 |
+
) -> Tuple[bool, List[str], List[str]]:
|
| 99 |
+
"""
|
| 100 |
+
Verify file integrity using a manifest.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
base_dir: Base directory for relative file paths
|
| 104 |
+
manifest: Dictionary of relative paths to MD5 hashes (None if no hash provided)
|
| 105 |
+
verify_checksums: If True, verify MD5 checksums when available; if False, only check existence
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Tuple of (all_valid: bool, missing_files: List[str], corrupted_files: List[str])
|
| 109 |
+
"""
|
| 110 |
+
missing = []
|
| 111 |
+
corrupted = []
|
| 112 |
+
|
| 113 |
+
for rel_path, expected_md5 in manifest.items():
|
| 114 |
+
file_path = base_dir / rel_path
|
| 115 |
+
|
| 116 |
+
if not file_path.exists():
|
| 117 |
+
missing.append(rel_path)
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# Only verify checksum if requested AND hash is available
|
| 121 |
+
if verify_checksums and expected_md5 is not None:
|
| 122 |
+
try:
|
| 123 |
+
actual_md5 = compute_file_md5(file_path)
|
| 124 |
+
if actual_md5 != expected_md5:
|
| 125 |
+
corrupted.append(rel_path)
|
| 126 |
+
logger.debug(f"Checksum mismatch for {rel_path}: expected {expected_md5}, got {actual_md5}")
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"Failed to compute checksum for {rel_path}: {e}")
|
| 129 |
+
corrupted.append(rel_path)
|
| 130 |
+
|
| 131 |
+
all_valid = len(missing) == 0 and len(corrupted) == 0
|
| 132 |
+
return all_valid, missing, corrupted
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def ensure_model_weights(
|
| 136 |
+
model_path: str,
|
| 137 |
+
repo_id: str = "Tongyi-MAI/Z-Image-Turbo",
|
| 138 |
+
verify: bool = False,
|
| 139 |
+
manifest_name: Optional[str] = None
|
| 140 |
+
) -> Path:
|
| 141 |
+
"""
|
| 142 |
+
Ensure model weights exist and optionally verify integrity.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
model_path: Path to model directory
|
| 146 |
+
repo_id: HuggingFace repo ID for download
|
| 147 |
+
verify: If True, verify MD5 checksums; if False, only check existence
|
| 148 |
+
manifest_name: Manifest file name in src/config/manifests/ (auto-detect if None)
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Path to validated model directory
|
| 152 |
+
"""
|
| 153 |
+
from huggingface_hub import snapshot_download
|
| 154 |
+
|
| 155 |
+
target_dir = Path(model_path)
|
| 156 |
+
|
| 157 |
+
# Determine manifest path
|
| 158 |
+
if manifest_name:
|
| 159 |
+
# Explicitly specified manifest from config/manifests/
|
| 160 |
+
manifest_path = Path(__file__).parent.parent / "config" / "manifests" / manifest_name
|
| 161 |
+
else:
|
| 162 |
+
# Auto-detect
|
| 163 |
+
model_name = target_dir.name.lower() # e.g., "Z-Image-Turbo" -> "z-image-turbo"
|
| 164 |
+
config_manifest = Path(__file__).parent.parent / "config" / "manifests" / f"{model_name}.txt"
|
| 165 |
+
|
| 166 |
+
if config_manifest.exists():
|
| 167 |
+
manifest_path = config_manifest
|
| 168 |
+
else:
|
| 169 |
+
# Fallback
|
| 170 |
+
manifest_path = target_dir / "manifest.txt"
|
| 171 |
+
|
| 172 |
+
manifest = load_manifest(manifest_path)
|
| 173 |
+
|
| 174 |
+
if not manifest:
|
| 175 |
+
logger.warning(f"Manifest file not found: {manifest_path}")
|
| 176 |
+
logger.warning("Skipping file verification (assuming model exists)")
|
| 177 |
+
if target_dir.exists():
|
| 178 |
+
logger.info(f"✓ Model directory exists: {target_dir}")
|
| 179 |
+
return target_dir
|
| 180 |
+
else:
|
| 181 |
+
logger.warning(f"Model directory not found: {target_dir}")
|
| 182 |
+
missing_files = ["entire model directory"]
|
| 183 |
+
corrupted_files = []
|
| 184 |
+
else:
|
| 185 |
+
# Count files with checksums
|
| 186 |
+
files_with_checksums = sum(1 for v in manifest.values() if v is not None)
|
| 187 |
+
|
| 188 |
+
if verify and files_with_checksums == 0:
|
| 189 |
+
logger.info(f"Verify requested but no checksums in manifest, only checking existence")
|
| 190 |
+
elif verify and files_with_checksums > 0:
|
| 191 |
+
logger.info(f"Verifying {files_with_checksums} file(s) with MD5 checksums...")
|
| 192 |
+
|
| 193 |
+
# Verify files
|
| 194 |
+
all_valid, missing_files, corrupted_files = verify_file_integrity(
|
| 195 |
+
target_dir, manifest, verify_checksums=verify
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if all_valid:
|
| 199 |
+
if verify and files_with_checksums > 0:
|
| 200 |
+
logger.success(f"✓ All files verified with MD5 checksums in {target_dir}")
|
| 201 |
+
else:
|
| 202 |
+
logger.info(f"✓ All {len(manifest)} required files exist in {target_dir}")
|
| 203 |
+
return target_dir
|
| 204 |
+
|
| 205 |
+
# Report missing and corrupted files
|
| 206 |
+
if missing_files:
|
| 207 |
+
logger.warning(f"Missing {len(missing_files)} file(s):")
|
| 208 |
+
for f in missing_files[:10]:
|
| 209 |
+
logger.warning(f" - {f}")
|
| 210 |
+
if len(missing_files) > 10:
|
| 211 |
+
logger.warning(f" ... and {len(missing_files) - 10} more")
|
| 212 |
+
|
| 213 |
+
if corrupted_files:
|
| 214 |
+
logger.error(f"Corrupted {len(corrupted_files)} file(s) (checksum mismatch):")
|
| 215 |
+
for f in corrupted_files[:10]:
|
| 216 |
+
logger.error(f" - {f}")
|
| 217 |
+
if len(corrupted_files) > 10:
|
| 218 |
+
logger.error(f" ... and {len(corrupted_files) - 10} more")
|
| 219 |
+
|
| 220 |
+
# Download model weights
|
| 221 |
+
logger.info(f"\nAttempting to download from {repo_id}...")
|
| 222 |
+
try:
|
| 223 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
snapshot_download(
|
| 225 |
+
repo_id=repo_id,
|
| 226 |
+
local_dir=str(target_dir),
|
| 227 |
+
local_dir_use_symlinks=False,
|
| 228 |
+
resume_download=True,
|
| 229 |
+
)
|
| 230 |
+
logger.success("✓ Download completed")
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"✗ Download failed: {e}")
|
| 233 |
+
logger.info(
|
| 234 |
+
f"\nIf you are offline, please manually download from:\n"
|
| 235 |
+
f" https://huggingface.co/{repo_id}\n"
|
| 236 |
+
f"and place in: {target_dir.absolute()}"
|
| 237 |
+
)
|
| 238 |
+
raise RuntimeError(f"Failed to download model weights: {e}")
|
| 239 |
+
|
| 240 |
+
# Verify after download
|
| 241 |
+
if manifest:
|
| 242 |
+
all_valid, missing_after, corrupted_after = verify_file_integrity(
|
| 243 |
+
target_dir, manifest, verify_checksums=verify
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if not all_valid:
|
| 247 |
+
error_msg = []
|
| 248 |
+
if missing_after:
|
| 249 |
+
error_msg.append(f"Still missing {len(missing_after)} file(s)")
|
| 250 |
+
if corrupted_after:
|
| 251 |
+
error_msg.append(f"Still corrupted {len(corrupted_after)} file(s)")
|
| 252 |
+
|
| 253 |
+
raise FileNotFoundError(
|
| 254 |
+
f"After download: {', '.join(error_msg)}\n"
|
| 255 |
+
f"Please verify the download or manually place files in:\n"
|
| 256 |
+
f" {target_dir.absolute()}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
logger.success("✓ All model weights validated successfully")
|
| 260 |
+
return target_dir
|
unet/Z-Image/src/utils/import_utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def is_flash_attn_available():
|
| 7 |
+
return importlib.util.find_spec("flash_attn") is not None
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def is_flash_attn_3_available():
|
| 11 |
+
return importlib.util.find_spec("flash_attn_interface") is not None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_torch_version(operator: str, version: str):
|
| 15 |
+
from packaging import version as pversion
|
| 16 |
+
|
| 17 |
+
torch_version = pversion.parse(torch.__version__)
|
| 18 |
+
target_version = pversion.parse(version)
|
| 19 |
+
|
| 20 |
+
# print(f"torch_version: {torch_version}, target: torch{operator}{target_version}")
|
| 21 |
+
if operator == ">":
|
| 22 |
+
return torch_version > target_version
|
| 23 |
+
elif operator == ">=":
|
| 24 |
+
return torch_version >= target_version
|
| 25 |
+
elif operator == "==":
|
| 26 |
+
return torch_version == target_version
|
| 27 |
+
elif operator == "<=":
|
| 28 |
+
return torch_version <= target_version
|
| 29 |
+
elif operator == "<":
|
| 30 |
+
return torch_version < target_version
|
| 31 |
+
return False
|
unet/Z-Image/src/utils/loader.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading utilities for Z-Image components."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
from typing import Optional, Union
|
| 8 |
+
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from safetensors.torch import load_file
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModel, AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from config import (
|
| 15 |
+
DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
|
| 16 |
+
DEFAULT_SCHEDULER_SHIFT,
|
| 17 |
+
DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
|
| 18 |
+
DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
|
| 19 |
+
DEFAULT_TRANSFORMER_DIM,
|
| 20 |
+
DEFAULT_TRANSFORMER_F_PATCH_SIZE,
|
| 21 |
+
DEFAULT_TRANSFORMER_IN_CHANNELS,
|
| 22 |
+
DEFAULT_TRANSFORMER_N_HEADS,
|
| 23 |
+
DEFAULT_TRANSFORMER_N_KV_HEADS,
|
| 24 |
+
DEFAULT_TRANSFORMER_N_LAYERS,
|
| 25 |
+
DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
|
| 26 |
+
DEFAULT_TRANSFORMER_NORM_EPS,
|
| 27 |
+
DEFAULT_TRANSFORMER_PATCH_SIZE,
|
| 28 |
+
DEFAULT_TRANSFORMER_QK_NORM,
|
| 29 |
+
DEFAULT_TRANSFORMER_T_SCALE,
|
| 30 |
+
DEFAULT_VAE_IN_CHANNELS,
|
| 31 |
+
DEFAULT_VAE_LATENT_CHANNELS,
|
| 32 |
+
DEFAULT_VAE_NORM_NUM_GROUPS,
|
| 33 |
+
DEFAULT_VAE_OUT_CHANNELS,
|
| 34 |
+
DEFAULT_VAE_SCALING_FACTOR,
|
| 35 |
+
ROPE_AXES_DIMS,
|
| 36 |
+
ROPE_AXES_LENS,
|
| 37 |
+
ROPE_THETA,
|
| 38 |
+
)
|
| 39 |
+
from zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL
|
| 40 |
+
from zimage.scheduler import FlowMatchEulerDiscreteScheduler
|
| 41 |
+
|
| 42 |
+
DIFFUSERS_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_config(config_path: str) -> dict:
|
| 46 |
+
with open(config_path, "r") as f:
|
| 47 |
+
return json.load(f)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> dict:
|
| 51 |
+
"""Load sharded safetensors from a directory."""
|
| 52 |
+
weight_dir = Path(weight_dir)
|
| 53 |
+
index_files = list(weight_dir.glob("*.safetensors.index.json"))
|
| 54 |
+
|
| 55 |
+
state_dict = {}
|
| 56 |
+
if index_files:
|
| 57 |
+
# Load sharded weights
|
| 58 |
+
with open(index_files[0], "r") as f:
|
| 59 |
+
index = json.load(f)
|
| 60 |
+
weight_map = index.get("weight_map", {})
|
| 61 |
+
shard_files = set(weight_map.values())
|
| 62 |
+
for shard_file in shard_files:
|
| 63 |
+
shard_path = weight_dir / shard_file
|
| 64 |
+
shard_state = load_file(str(shard_path), device=str(device))
|
| 65 |
+
state_dict.update(shard_state)
|
| 66 |
+
else:
|
| 67 |
+
# Load single safetensors file
|
| 68 |
+
safetensors_files = list(weight_dir.glob("*.safetensors"))
|
| 69 |
+
if not safetensors_files:
|
| 70 |
+
raise FileNotFoundError(f"No safetensors files found in {weight_dir}")
|
| 71 |
+
state_dict = load_file(str(safetensors_files[0]), device=str(device))
|
| 72 |
+
|
| 73 |
+
# Cast to target dtype if specified
|
| 74 |
+
if dtype is not None:
|
| 75 |
+
state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()}
|
| 76 |
+
|
| 77 |
+
return state_dict
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_from_local_dir(
|
| 81 |
+
model_dir: Union[str, Path],
|
| 82 |
+
device: str = "cuda",
|
| 83 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 84 |
+
verbose: bool = False,
|
| 85 |
+
compile: bool = False,
|
| 86 |
+
) -> dict:
|
| 87 |
+
"""
|
| 88 |
+
Load all Z-Image components from local directory.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
model_dir: Path to model directory
|
| 92 |
+
device: Device to load models on
|
| 93 |
+
dtype: Data type for model weights
|
| 94 |
+
verbose: Whether to display loading logs
|
| 95 |
+
compile: Whether to compile transformer and vae with torch.compile
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler
|
| 99 |
+
"""
|
| 100 |
+
model_dir = Path(model_dir)
|
| 101 |
+
|
| 102 |
+
sys.path.insert(0, str(model_dir.parent.parent / "Z-Image" / "src"))
|
| 103 |
+
from zimage.transformer import ZImageTransformer2DModel
|
| 104 |
+
|
| 105 |
+
if verbose:
|
| 106 |
+
logger.info(f"Loading Z-Image from: {model_dir}")
|
| 107 |
+
|
| 108 |
+
# DiT
|
| 109 |
+
if verbose:
|
| 110 |
+
logger.info("Loading DiT...")
|
| 111 |
+
transformer_dir = model_dir / "transformer"
|
| 112 |
+
config = load_config(str(transformer_dir / "config.json"))
|
| 113 |
+
|
| 114 |
+
with torch.device("meta"):
|
| 115 |
+
transformer = ZImageTransformer2DModel(
|
| 116 |
+
all_patch_size=tuple(config.get("all_patch_size", DEFAULT_TRANSFORMER_PATCH_SIZE)),
|
| 117 |
+
all_f_patch_size=tuple(config.get("all_f_patch_size", DEFAULT_TRANSFORMER_F_PATCH_SIZE)),
|
| 118 |
+
in_channels=config.get("in_channels", DEFAULT_TRANSFORMER_IN_CHANNELS),
|
| 119 |
+
dim=config.get("dim", DEFAULT_TRANSFORMER_DIM),
|
| 120 |
+
n_layers=config.get("n_layers", DEFAULT_TRANSFORMER_N_LAYERS),
|
| 121 |
+
n_refiner_layers=config.get("n_refiner_layers", DEFAULT_TRANSFORMER_N_REFINER_LAYERS),
|
| 122 |
+
n_heads=config.get("n_heads", DEFAULT_TRANSFORMER_N_HEADS),
|
| 123 |
+
n_kv_heads=config.get("n_kv_heads", DEFAULT_TRANSFORMER_N_KV_HEADS),
|
| 124 |
+
norm_eps=config.get("norm_eps", DEFAULT_TRANSFORMER_NORM_EPS),
|
| 125 |
+
qk_norm=config.get("qk_norm", DEFAULT_TRANSFORMER_QK_NORM),
|
| 126 |
+
cap_feat_dim=config.get("cap_feat_dim", DEFAULT_TRANSFORMER_CAP_FEAT_DIM),
|
| 127 |
+
rope_theta=config.get("rope_theta", ROPE_THETA),
|
| 128 |
+
t_scale=config.get("t_scale", DEFAULT_TRANSFORMER_T_SCALE),
|
| 129 |
+
axes_dims=config.get("axes_dims", ROPE_AXES_DIMS),
|
| 130 |
+
axes_lens=config.get("axes_lens", ROPE_AXES_LENS),
|
| 131 |
+
).to(dtype)
|
| 132 |
+
|
| 133 |
+
# DiT (weights to CPU then move to GPU to optimize memory)
|
| 134 |
+
state_dict = load_sharded_safetensors(transformer_dir, device="cpu", dtype=dtype)
|
| 135 |
+
transformer.load_state_dict(state_dict, strict=False, assign=True)
|
| 136 |
+
del state_dict
|
| 137 |
+
|
| 138 |
+
if verbose:
|
| 139 |
+
logger.info("Moving DiT to GPU...")
|
| 140 |
+
transformer = transformer.to(device)
|
| 141 |
+
if torch.cuda.is_available():
|
| 142 |
+
torch.cuda.empty_cache()
|
| 143 |
+
transformer.eval()
|
| 144 |
+
|
| 145 |
+
# VAE
|
| 146 |
+
if verbose:
|
| 147 |
+
logger.info("Loading VAE...")
|
| 148 |
+
vae_dir = model_dir / "vae"
|
| 149 |
+
vae_config = load_config(str(vae_dir / "config.json"))
|
| 150 |
+
|
| 151 |
+
vae = LocalAutoencoderKL(
|
| 152 |
+
in_channels=vae_config.get("in_channels", DEFAULT_VAE_IN_CHANNELS),
|
| 153 |
+
out_channels=vae_config.get("out_channels", DEFAULT_VAE_OUT_CHANNELS),
|
| 154 |
+
down_block_types=tuple(vae_config.get("down_block_types", ("DownEncoderBlock2D",))),
|
| 155 |
+
up_block_types=tuple(vae_config.get("up_block_types", ("UpDecoderBlock2D",))),
|
| 156 |
+
block_out_channels=tuple(vae_config.get("block_out_channels", (64,))),
|
| 157 |
+
layers_per_block=vae_config.get("layers_per_block", 1),
|
| 158 |
+
latent_channels=vae_config.get("latent_channels", DEFAULT_VAE_LATENT_CHANNELS),
|
| 159 |
+
norm_num_groups=vae_config.get("norm_num_groups", DEFAULT_VAE_NORM_NUM_GROUPS),
|
| 160 |
+
scaling_factor=vae_config.get("scaling_factor", DEFAULT_VAE_SCALING_FACTOR),
|
| 161 |
+
shift_factor=vae_config.get("shift_factor", None),
|
| 162 |
+
use_quant_conv=vae_config.get("use_quant_conv", True),
|
| 163 |
+
use_post_quant_conv=vae_config.get("use_post_quant_conv", True),
|
| 164 |
+
mid_block_add_attention=vae_config.get("mid_block_add_attention", True),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# VAE (fp32 for better precision)
|
| 168 |
+
vae_state_dict = load_sharded_safetensors(vae_dir, device="cpu")
|
| 169 |
+
vae.load_state_dict(vae_state_dict, strict=False)
|
| 170 |
+
del vae_state_dict
|
| 171 |
+
vae.to(device=device, dtype=torch.float32)
|
| 172 |
+
vae.eval()
|
| 173 |
+
torch.cuda.empty_cache()
|
| 174 |
+
|
| 175 |
+
# Text Encoder
|
| 176 |
+
if verbose:
|
| 177 |
+
logger.info("Loading Text Encoder...")
|
| 178 |
+
text_encoder_dir = model_dir / "text_encoder"
|
| 179 |
+
text_encoder = AutoModel.from_pretrained(
|
| 180 |
+
str(text_encoder_dir),
|
| 181 |
+
# torch_dtype=dtype, # some version use this
|
| 182 |
+
dtype=dtype,
|
| 183 |
+
trust_remote_code=True,
|
| 184 |
+
)
|
| 185 |
+
text_encoder.to(device)
|
| 186 |
+
text_encoder.eval()
|
| 187 |
+
|
| 188 |
+
# Tokenizer
|
| 189 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 190 |
+
if verbose:
|
| 191 |
+
logger.info("Loading Tokenizer...")
|
| 192 |
+
tokenizer_dir = model_dir / "tokenizer"
|
| 193 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 194 |
+
str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir),
|
| 195 |
+
trust_remote_code=True,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Scheduler
|
| 199 |
+
if verbose:
|
| 200 |
+
logger.info("Loading Scheduler...")
|
| 201 |
+
scheduler_dir = model_dir / "scheduler"
|
| 202 |
+
scheduler_config = load_config(str(scheduler_dir / "scheduler_config.json"))
|
| 203 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
| 204 |
+
num_train_timesteps=scheduler_config.get("num_train_timesteps", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS),
|
| 205 |
+
shift=scheduler_config.get("shift", DEFAULT_SCHEDULER_SHIFT),
|
| 206 |
+
use_dynamic_shifting=scheduler_config.get("use_dynamic_shifting", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
if compile:
|
| 210 |
+
if verbose:
|
| 211 |
+
logger.info("Compiling DiT and VAE...")
|
| 212 |
+
transformer = torch.compile(transformer)
|
| 213 |
+
vae = torch.compile(vae)
|
| 214 |
+
|
| 215 |
+
if verbose:
|
| 216 |
+
logger.success("All components loaded successfully")
|
| 217 |
+
|
| 218 |
+
return {
|
| 219 |
+
"transformer": transformer,
|
| 220 |
+
"vae": vae,
|
| 221 |
+
"text_encoder": text_encoder,
|
| 222 |
+
"tokenizer": tokenizer,
|
| 223 |
+
"scheduler": scheduler,
|
| 224 |
+
}
|
unet/Z-Image/src/zimage/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Z-Image PyTorch Native Implementation."""
|
| 2 |
+
|
| 3 |
+
from .pipeline import generate
|
| 4 |
+
from .transformer import ZImageTransformer2DModel
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ZImageTransformer2DModel",
|
| 8 |
+
"generate",
|
| 9 |
+
]
|
unet/Z-Image/src/zimage/autoencoder.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AutoencoderKL implementation compatible with diffusers weights."""
|
| 2 |
+
|
| 3 |
+
# Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class AutoencoderKLOutput:
|
| 13 |
+
sample: torch.Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AutoencoderConfig:
|
| 17 |
+
def __init__(self, **kwargs):
|
| 18 |
+
self.__dict__.update(kwargs)
|
| 19 |
+
|
| 20 |
+
def get(self, key, default=None):
|
| 21 |
+
return self.__dict__.get(key, default)
|
| 22 |
+
|
| 23 |
+
def __getattr__(self, name):
|
| 24 |
+
return self.__dict__.get(name)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def swish(x):
|
| 28 |
+
return x * torch.sigmoid(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ResnetBlock2D(nn.Module):
|
| 32 |
+
def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6):
|
| 33 |
+
super().__init__()
|
| 34 |
+
out_channels = out_channels or in_channels
|
| 35 |
+
self.in_channels = in_channels
|
| 36 |
+
self.out_channels = out_channels
|
| 37 |
+
|
| 38 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 39 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 40 |
+
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 41 |
+
self.dropout = nn.Dropout(dropout)
|
| 42 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 43 |
+
|
| 44 |
+
self.nonlinearity = swish
|
| 45 |
+
|
| 46 |
+
if self.in_channels != self.out_channels:
|
| 47 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 48 |
+
else:
|
| 49 |
+
self.conv_shortcut = None
|
| 50 |
+
|
| 51 |
+
def forward(self, input_tensor, temb=None):
|
| 52 |
+
hidden_states = input_tensor
|
| 53 |
+
hidden_states = self.norm1(hidden_states)
|
| 54 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 55 |
+
hidden_states = self.conv1(hidden_states)
|
| 56 |
+
|
| 57 |
+
hidden_states = self.norm2(hidden_states)
|
| 58 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 59 |
+
hidden_states = self.dropout(hidden_states)
|
| 60 |
+
hidden_states = self.conv2(hidden_states)
|
| 61 |
+
|
| 62 |
+
if self.conv_shortcut is not None:
|
| 63 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 64 |
+
|
| 65 |
+
output_tensor = (input_tensor + hidden_states) / 1.0
|
| 66 |
+
return output_tensor
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Attention(nn.Module):
|
| 70 |
+
def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.heads = heads
|
| 73 |
+
self.in_channels = in_channels
|
| 74 |
+
self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 75 |
+
|
| 76 |
+
self.to_q = nn.Linear(in_channels, in_channels)
|
| 77 |
+
self.to_k = nn.Linear(in_channels, in_channels)
|
| 78 |
+
self.to_v = nn.Linear(in_channels, in_channels)
|
| 79 |
+
self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)])
|
| 80 |
+
|
| 81 |
+
def forward(self, hidden_states):
|
| 82 |
+
b, c, h, w = hidden_states.shape
|
| 83 |
+
residual = hidden_states
|
| 84 |
+
hidden_states = self.group_norm(hidden_states)
|
| 85 |
+
hidden_states = hidden_states.view(b, c, -1).transpose(1, 2) # (B, H*W, C)
|
| 86 |
+
|
| 87 |
+
query = self.to_q(hidden_states)
|
| 88 |
+
key = self.to_k(hidden_states)
|
| 89 |
+
value = self.to_v(hidden_states)
|
| 90 |
+
|
| 91 |
+
import torch.nn.functional as F
|
| 92 |
+
|
| 93 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value)
|
| 94 |
+
|
| 95 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 96 |
+
hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w)
|
| 97 |
+
|
| 98 |
+
return residual + hidden_states
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Downsample2D(nn.Module):
|
| 102 |
+
def __init__(self, channels, with_conv=True, out_channels=None, padding=1):
|
| 103 |
+
super().__init__()
|
| 104 |
+
out_channels = out_channels or channels
|
| 105 |
+
self.with_conv = with_conv
|
| 106 |
+
if with_conv:
|
| 107 |
+
self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
| 108 |
+
|
| 109 |
+
def forward(self, hidden_states):
|
| 110 |
+
if self.with_conv:
|
| 111 |
+
return self.conv(hidden_states)
|
| 112 |
+
else:
|
| 113 |
+
return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Upsample2D(nn.Module):
|
| 117 |
+
def __init__(self, channels, with_conv=True, out_channels=None):
|
| 118 |
+
super().__init__()
|
| 119 |
+
out_channels = out_channels or channels
|
| 120 |
+
self.with_conv = with_conv
|
| 121 |
+
if with_conv:
|
| 122 |
+
self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 123 |
+
|
| 124 |
+
def forward(self, hidden_states):
|
| 125 |
+
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
| 126 |
+
if self.with_conv:
|
| 127 |
+
hidden_states = self.conv(hidden_states)
|
| 128 |
+
return hidden_states
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class DownEncoderBlock2D(nn.Module):
|
| 132 |
+
def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True):
|
| 133 |
+
super().__init__()
|
| 134 |
+
resnets = []
|
| 135 |
+
for i in range(num_layers):
|
| 136 |
+
in_c = in_channels if i == 0 else out_channels
|
| 137 |
+
resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
|
| 138 |
+
self.resnets = nn.ModuleList(resnets)
|
| 139 |
+
|
| 140 |
+
if add_downsample:
|
| 141 |
+
self.downsamplers = nn.ModuleList(
|
| 142 |
+
[Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)]
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
self.downsamplers = None
|
| 146 |
+
|
| 147 |
+
def forward(self, hidden_states):
|
| 148 |
+
for resnet in self.resnets:
|
| 149 |
+
hidden_states = resnet(hidden_states)
|
| 150 |
+
|
| 151 |
+
if self.downsamplers is not None:
|
| 152 |
+
for downsampler in self.downsamplers:
|
| 153 |
+
pad = (0, 1, 0, 1)
|
| 154 |
+
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
| 155 |
+
hidden_states = downsampler(hidden_states)
|
| 156 |
+
|
| 157 |
+
return hidden_states
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class UpDecoderBlock2D(nn.Module):
|
| 161 |
+
def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True):
|
| 162 |
+
super().__init__()
|
| 163 |
+
resnets = []
|
| 164 |
+
for i in range(num_layers):
|
| 165 |
+
in_c = in_channels if i == 0 else out_channels
|
| 166 |
+
resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
|
| 167 |
+
self.resnets = nn.ModuleList(resnets)
|
| 168 |
+
|
| 169 |
+
if add_upsample:
|
| 170 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)])
|
| 171 |
+
else:
|
| 172 |
+
self.upsamplers = None
|
| 173 |
+
|
| 174 |
+
def forward(self, hidden_states):
|
| 175 |
+
for resnet in self.resnets:
|
| 176 |
+
hidden_states = resnet(hidden_states)
|
| 177 |
+
|
| 178 |
+
if self.upsamplers is not None:
|
| 179 |
+
for upsampler in self.upsamplers:
|
| 180 |
+
hidden_states = upsampler(hidden_states)
|
| 181 |
+
|
| 182 |
+
return hidden_states
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class UNetMidBlock2D(nn.Module):
|
| 186 |
+
def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.resnets = nn.ModuleList(
|
| 189 |
+
[
|
| 190 |
+
ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
|
| 191 |
+
ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
|
| 192 |
+
]
|
| 193 |
+
)
|
| 194 |
+
self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)])
|
| 195 |
+
|
| 196 |
+
def forward(self, hidden_states):
|
| 197 |
+
hidden_states = self.resnets[0](hidden_states)
|
| 198 |
+
for attn in self.attentions:
|
| 199 |
+
hidden_states = attn(hidden_states)
|
| 200 |
+
hidden_states = self.resnets[1](hidden_states)
|
| 201 |
+
return hidden_states
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Encoder(nn.Module):
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
in_channels=3,
|
| 208 |
+
out_channels=3,
|
| 209 |
+
block_out_channels=(64,),
|
| 210 |
+
layers_per_block=2,
|
| 211 |
+
norm_num_groups=32,
|
| 212 |
+
double_z=True,
|
| 213 |
+
):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 216 |
+
|
| 217 |
+
self.down_blocks = nn.ModuleList([])
|
| 218 |
+
output_channel = block_out_channels[0]
|
| 219 |
+
for i, block_out_channel in enumerate(block_out_channels):
|
| 220 |
+
input_channel = output_channel
|
| 221 |
+
output_channel = block_out_channel
|
| 222 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 223 |
+
|
| 224 |
+
block = DownEncoderBlock2D(
|
| 225 |
+
input_channel,
|
| 226 |
+
output_channel,
|
| 227 |
+
num_layers=layers_per_block,
|
| 228 |
+
resnet_groups=norm_num_groups,
|
| 229 |
+
add_downsample=not is_final_block,
|
| 230 |
+
)
|
| 231 |
+
self.down_blocks.append(block)
|
| 232 |
+
|
| 233 |
+
self.mid_block = UNetMidBlock2D(
|
| 234 |
+
block_out_channels[-1],
|
| 235 |
+
resnet_groups=norm_num_groups,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 239 |
+
self.conv_act = nn.SiLU()
|
| 240 |
+
|
| 241 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
| 242 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
| 243 |
+
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
x = self.conv_in(x)
|
| 246 |
+
for block in self.down_blocks:
|
| 247 |
+
x = block(x)
|
| 248 |
+
x = self.mid_block(x)
|
| 249 |
+
x = self.conv_norm_out(x)
|
| 250 |
+
x = self.conv_act(x)
|
| 251 |
+
x = self.conv_out(x)
|
| 252 |
+
return x
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class Decoder(nn.Module):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
in_channels=3,
|
| 259 |
+
out_channels=3,
|
| 260 |
+
block_out_channels=(64,),
|
| 261 |
+
layers_per_block=2,
|
| 262 |
+
norm_num_groups=32,
|
| 263 |
+
):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
| 266 |
+
|
| 267 |
+
self.mid_block = UNetMidBlock2D(
|
| 268 |
+
block_out_channels[-1],
|
| 269 |
+
resnet_groups=norm_num_groups,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.up_blocks = nn.ModuleList([])
|
| 273 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 274 |
+
output_channel = reversed_block_out_channels[0]
|
| 275 |
+
|
| 276 |
+
for i, block_out_channel in enumerate(reversed_block_out_channels):
|
| 277 |
+
input_channel = output_channel
|
| 278 |
+
output_channel = block_out_channel
|
| 279 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 280 |
+
block = UpDecoderBlock2D(
|
| 281 |
+
input_channel,
|
| 282 |
+
output_channel,
|
| 283 |
+
num_layers=layers_per_block + 1,
|
| 284 |
+
resnet_groups=norm_num_groups,
|
| 285 |
+
add_upsample=not is_final_block,
|
| 286 |
+
)
|
| 287 |
+
self.up_blocks.append(block)
|
| 288 |
+
|
| 289 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
| 290 |
+
self.conv_act = nn.SiLU()
|
| 291 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
x = self.conv_in(x)
|
| 295 |
+
x = self.mid_block(x)
|
| 296 |
+
for block in self.up_blocks:
|
| 297 |
+
x = block(x)
|
| 298 |
+
x = self.conv_norm_out(x)
|
| 299 |
+
x = self.conv_act(x)
|
| 300 |
+
x = self.conv_out(x)
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class AutoencoderKL(nn.Module):
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
in_channels: int = 3,
|
| 308 |
+
out_channels: int = 3,
|
| 309 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
| 310 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
| 311 |
+
block_out_channels: Tuple[int] = (64,),
|
| 312 |
+
layers_per_block: int = 1,
|
| 313 |
+
act_fn: str = "silu",
|
| 314 |
+
latent_channels: int = 4,
|
| 315 |
+
norm_num_groups: int = 32,
|
| 316 |
+
sample_size: int = 32,
|
| 317 |
+
scaling_factor: float = 0.18215,
|
| 318 |
+
shift_factor: Optional[float] = None,
|
| 319 |
+
force_upcast: bool = True,
|
| 320 |
+
use_quant_conv: bool = True,
|
| 321 |
+
use_post_quant_conv: bool = True,
|
| 322 |
+
mid_block_add_attention: bool = True,
|
| 323 |
+
**kwargs,
|
| 324 |
+
):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.config = AutoencoderConfig(
|
| 327 |
+
in_channels=in_channels,
|
| 328 |
+
out_channels=out_channels,
|
| 329 |
+
block_out_channels=block_out_channels,
|
| 330 |
+
layers_per_block=layers_per_block,
|
| 331 |
+
latent_channels=latent_channels,
|
| 332 |
+
scaling_factor=scaling_factor,
|
| 333 |
+
shift_factor=shift_factor,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.encoder = Encoder(
|
| 337 |
+
in_channels=in_channels,
|
| 338 |
+
out_channels=latent_channels,
|
| 339 |
+
block_out_channels=block_out_channels,
|
| 340 |
+
layers_per_block=layers_per_block,
|
| 341 |
+
norm_num_groups=norm_num_groups,
|
| 342 |
+
double_z=True,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
self.decoder = Decoder(
|
| 346 |
+
in_channels=latent_channels,
|
| 347 |
+
out_channels=out_channels,
|
| 348 |
+
block_out_channels=block_out_channels,
|
| 349 |
+
layers_per_block=layers_per_block,
|
| 350 |
+
norm_num_groups=norm_num_groups,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
| 354 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
| 355 |
+
|
| 356 |
+
@property
|
| 357 |
+
def dtype(self):
|
| 358 |
+
return next(self.parameters()).dtype
|
| 359 |
+
|
| 360 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
| 361 |
+
if self.post_quant_conv is not None:
|
| 362 |
+
z = self.post_quant_conv(z)
|
| 363 |
+
|
| 364 |
+
dec = self.decoder(z)
|
| 365 |
+
|
| 366 |
+
if not return_dict:
|
| 367 |
+
return (dec,)
|
| 368 |
+
|
| 369 |
+
return AutoencoderKLOutput(sample=dec)
|
unet/Z-Image/src/zimage/pipeline.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Z-Image Pipeline."""
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from loguru import logger
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from config import (
|
| 10 |
+
BASE_IMAGE_SEQ_LEN,
|
| 11 |
+
BASE_SHIFT,
|
| 12 |
+
DEFAULT_CFG_TRUNCATION,
|
| 13 |
+
DEFAULT_GUIDANCE_SCALE,
|
| 14 |
+
DEFAULT_HEIGHT,
|
| 15 |
+
DEFAULT_INFERENCE_STEPS,
|
| 16 |
+
DEFAULT_MAX_SEQUENCE_LENGTH,
|
| 17 |
+
DEFAULT_WIDTH,
|
| 18 |
+
MAX_IMAGE_SEQ_LEN,
|
| 19 |
+
MAX_SHIFT,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def calculate_shift(
|
| 24 |
+
image_seq_len,
|
| 25 |
+
base_seq_len: int = BASE_IMAGE_SEQ_LEN,
|
| 26 |
+
max_seq_len: int = MAX_IMAGE_SEQ_LEN,
|
| 27 |
+
base_shift: float = BASE_SHIFT,
|
| 28 |
+
max_shift: float = MAX_SHIFT,
|
| 29 |
+
):
|
| 30 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 31 |
+
b = base_shift - m * base_seq_len
|
| 32 |
+
mu = image_seq_len * m + b
|
| 33 |
+
return mu
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def retrieve_timesteps(
|
| 37 |
+
scheduler,
|
| 38 |
+
num_inference_steps: Optional[int] = None,
|
| 39 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 40 |
+
timesteps: Optional[List[int]] = None,
|
| 41 |
+
sigmas: Optional[List[float]] = None,
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
if timesteps is not None and sigmas is not None:
|
| 45 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
|
| 46 |
+
if timesteps is not None:
|
| 47 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 48 |
+
if not accepts_timesteps:
|
| 49 |
+
raise ValueError(f"The scheduler does not support custom timestep schedules.")
|
| 50 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 51 |
+
timesteps = scheduler.timesteps
|
| 52 |
+
num_inference_steps = len(timesteps)
|
| 53 |
+
elif sigmas is not None:
|
| 54 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 55 |
+
if not accept_sigmas:
|
| 56 |
+
raise ValueError(f"The scheduler does not support custom sigmas schedules.")
|
| 57 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 58 |
+
timesteps = scheduler.timesteps
|
| 59 |
+
num_inference_steps = len(timesteps)
|
| 60 |
+
else:
|
| 61 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 62 |
+
timesteps = scheduler.timesteps
|
| 63 |
+
return timesteps, num_inference_steps
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def generate(
|
| 68 |
+
transformer,
|
| 69 |
+
vae,
|
| 70 |
+
text_encoder,
|
| 71 |
+
tokenizer,
|
| 72 |
+
scheduler,
|
| 73 |
+
prompt: Union[str, List[str]],
|
| 74 |
+
height: int = DEFAULT_HEIGHT,
|
| 75 |
+
width: int = DEFAULT_WIDTH,
|
| 76 |
+
num_inference_steps: int = DEFAULT_INFERENCE_STEPS,
|
| 77 |
+
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
| 78 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 79 |
+
num_images_per_prompt: int = 1,
|
| 80 |
+
generator: Optional[torch.Generator] = None,
|
| 81 |
+
cfg_normalization: bool = False,
|
| 82 |
+
cfg_truncation: float = DEFAULT_CFG_TRUNCATION,
|
| 83 |
+
max_sequence_length: int = DEFAULT_MAX_SEQUENCE_LENGTH,
|
| 84 |
+
output_type: str = "pil",
|
| 85 |
+
):
|
| 86 |
+
device = next(transformer.parameters()).device
|
| 87 |
+
|
| 88 |
+
if hasattr(vae, "config") and hasattr(vae.config, "block_out_channels"):
|
| 89 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
| 90 |
+
else:
|
| 91 |
+
vae_scale_factor = 8
|
| 92 |
+
vae_scale = vae_scale_factor * 2
|
| 93 |
+
|
| 94 |
+
if height % vae_scale != 0:
|
| 95 |
+
raise ValueError(f"Height must be divisible by {vae_scale} (got {height}).")
|
| 96 |
+
if width % vae_scale != 0:
|
| 97 |
+
raise ValueError(f"Width must be divisible by {vae_scale} (got {width}).")
|
| 98 |
+
|
| 99 |
+
if isinstance(prompt, str):
|
| 100 |
+
batch_size = 1
|
| 101 |
+
prompt = [prompt]
|
| 102 |
+
else:
|
| 103 |
+
batch_size = len(prompt)
|
| 104 |
+
|
| 105 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 106 |
+
logger.info(f"Generating image: {height}x{width}, steps={num_inference_steps}, cfg={guidance_scale}")
|
| 107 |
+
|
| 108 |
+
formatted_prompts = []
|
| 109 |
+
for p in prompt:
|
| 110 |
+
messages = [{"role": "user", "content": p}]
|
| 111 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 112 |
+
messages,
|
| 113 |
+
tokenize=False,
|
| 114 |
+
add_generation_prompt=True,
|
| 115 |
+
enable_thinking=True,
|
| 116 |
+
)
|
| 117 |
+
formatted_prompts.append(formatted_prompt)
|
| 118 |
+
|
| 119 |
+
text_inputs = tokenizer(
|
| 120 |
+
formatted_prompts,
|
| 121 |
+
padding="max_length",
|
| 122 |
+
max_length=max_sequence_length,
|
| 123 |
+
truncation=True,
|
| 124 |
+
return_tensors="pt",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 128 |
+
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
| 129 |
+
|
| 130 |
+
prompt_embeds = text_encoder(
|
| 131 |
+
input_ids=text_input_ids,
|
| 132 |
+
attention_mask=prompt_masks,
|
| 133 |
+
output_hidden_states=True,
|
| 134 |
+
).hidden_states[-2]
|
| 135 |
+
|
| 136 |
+
prompt_embeds_list = []
|
| 137 |
+
for i in range(len(prompt_embeds)):
|
| 138 |
+
prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])
|
| 139 |
+
|
| 140 |
+
negative_prompt_embeds_list = []
|
| 141 |
+
if do_classifier_free_guidance:
|
| 142 |
+
if negative_prompt is None:
|
| 143 |
+
negative_prompt = ["" for _ in prompt]
|
| 144 |
+
elif isinstance(negative_prompt, str):
|
| 145 |
+
negative_prompt = [negative_prompt]
|
| 146 |
+
|
| 147 |
+
neg_formatted = []
|
| 148 |
+
for p in negative_prompt:
|
| 149 |
+
messages = [{"role": "user", "content": p}]
|
| 150 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 151 |
+
messages,
|
| 152 |
+
tokenize=False,
|
| 153 |
+
add_generation_prompt=True,
|
| 154 |
+
enable_thinking=True,
|
| 155 |
+
)
|
| 156 |
+
neg_formatted.append(formatted_prompt)
|
| 157 |
+
|
| 158 |
+
neg_inputs = tokenizer(
|
| 159 |
+
neg_formatted,
|
| 160 |
+
padding="max_length",
|
| 161 |
+
max_length=max_sequence_length,
|
| 162 |
+
truncation=True,
|
| 163 |
+
return_tensors="pt",
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
neg_input_ids = neg_inputs.input_ids.to(device)
|
| 167 |
+
neg_masks = neg_inputs.attention_mask.to(device).bool()
|
| 168 |
+
|
| 169 |
+
neg_embeds = text_encoder(
|
| 170 |
+
input_ids=neg_input_ids,
|
| 171 |
+
attention_mask=neg_masks,
|
| 172 |
+
output_hidden_states=True,
|
| 173 |
+
).hidden_states[-2]
|
| 174 |
+
|
| 175 |
+
for i in range(len(neg_embeds)):
|
| 176 |
+
negative_prompt_embeds_list.append(neg_embeds[i][neg_masks[i]])
|
| 177 |
+
|
| 178 |
+
if num_images_per_prompt > 1:
|
| 179 |
+
prompt_embeds_list = [pe for pe in prompt_embeds_list for _ in range(num_images_per_prompt)]
|
| 180 |
+
if do_classifier_free_guidance:
|
| 181 |
+
negative_prompt_embeds_list = [
|
| 182 |
+
npe for npe in negative_prompt_embeds_list for _ in range(num_images_per_prompt)
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
height_latent = 2 * (int(height) // vae_scale)
|
| 186 |
+
width_latent = 2 * (int(width) // vae_scale)
|
| 187 |
+
shape = (batch_size * num_images_per_prompt, transformer.in_channels, height_latent, width_latent)
|
| 188 |
+
|
| 189 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=torch.float32)
|
| 190 |
+
|
| 191 |
+
actual_batch_size = batch_size * num_images_per_prompt
|
| 192 |
+
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
| 193 |
+
|
| 194 |
+
mu = calculate_shift(
|
| 195 |
+
image_seq_len,
|
| 196 |
+
scheduler.config.get("base_image_seq_len", 256),
|
| 197 |
+
scheduler.config.get("max_image_seq_len", 4096),
|
| 198 |
+
scheduler.config.get("base_shift", 0.5),
|
| 199 |
+
scheduler.config.get("max_shift", 1.15),
|
| 200 |
+
)
|
| 201 |
+
scheduler.sigma_min = 0.0
|
| 202 |
+
scheduler_kwargs = {"mu": mu}
|
| 203 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 204 |
+
scheduler,
|
| 205 |
+
num_inference_steps,
|
| 206 |
+
device,
|
| 207 |
+
sigmas=None,
|
| 208 |
+
**scheduler_kwargs,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
logger.info(f"Sampling loop start: {num_inference_steps} steps")
|
| 212 |
+
|
| 213 |
+
from tqdm import tqdm
|
| 214 |
+
|
| 215 |
+
# Denoising loop with progress bar
|
| 216 |
+
for i, t in enumerate(tqdm(timesteps, desc="Denoising", total=len(timesteps))):
|
| 217 |
+
# If current t is 0 and it's the last step, skip computation
|
| 218 |
+
if t == 0 and i == len(timesteps) - 1:
|
| 219 |
+
logger.debug(f"Step {i+1}/{num_inference_steps} | t: {t.item():.2f} | Skipping last step")
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
timestep = t.expand(latents.shape[0])
|
| 223 |
+
timestep = (1000 - timestep) / 1000
|
| 224 |
+
t_norm = timestep[0].item()
|
| 225 |
+
|
| 226 |
+
current_guidance_scale = guidance_scale
|
| 227 |
+
if do_classifier_free_guidance and cfg_truncation is not None and float(cfg_truncation) <= 1:
|
| 228 |
+
if t_norm > cfg_truncation:
|
| 229 |
+
current_guidance_scale = 0.0
|
| 230 |
+
|
| 231 |
+
apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0
|
| 232 |
+
|
| 233 |
+
if apply_cfg:
|
| 234 |
+
latents_typed = latents.to(
|
| 235 |
+
transformer.dtype if hasattr(transformer, "dtype") else next(transformer.parameters()).dtype
|
| 236 |
+
)
|
| 237 |
+
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
| 238 |
+
prompt_embeds_model_input = prompt_embeds_list + negative_prompt_embeds_list
|
| 239 |
+
timestep_model_input = timestep.repeat(2)
|
| 240 |
+
else:
|
| 241 |
+
latent_model_input = latents.to(next(transformer.parameters()).dtype)
|
| 242 |
+
prompt_embeds_model_input = prompt_embeds_list
|
| 243 |
+
timestep_model_input = timestep
|
| 244 |
+
|
| 245 |
+
latent_model_input = latent_model_input.unsqueeze(2)
|
| 246 |
+
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
| 247 |
+
|
| 248 |
+
model_out_list = transformer(
|
| 249 |
+
latent_model_input_list,
|
| 250 |
+
timestep_model_input,
|
| 251 |
+
prompt_embeds_model_input,
|
| 252 |
+
)[0]
|
| 253 |
+
|
| 254 |
+
if apply_cfg:
|
| 255 |
+
pos_out = model_out_list[:actual_batch_size]
|
| 256 |
+
neg_out = model_out_list[actual_batch_size:]
|
| 257 |
+
noise_pred = []
|
| 258 |
+
for j in range(actual_batch_size):
|
| 259 |
+
pos = pos_out[j].float()
|
| 260 |
+
neg = neg_out[j].float()
|
| 261 |
+
pred = pos + current_guidance_scale * (pos - neg)
|
| 262 |
+
|
| 263 |
+
if cfg_normalization and float(cfg_normalization) > 0.0:
|
| 264 |
+
ori_pos_norm = torch.linalg.vector_norm(pos)
|
| 265 |
+
new_pos_norm = torch.linalg.vector_norm(pred)
|
| 266 |
+
max_new_norm = ori_pos_norm * float(cfg_normalization)
|
| 267 |
+
if new_pos_norm > max_new_norm:
|
| 268 |
+
pred = pred * (max_new_norm / new_pos_norm)
|
| 269 |
+
noise_pred.append(pred)
|
| 270 |
+
noise_pred = torch.stack(noise_pred, dim=0)
|
| 271 |
+
else:
|
| 272 |
+
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
| 273 |
+
|
| 274 |
+
noise_pred = -noise_pred.squeeze(2)
|
| 275 |
+
latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
| 276 |
+
assert latents.dtype == torch.float32
|
| 277 |
+
|
| 278 |
+
if output_type == "latent":
|
| 279 |
+
return latents
|
| 280 |
+
|
| 281 |
+
shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
|
| 282 |
+
latents = (latents.to(vae.dtype) / vae.config.scaling_factor) + shift_factor
|
| 283 |
+
image = vae.decode(latents, return_dict=False)[0]
|
| 284 |
+
|
| 285 |
+
if output_type == "pil":
|
| 286 |
+
from PIL import Image
|
| 287 |
+
|
| 288 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 289 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 290 |
+
image = (image * 255).round().astype("uint8")
|
| 291 |
+
image = [Image.fromarray(img) for img in image]
|
| 292 |
+
|
| 293 |
+
return image
|
unet/Z-Image/src/zimage/transformer.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Z-Image Transformer."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
|
| 11 |
+
from config import (
|
| 12 |
+
ADALN_EMBED_DIM,
|
| 13 |
+
FREQUENCY_EMBEDDING_SIZE,
|
| 14 |
+
MAX_PERIOD,
|
| 15 |
+
ROPE_AXES_DIMS,
|
| 16 |
+
ROPE_AXES_LENS,
|
| 17 |
+
ROPE_THETA,
|
| 18 |
+
SEQ_MULTI_OF,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TimestepEmbedder(nn.Module):
|
| 23 |
+
def __init__(self, out_size, mid_size=None, frequency_embedding_size=FREQUENCY_EMBEDDING_SIZE):
|
| 24 |
+
super().__init__()
|
| 25 |
+
if mid_size is None:
|
| 26 |
+
mid_size = out_size
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(frequency_embedding_size, mid_size, bias=True),
|
| 29 |
+
nn.SiLU(),
|
| 30 |
+
nn.Linear(mid_size, out_size, bias=True),
|
| 31 |
+
)
|
| 32 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def timestep_embedding(t, dim, max_period=MAX_PERIOD):
|
| 36 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 37 |
+
half = dim // 2
|
| 38 |
+
freqs = torch.exp(
|
| 39 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
| 40 |
+
)
|
| 41 |
+
args = t[:, None].float() * freqs[None]
|
| 42 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 43 |
+
if dim % 2:
|
| 44 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 45 |
+
return embedding
|
| 46 |
+
|
| 47 |
+
def forward(self, t):
|
| 48 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 49 |
+
weight_dtype = self.mlp[0].weight.dtype
|
| 50 |
+
if weight_dtype.is_floating_point:
|
| 51 |
+
t_freq = t_freq.to(weight_dtype)
|
| 52 |
+
t_emb = self.mlp(t_freq)
|
| 53 |
+
return t_emb
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RMSNorm(nn.Module):
|
| 57 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.eps = eps
|
| 60 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 61 |
+
|
| 62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 64 |
+
return output * self.weight
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class FeedForward(nn.Module):
|
| 68 |
+
def __init__(self, dim: int, hidden_dim: int):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 71 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 72 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 80 |
+
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
| 81 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 82 |
+
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
| 83 |
+
return x_out.type_as(x_in)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ZImageAttention(nn.Module):
|
| 87 |
+
_attention_backend = None
|
| 88 |
+
|
| 89 |
+
def __init__(self, dim: int, n_heads: int, n_kv_heads: int, qk_norm: bool = True, eps: float = 1e-5):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.n_heads = n_heads
|
| 92 |
+
self.n_kv_heads = n_kv_heads
|
| 93 |
+
self.head_dim = dim // n_heads
|
| 94 |
+
|
| 95 |
+
self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
|
| 96 |
+
self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
|
| 97 |
+
self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
|
| 98 |
+
self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)])
|
| 99 |
+
|
| 100 |
+
self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
|
| 101 |
+
self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
hidden_states: torch.Tensor,
|
| 106 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 107 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
query = self.to_q(hidden_states)
|
| 110 |
+
key = self.to_k(hidden_states)
|
| 111 |
+
value = self.to_v(hidden_states)
|
| 112 |
+
|
| 113 |
+
query = query.unflatten(-1, (self.n_heads, -1))
|
| 114 |
+
key = key.unflatten(-1, (self.n_kv_heads, -1))
|
| 115 |
+
value = value.unflatten(-1, (self.n_kv_heads, -1))
|
| 116 |
+
|
| 117 |
+
if self.norm_q is not None:
|
| 118 |
+
query = self.norm_q(query)
|
| 119 |
+
if self.norm_k is not None:
|
| 120 |
+
key = self.norm_k(key)
|
| 121 |
+
|
| 122 |
+
if freqs_cis is not None:
|
| 123 |
+
query = apply_rotary_emb(query, freqs_cis)
|
| 124 |
+
key = apply_rotary_emb(key, freqs_cis)
|
| 125 |
+
|
| 126 |
+
dtype = query.dtype
|
| 127 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 128 |
+
|
| 129 |
+
# Dispatch
|
| 130 |
+
from utils.attention import dispatch_attention
|
| 131 |
+
|
| 132 |
+
hidden_states = dispatch_attention(
|
| 133 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 137 |
+
hidden_states = hidden_states.to(dtype)
|
| 138 |
+
|
| 139 |
+
output = self.to_out[0](hidden_states)
|
| 140 |
+
return output
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ZImageTransformerBlock(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
layer_id: int,
|
| 147 |
+
dim: int,
|
| 148 |
+
n_heads: int,
|
| 149 |
+
n_kv_heads: int,
|
| 150 |
+
norm_eps: float,
|
| 151 |
+
qk_norm: bool,
|
| 152 |
+
modulation=True,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.dim = dim
|
| 156 |
+
self.head_dim = dim // n_heads
|
| 157 |
+
self.layer_id = layer_id
|
| 158 |
+
self.modulation = modulation
|
| 159 |
+
|
| 160 |
+
self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)
|
| 161 |
+
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
| 162 |
+
|
| 163 |
+
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 164 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 165 |
+
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 166 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 167 |
+
|
| 168 |
+
if modulation:
|
| 169 |
+
self.adaLN_modulation = nn.ModuleList([nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)])
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self,
|
| 173 |
+
x: torch.Tensor,
|
| 174 |
+
attn_mask: torch.Tensor,
|
| 175 |
+
freqs_cis: torch.Tensor,
|
| 176 |
+
adaln_input: Optional[torch.Tensor] = None,
|
| 177 |
+
):
|
| 178 |
+
if self.modulation:
|
| 179 |
+
assert adaln_input is not None
|
| 180 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = (
|
| 181 |
+
self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)
|
| 182 |
+
)
|
| 183 |
+
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
| 184 |
+
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
| 185 |
+
|
| 186 |
+
attn_out = self.attention(
|
| 187 |
+
self.attention_norm1(x) * scale_msa,
|
| 188 |
+
attention_mask=attn_mask,
|
| 189 |
+
freqs_cis=freqs_cis,
|
| 190 |
+
)
|
| 191 |
+
x = x + gate_msa * self.attention_norm2(attn_out)
|
| 192 |
+
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
| 193 |
+
else:
|
| 194 |
+
attn_out = self.attention(
|
| 195 |
+
self.attention_norm1(x),
|
| 196 |
+
attention_mask=attn_mask,
|
| 197 |
+
freqs_cis=freqs_cis,
|
| 198 |
+
)
|
| 199 |
+
x = x + self.attention_norm2(attn_out)
|
| 200 |
+
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
| 201 |
+
|
| 202 |
+
return x
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class FinalLayer(nn.Module):
|
| 206 |
+
def __init__(self, hidden_size, out_channels):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 209 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
| 210 |
+
self.adaLN_modulation = nn.Sequential(
|
| 211 |
+
nn.SiLU(),
|
| 212 |
+
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def forward(self, x, c):
|
| 216 |
+
scale = 1.0 + self.adaLN_modulation(c)
|
| 217 |
+
x = self.norm_final(x) * scale.unsqueeze(1)
|
| 218 |
+
x = self.linear(x)
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class RopeEmbedder:
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
theta: float = ROPE_THETA,
|
| 226 |
+
axes_dims: List[int] = ROPE_AXES_DIMS,
|
| 227 |
+
axes_lens: List[int] = ROPE_AXES_LENS,
|
| 228 |
+
):
|
| 229 |
+
self.theta = theta
|
| 230 |
+
self.axes_dims = axes_dims
|
| 231 |
+
self.axes_lens = axes_lens
|
| 232 |
+
assert len(axes_dims) == len(axes_lens)
|
| 233 |
+
self.freqs_cis = None
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = ROPE_THETA):
|
| 237 |
+
with torch.device("cpu"):
|
| 238 |
+
freqs_cis = []
|
| 239 |
+
for i, (d, e) in enumerate(zip(dim, end)):
|
| 240 |
+
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
| 241 |
+
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
| 242 |
+
freqs = torch.outer(timestep, freqs).float()
|
| 243 |
+
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
|
| 244 |
+
freqs_cis.append(freqs_cis_i)
|
| 245 |
+
return freqs_cis
|
| 246 |
+
|
| 247 |
+
def __call__(self, ids: torch.Tensor):
|
| 248 |
+
assert ids.ndim == 2
|
| 249 |
+
assert ids.shape[-1] == len(self.axes_dims)
|
| 250 |
+
device = ids.device
|
| 251 |
+
|
| 252 |
+
if self.freqs_cis is None:
|
| 253 |
+
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
| 254 |
+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
| 255 |
+
else:
|
| 256 |
+
if self.freqs_cis[0].device != device:
|
| 257 |
+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
| 258 |
+
|
| 259 |
+
result = []
|
| 260 |
+
for i in range(len(self.axes_dims)):
|
| 261 |
+
index = ids[:, i]
|
| 262 |
+
result.append(self.freqs_cis[i][index])
|
| 263 |
+
return torch.cat(result, dim=-1)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class ZImageTransformer2DModel(nn.Module):
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
all_patch_size=(2,),
|
| 270 |
+
all_f_patch_size=(1,),
|
| 271 |
+
in_channels=16,
|
| 272 |
+
dim=3840,
|
| 273 |
+
n_layers=30,
|
| 274 |
+
n_refiner_layers=2,
|
| 275 |
+
n_heads=30,
|
| 276 |
+
n_kv_heads=30,
|
| 277 |
+
norm_eps=1e-5,
|
| 278 |
+
qk_norm=True,
|
| 279 |
+
cap_feat_dim=2560,
|
| 280 |
+
rope_theta=ROPE_THETA,
|
| 281 |
+
t_scale=1000.0,
|
| 282 |
+
axes_dims=ROPE_AXES_DIMS,
|
| 283 |
+
axes_lens=ROPE_AXES_LENS,
|
| 284 |
+
):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.in_channels = in_channels
|
| 287 |
+
self.out_channels = in_channels
|
| 288 |
+
self.all_patch_size = all_patch_size
|
| 289 |
+
self.all_f_patch_size = all_f_patch_size
|
| 290 |
+
self.dim = dim
|
| 291 |
+
self.n_heads = n_heads
|
| 292 |
+
self.rope_theta = rope_theta
|
| 293 |
+
self.t_scale = t_scale
|
| 294 |
+
|
| 295 |
+
assert len(all_patch_size) == len(all_f_patch_size)
|
| 296 |
+
|
| 297 |
+
all_x_embedder = {}
|
| 298 |
+
all_final_layer = {}
|
| 299 |
+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
|
| 300 |
+
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
| 301 |
+
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
| 302 |
+
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
| 303 |
+
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
|
| 304 |
+
|
| 305 |
+
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
|
| 306 |
+
self.all_final_layer = nn.ModuleDict(all_final_layer)
|
| 307 |
+
|
| 308 |
+
self.noise_refiner = nn.ModuleList(
|
| 309 |
+
[
|
| 310 |
+
ZImageTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True)
|
| 311 |
+
for layer_id in range(n_refiner_layers)
|
| 312 |
+
]
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.context_refiner = nn.ModuleList(
|
| 316 |
+
[
|
| 317 |
+
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False)
|
| 318 |
+
for layer_id in range(n_refiner_layers)
|
| 319 |
+
]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
| 323 |
+
self.cap_embedder = nn.Sequential(
|
| 324 |
+
RMSNorm(cap_feat_dim, eps=norm_eps),
|
| 325 |
+
nn.Linear(cap_feat_dim, dim, bias=True),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
| 329 |
+
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
| 330 |
+
|
| 331 |
+
self.layers = nn.ModuleList(
|
| 332 |
+
[
|
| 333 |
+
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
|
| 334 |
+
for layer_id in range(n_layers)
|
| 335 |
+
]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
head_dim = dim // n_heads
|
| 339 |
+
assert head_dim == sum(axes_dims)
|
| 340 |
+
self.axes_dims = axes_dims
|
| 341 |
+
self.axes_lens = axes_lens
|
| 342 |
+
|
| 343 |
+
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
| 344 |
+
|
| 345 |
+
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
| 346 |
+
pH = pW = patch_size
|
| 347 |
+
pF = f_patch_size
|
| 348 |
+
bsz = len(x)
|
| 349 |
+
assert len(size) == bsz
|
| 350 |
+
for i in range(bsz):
|
| 351 |
+
F, H, W = size[i]
|
| 352 |
+
ori_len = (F // pF) * (H // pH) * (W // pW)
|
| 353 |
+
x[i] = (
|
| 354 |
+
x[i][:ori_len]
|
| 355 |
+
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
| 356 |
+
.permute(6, 0, 3, 1, 4, 2, 5)
|
| 357 |
+
.reshape(self.out_channels, F, H, W)
|
| 358 |
+
)
|
| 359 |
+
return x
|
| 360 |
+
|
| 361 |
+
@staticmethod
|
| 362 |
+
def create_coordinate_grid(size, start=None, device=None):
|
| 363 |
+
if start is None:
|
| 364 |
+
start = (0 for _ in size)
|
| 365 |
+
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
| 366 |
+
grids = torch.meshgrid(axes, indexing="ij")
|
| 367 |
+
return torch.stack(grids, dim=-1)
|
| 368 |
+
|
| 369 |
+
def patchify_and_embed(
|
| 370 |
+
self,
|
| 371 |
+
all_image: List[torch.Tensor],
|
| 372 |
+
all_cap_feats: List[torch.Tensor],
|
| 373 |
+
patch_size: int,
|
| 374 |
+
f_patch_size: int,
|
| 375 |
+
):
|
| 376 |
+
pH = pW = patch_size
|
| 377 |
+
pF = f_patch_size
|
| 378 |
+
device = all_image[0].device
|
| 379 |
+
|
| 380 |
+
all_image_out = []
|
| 381 |
+
all_image_size = []
|
| 382 |
+
all_image_pos_ids = []
|
| 383 |
+
all_image_pad_mask = []
|
| 384 |
+
all_cap_pos_ids = []
|
| 385 |
+
all_cap_pad_mask = []
|
| 386 |
+
all_cap_feats_out = []
|
| 387 |
+
|
| 388 |
+
for _, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
| 389 |
+
cap_ori_len = len(cap_feat)
|
| 390 |
+
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
| 391 |
+
cap_padded_pos_ids = self.create_coordinate_grid(
|
| 392 |
+
size=(cap_ori_len + cap_padding_len, 1, 1),
|
| 393 |
+
start=(1, 0, 0),
|
| 394 |
+
device=device,
|
| 395 |
+
).flatten(0, 2)
|
| 396 |
+
all_cap_pos_ids.append(cap_padded_pos_ids)
|
| 397 |
+
# pad mask
|
| 398 |
+
all_cap_pad_mask.append(
|
| 399 |
+
torch.cat(
|
| 400 |
+
[
|
| 401 |
+
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
| 402 |
+
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
| 403 |
+
],
|
| 404 |
+
dim=0,
|
| 405 |
+
)
|
| 406 |
+
if cap_padding_len > 0
|
| 407 |
+
else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
|
| 408 |
+
)
|
| 409 |
+
# padded feature
|
| 410 |
+
all_cap_feats_out.append(
|
| 411 |
+
torch.cat(
|
| 412 |
+
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
| 413 |
+
dim=0,
|
| 414 |
+
)
|
| 415 |
+
if cap_padding_len > 0
|
| 416 |
+
else cap_feat
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
C, F, H, W = image.size()
|
| 420 |
+
all_image_size.append((F, H, W))
|
| 421 |
+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
| 422 |
+
|
| 423 |
+
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
| 424 |
+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
| 425 |
+
|
| 426 |
+
image_ori_len = len(image)
|
| 427 |
+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
| 428 |
+
|
| 429 |
+
image_ori_pos_ids = self.create_coordinate_grid(
|
| 430 |
+
size=(F_tokens, H_tokens, W_tokens),
|
| 431 |
+
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
| 432 |
+
device=device,
|
| 433 |
+
).flatten(0, 2)
|
| 434 |
+
image_padded_pos_ids = torch.cat(
|
| 435 |
+
[
|
| 436 |
+
image_ori_pos_ids,
|
| 437 |
+
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
| 438 |
+
.flatten(0, 2)
|
| 439 |
+
.repeat(image_padding_len, 1),
|
| 440 |
+
],
|
| 441 |
+
dim=0,
|
| 442 |
+
)
|
| 443 |
+
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
|
| 444 |
+
# pad mask
|
| 445 |
+
image_pad_mask = torch.cat(
|
| 446 |
+
[
|
| 447 |
+
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
| 448 |
+
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
| 449 |
+
],
|
| 450 |
+
dim=0,
|
| 451 |
+
)
|
| 452 |
+
all_image_pad_mask.append(
|
| 453 |
+
image_pad_mask
|
| 454 |
+
if image_padding_len > 0
|
| 455 |
+
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
|
| 456 |
+
)
|
| 457 |
+
# padded feature
|
| 458 |
+
image_padded_feat = torch.cat(
|
| 459 |
+
[image, image[-1:].repeat(image_padding_len, 1)],
|
| 460 |
+
dim=0,
|
| 461 |
+
)
|
| 462 |
+
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
|
| 463 |
+
|
| 464 |
+
return (
|
| 465 |
+
all_image_out,
|
| 466 |
+
all_cap_feats_out,
|
| 467 |
+
all_image_size,
|
| 468 |
+
all_image_pos_ids,
|
| 469 |
+
all_cap_pos_ids,
|
| 470 |
+
all_image_pad_mask,
|
| 471 |
+
all_cap_pad_mask,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
def forward(
|
| 475 |
+
self,
|
| 476 |
+
x: List[torch.Tensor],
|
| 477 |
+
t,
|
| 478 |
+
cap_feats: List[torch.Tensor],
|
| 479 |
+
patch_size=2,
|
| 480 |
+
f_patch_size=1,
|
| 481 |
+
):
|
| 482 |
+
assert patch_size in self.all_patch_size
|
| 483 |
+
assert f_patch_size in self.all_f_patch_size
|
| 484 |
+
|
| 485 |
+
bsz = len(x)
|
| 486 |
+
device = x[0].device
|
| 487 |
+
t = t * self.t_scale
|
| 488 |
+
t = self.t_embedder(t)
|
| 489 |
+
|
| 490 |
+
(
|
| 491 |
+
x,
|
| 492 |
+
cap_feats,
|
| 493 |
+
x_size,
|
| 494 |
+
x_pos_ids,
|
| 495 |
+
cap_pos_ids,
|
| 496 |
+
x_inner_pad_mask,
|
| 497 |
+
cap_inner_pad_mask,
|
| 498 |
+
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
| 499 |
+
|
| 500 |
+
x_item_seqlens = [len(_) for _ in x]
|
| 501 |
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
| 502 |
+
x_max_item_seqlen = max(x_item_seqlens)
|
| 503 |
+
|
| 504 |
+
x = torch.cat(x, dim=0)
|
| 505 |
+
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
| 506 |
+
|
| 507 |
+
adaln_input = t.type_as(x)
|
| 508 |
+
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
| 509 |
+
x = list(x.split(x_item_seqlens, dim=0))
|
| 510 |
+
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
| 511 |
+
|
| 512 |
+
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
| 513 |
+
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
| 514 |
+
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
| 515 |
+
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
| 516 |
+
|
| 517 |
+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
| 518 |
+
for i, seq_len in enumerate(x_item_seqlens):
|
| 519 |
+
x_attn_mask[i, :seq_len] = 1
|
| 520 |
+
|
| 521 |
+
for layer in self.noise_refiner:
|
| 522 |
+
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
| 523 |
+
|
| 524 |
+
cap_item_seqlens = [len(_) for _ in cap_feats]
|
| 525 |
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
| 526 |
+
cap_max_item_seqlen = max(cap_item_seqlens)
|
| 527 |
+
|
| 528 |
+
cap_feats = torch.cat(cap_feats, dim=0)
|
| 529 |
+
cap_feats = self.cap_embedder(cap_feats)
|
| 530 |
+
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
| 531 |
+
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
| 532 |
+
cap_freqs_cis = list(
|
| 533 |
+
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
| 537 |
+
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
| 538 |
+
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] # same for dynamo compatibility
|
| 539 |
+
|
| 540 |
+
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
| 541 |
+
for i, seq_len in enumerate(cap_item_seqlens):
|
| 542 |
+
cap_attn_mask[i, :seq_len] = 1
|
| 543 |
+
|
| 544 |
+
for layer in self.context_refiner:
|
| 545 |
+
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
| 546 |
+
|
| 547 |
+
unified = []
|
| 548 |
+
unified_freqs_cis = []
|
| 549 |
+
for i in range(bsz):
|
| 550 |
+
x_len = x_item_seqlens[i]
|
| 551 |
+
cap_len = cap_item_seqlens[i]
|
| 552 |
+
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
| 553 |
+
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
| 554 |
+
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
| 555 |
+
assert unified_item_seqlens == [len(_) for _ in unified]
|
| 556 |
+
unified_max_item_seqlen = max(unified_item_seqlens)
|
| 557 |
+
|
| 558 |
+
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
| 559 |
+
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
| 560 |
+
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
| 561 |
+
for i, seq_len in enumerate(unified_item_seqlens):
|
| 562 |
+
unified_attn_mask[i, :seq_len] = 1
|
| 563 |
+
|
| 564 |
+
for layer in self.layers:
|
| 565 |
+
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
|
| 566 |
+
|
| 567 |
+
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
| 568 |
+
unified = list(unified.unbind(dim=0))
|
| 569 |
+
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
| 570 |
+
|
| 571 |
+
return x, {}
|
upscale_models/1x-ITF-SkinDiffDetail-Lite-v1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94d368b633614958f84f335b129fd85abd30200e8fbc575b859ba6762116222b
|
| 3 |
+
size 20099337
|
upscale_models/1x_PureVision.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c3bec2111d0f7c3926f3171f37ce28e72502b744e084c566d8960c06a6a06a3
|
| 3 |
+
size 67120607
|
upscale_models/2x_PureVision.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c49109c8257b80b3cddf0110786ad11c1ca05214a470bfcc7fd49b6461dfcaee
|
| 3 |
+
size 67037663
|
upscale_models/4x-ClearRealityV1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a4cd3a25b00e0be949d4302fc774eb4d7f2ed5f47cdb51551e2d75fa6562e51e
|
| 3 |
+
size 9016074
|
upscale_models/4x-UltraSharp.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5812231fc936b42af08a5edba784195495d303d5b3248c24489ef0c4021fe01
|
| 3 |
+
size 66961958
|
upscale_models/4xFFHQDAT.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa91faa8c1f72c32646d71abf51e952c81b4984948e89e6e5a8c40822a6cf3cc
|
| 3 |
+
size 154152604
|
upscale_models/4xNomos8k_atd_jpg.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b2bfb0e822c79594288dd43efaec213b6f0244384bd98db75072b0ce5a729fe
|
| 3 |
+
size 81959074
|
upscale_models/4xNomos8k_span_otf_weak.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26e7ef6483faf93b47f48af262ed4bb8dededa20af34a923e344cb63cafeec0a
|
| 3 |
+
size 9015866
|
upscale_models/4x_NMKD-Siax_200k.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:560424d9f68625713fc47e9e7289a98aabe1d744e1cd6a9ae5a35e9957fd127e
|
| 3 |
+
size 66957746
|
upscale_models/4x_NMKD-Superscale-SP_178000_G.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d1b0078fe71446e0469d8d4df59e96baa80d83cda600d68237d655830821bcc
|
| 3 |
+
size 66958607
|
upscale_models/4x_foolhardy_Remacri.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1a73bd89c2da1ae494774746398689048b5a892bd9653e146713f9df8bca86a
|
| 3 |
+
size 67025055
|
upscale_models/RealESRGAN_x4plus.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1
|
| 3 |
+
size 67040989
|
vae_approx/taew2_1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d26151e76cdc2c9424bef988de874b33d9a53f30ef3060cd556c429c469c797e
|
| 3 |
+
size 22678901
|
vitmatte/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bda9289db1bb6762d978b42d1c62ae3f34daf7497171a347a1d09657efd788cb
|
| 3 |
+
size 103294572
|