Snapmap commited on
Commit
201ff98
·
verified ·
1 Parent(s): 5f16a95

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. LLM/Florence-2-base/model.safetensors +3 -0
  3. LLM/Florence-2-base/pytorch_model.bin +3 -0
  4. SEEDVR2/ema_vae_fp16.safetensors +3 -0
  5. assets/bpe_simple_vocab_16e6.txt.gz +3 -0
  6. checkpoints/qwen_image_fp8_hq.safetensors +3 -0
  7. depthcrafter/stabilityai_stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors +3 -0
  8. detection/vitpose_h_wholebody_model.onnx +3 -0
  9. detection/yolov10m.onnx +3 -0
  10. mediapipe/selfie_multiclass_256x256.tflite +3 -0
  11. sams/bpe_simple_vocab_16e6.txt.gz +3 -0
  12. ultralytics/bbox/adetailerFootYolov8x_v20.pt +3 -0
  13. ultralytics/bbox/face_yolov8m.pt +3 -0
  14. ultralytics/bbox/face_yolov8m[1].pt +3 -0
  15. unet/Z-Image/assets/DMDR.webp +3 -0
  16. unet/Z-Image/assets/Z-Image-Gallery.pdf +3 -0
  17. unet/Z-Image/assets/architecture.webp +3 -0
  18. unet/Z-Image/assets/decoupled-dmd.webp +3 -0
  19. unet/Z-Image/assets/image_arena_all.jpg +3 -0
  20. unet/Z-Image/assets/reasoning.png +3 -0
  21. unet/Z-Image/assets/showcase.jpg +3 -0
  22. unet/Z-Image/src/config/__init__.py +91 -0
  23. unet/Z-Image/src/config/inference.py +8 -0
  24. unet/Z-Image/src/config/manifests/z-image-turbo.txt +20 -0
  25. unet/Z-Image/src/config/model.py +45 -0
  26. unet/Z-Image/src/tools/__init__.py +9 -0
  27. unet/Z-Image/src/tools/generate_manifest.py +127 -0
  28. unet/Z-Image/src/utils/__init__.py +15 -0
  29. unet/Z-Image/src/utils/attention.py +516 -0
  30. unet/Z-Image/src/utils/helpers.py +260 -0
  31. unet/Z-Image/src/utils/import_utils.py +31 -0
  32. unet/Z-Image/src/utils/loader.py +224 -0
  33. unet/Z-Image/src/zimage/__init__.py +9 -0
  34. unet/Z-Image/src/zimage/autoencoder.py +369 -0
  35. unet/Z-Image/src/zimage/pipeline.py +293 -0
  36. unet/Z-Image/src/zimage/transformer.py +571 -0
  37. upscale_models/1x-ITF-SkinDiffDetail-Lite-v1.pth +3 -0
  38. upscale_models/1x_PureVision.pth +3 -0
  39. upscale_models/2x_PureVision.pth +3 -0
  40. upscale_models/4x-ClearRealityV1.pth +3 -0
  41. upscale_models/4x-UltraSharp.pth +3 -0
  42. upscale_models/4xFFHQDAT.safetensors +3 -0
  43. upscale_models/4xNomos8k_atd_jpg.pth +3 -0
  44. upscale_models/4xNomos8k_span_otf_weak.pth +3 -0
  45. upscale_models/4x_NMKD-Siax_200k.pth +3 -0
  46. upscale_models/4x_NMKD-Superscale-SP_178000_G.pth +3 -0
  47. upscale_models/4x_foolhardy_Remacri.pth +3 -0
  48. upscale_models/RealESRGAN_x4plus.pth +3 -0
  49. vae_approx/taew2_1.pth +3 -0
  50. vitmatte/model.safetensors +3 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ unet/Z-Image/assets/Z-Image-Gallery.pdf filter=lfs diff=lfs merge=lfs -text
37
+ unet/Z-Image/assets/image_arena_all.jpg filter=lfs diff=lfs merge=lfs -text
38
+ unet/Z-Image/assets/decoupled-dmd.webp filter=lfs diff=lfs merge=lfs -text
39
+ unet/Z-Image/assets/architecture.webp filter=lfs diff=lfs merge=lfs -text
40
+ unet/Z-Image/assets/reasoning.png filter=lfs diff=lfs merge=lfs -text
41
+ unet/Z-Image/assets/DMDR.webp filter=lfs diff=lfs merge=lfs -text
42
+ unet/Z-Image/assets/showcase.jpg filter=lfs diff=lfs merge=lfs -text
LLM/Florence-2-base/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03075d2d2d2bbd3e180b9ba0afae4aa8563226e2d32911656966e05b2f2ee060
3
+ size 463221266
LLM/Florence-2-base/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b480ac374593b0dcb18ffa63b23213734e04cd43eab0d620d23e39708d4a4a7e
3
+ size 464421827
SEEDVR2/ema_vae_fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20678548f420d98d26f11442d3528f8b8c94e57ee046ef93dbb7633da8612ca1
3
+ size 501324814
assets/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
checkpoints/qwen_image_fp8_hq.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a61c42a58c181813fd94748df62ca1cdb53d0ec4b32c34af09375e5309126fa
3
+ size 89460
depthcrafter/stabilityai_stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af602cd0eb4ad6086ec94fbf1438dfb1be5ec9ac03fd0215640854e90d6463a3
3
+ size 195531910
detection/vitpose_h_wholebody_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f21466cd6c93d0066782ad5923c14a4e6569133def212dc2895c73596c2e553b
3
+ size 420252
detection/yolov10m.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89b526498a6d55f869a6ab52e3a2eb20ad45b3711c1f7de3dd9ca0b399dfd6d7
3
+ size 61659339
mediapipe/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
sams/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
ultralytics/bbox/adetailerFootYolov8x_v20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f39f32ab83b43002ca466605603b6c6dcff124ecfb23dab1c74c36ecb95cb4b
3
+ size 136712062
ultralytics/bbox/face_yolov8m.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:717923c19b3f4bbf5250b728f1fa6b2cb72a33aed1d236ea9caf0e21ad943e5f
3
+ size 52026019
ultralytics/bbox/face_yolov8m[1].pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:717923c19b3f4bbf5250b728f1fa6b2cb72a33aed1d236ea9caf0e21ad943e5f
3
+ size 52026019
unet/Z-Image/assets/DMDR.webp ADDED

Git LFS Details

  • SHA256: 2e6f3053b98d097f2aa11d3892bd9307326db41b65336bea54dc5825a0e03077
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
unet/Z-Image/assets/Z-Image-Gallery.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f9895b3246d2547bac74bbe0be975da500eaae93f2cad4248ad3281786b1ac6
3
+ size 15767436
unet/Z-Image/assets/architecture.webp ADDED

Git LFS Details

  • SHA256: 261af62ecc7e9749ae28e1d3a84e2f70a6c192d2017b7d8f020c7bff982ef59c
  • Pointer size: 131 Bytes
  • Size of remote file: 422 kB
unet/Z-Image/assets/decoupled-dmd.webp ADDED

Git LFS Details

  • SHA256: 4568ca559b997fc38f57dc1c3f5b1da3a3c144ae12419caa855ced972bf8c7aa
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB
unet/Z-Image/assets/image_arena_all.jpg ADDED

Git LFS Details

  • SHA256: 899a87527d6fe44068bf1928dc7af60baefaca9b9566034e7ec0f5b15e5e3833
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
unet/Z-Image/assets/reasoning.png ADDED

Git LFS Details

  • SHA256: 96c16b2c8d8dc67bb92ecc22d54b9955ab55136977f515bb76f4b2eb42eb3cdb
  • Pointer size: 132 Bytes
  • Size of remote file: 7.7 MB
unet/Z-Image/assets/showcase.jpg ADDED

Git LFS Details

  • SHA256: f6ee74e066e00596e429f5a08140aebae1678e5935ce1e11ca6c1c6cd72432ee
  • Pointer size: 132 Bytes
  • Size of remote file: 6.43 MB
unet/Z-Image/src/config/__init__.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Z-Image Configuration."""
2
+
3
+ from .inference import (
4
+ DEFAULT_CFG_TRUNCATION,
5
+ DEFAULT_GUIDANCE_SCALE,
6
+ DEFAULT_HEIGHT,
7
+ DEFAULT_INFERENCE_STEPS,
8
+ DEFAULT_MAX_SEQUENCE_LENGTH,
9
+ DEFAULT_WIDTH,
10
+ )
11
+ from .model import (
12
+ ADALN_EMBED_DIM,
13
+ BASE_IMAGE_SEQ_LEN,
14
+ BASE_SHIFT,
15
+ BYTES_PER_GB,
16
+ DEFAULT_LOAD_DEVICE,
17
+ DEFAULT_LOAD_DTYPE_STR,
18
+ DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
19
+ DEFAULT_SCHEDULER_SHIFT,
20
+ DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
21
+ DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
22
+ DEFAULT_TRANSFORMER_DIM,
23
+ DEFAULT_TRANSFORMER_F_PATCH_SIZE,
24
+ DEFAULT_TRANSFORMER_IN_CHANNELS,
25
+ DEFAULT_TRANSFORMER_N_HEADS,
26
+ DEFAULT_TRANSFORMER_N_KV_HEADS,
27
+ DEFAULT_TRANSFORMER_N_LAYERS,
28
+ DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
29
+ DEFAULT_TRANSFORMER_NORM_EPS,
30
+ DEFAULT_TRANSFORMER_PATCH_SIZE,
31
+ DEFAULT_TRANSFORMER_QK_NORM,
32
+ DEFAULT_TRANSFORMER_T_SCALE,
33
+ DEFAULT_VAE_IN_CHANNELS,
34
+ DEFAULT_VAE_LATENT_CHANNELS,
35
+ DEFAULT_VAE_NORM_NUM_GROUPS,
36
+ DEFAULT_VAE_OUT_CHANNELS,
37
+ DEFAULT_VAE_SCALE_FACTOR,
38
+ DEFAULT_VAE_SCALING_FACTOR,
39
+ FREQUENCY_EMBEDDING_SIZE,
40
+ MAX_IMAGE_SEQ_LEN,
41
+ MAX_PERIOD,
42
+ MAX_SHIFT,
43
+ ROPE_AXES_DIMS,
44
+ ROPE_AXES_LENS,
45
+ ROPE_THETA,
46
+ SEQ_MULTI_OF,
47
+ )
48
+
49
+ __all__ = [
50
+ "ADALN_EMBED_DIM",
51
+ "SEQ_MULTI_OF",
52
+ "ROPE_THETA",
53
+ "ROPE_AXES_DIMS",
54
+ "ROPE_AXES_LENS",
55
+ "FREQUENCY_EMBEDDING_SIZE",
56
+ "MAX_PERIOD",
57
+ "BASE_IMAGE_SEQ_LEN",
58
+ "MAX_IMAGE_SEQ_LEN",
59
+ "BASE_SHIFT",
60
+ "MAX_SHIFT",
61
+ "DEFAULT_VAE_SCALE_FACTOR",
62
+ "DEFAULT_VAE_IN_CHANNELS",
63
+ "DEFAULT_VAE_OUT_CHANNELS",
64
+ "DEFAULT_VAE_LATENT_CHANNELS",
65
+ "DEFAULT_VAE_NORM_NUM_GROUPS",
66
+ "DEFAULT_VAE_SCALING_FACTOR",
67
+ "DEFAULT_TRANSFORMER_PATCH_SIZE",
68
+ "DEFAULT_TRANSFORMER_F_PATCH_SIZE",
69
+ "DEFAULT_TRANSFORMER_IN_CHANNELS",
70
+ "DEFAULT_TRANSFORMER_DIM",
71
+ "DEFAULT_TRANSFORMER_N_LAYERS",
72
+ "DEFAULT_TRANSFORMER_N_REFINER_LAYERS",
73
+ "DEFAULT_TRANSFORMER_N_HEADS",
74
+ "DEFAULT_TRANSFORMER_N_KV_HEADS",
75
+ "DEFAULT_TRANSFORMER_NORM_EPS",
76
+ "DEFAULT_TRANSFORMER_QK_NORM",
77
+ "DEFAULT_TRANSFORMER_CAP_FEAT_DIM",
78
+ "DEFAULT_TRANSFORMER_T_SCALE",
79
+ "DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS",
80
+ "DEFAULT_SCHEDULER_SHIFT",
81
+ "DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING",
82
+ "DEFAULT_LOAD_DEVICE",
83
+ "DEFAULT_LOAD_DTYPE_STR",
84
+ "BYTES_PER_GB",
85
+ "DEFAULT_HEIGHT",
86
+ "DEFAULT_WIDTH",
87
+ "DEFAULT_INFERENCE_STEPS",
88
+ "DEFAULT_GUIDANCE_SCALE",
89
+ "DEFAULT_CFG_TRUNCATION",
90
+ "DEFAULT_MAX_SEQUENCE_LENGTH",
91
+ ]
unet/Z-Image/src/config/inference.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Inference-specific configuration for Z-Image."""
2
+
3
+ DEFAULT_HEIGHT = 1024
4
+ DEFAULT_WIDTH = 1024
5
+ DEFAULT_INFERENCE_STEPS = 8
6
+ DEFAULT_GUIDANCE_SCALE = 0.0
7
+ DEFAULT_CFG_TRUNCATION = 1.0
8
+ DEFAULT_MAX_SEQUENCE_LENGTH = 512
unet/Z-Image/src/config/manifests/z-image-turbo.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Z-Image Model Manifest
2
+ # Format: <md5hash> <filepath>
3
+ # Generated automatically - DO NOT edit manually
4
+
5
+ 5e3226ed72a9a4a080f2a4ca78b98ddc model_index.json
6
+ ca682fcc6c5a94cf726b7187e64b9411 scheduler/scheduler_config.json
7
+ 1e97eb35d9d0b6aa60c58a8df8d7d99a text_encoder/config.json
8
+ 30b85686b9a9b002e012494fadc027cb text_encoder/model-00001-of-00003.safetensors
9
+ e6a24ea164404a01ad2800dbae4e1a13 text_encoder/model-00002-of-00003.safetensors
10
+ 09e190ed15ff14795b6277e023cfcb2d text_encoder/model-00003-of-00003.safetensors
11
+ 589f5395156900f49d617aee8a8d8708 text_encoder/model.safetensors.index.json
12
+ 6423133b9cc1a2077b57822c30c211aa tokenizer/tokenizer.json
13
+ b06e103ac555ec4b51266078b518c0f0 tokenizer/tokenizer_config.json
14
+ baed87136fe5f848e24b072f99856cc3 transformer/config.json
15
+ 54889d0dd179b4fa2fd7bd0e487d856e transformer/diffusion_pytorch_model-00001-of-00003.safetensors
16
+ fe81e804658d345323512c63224b0604 transformer/diffusion_pytorch_model-00002-of-00003.safetensors
17
+ 4e074e09129f98ad840414951f122feb transformer/diffusion_pytorch_model-00003-of-00003.safetensors
18
+ 76d788eb0d42c59cc8f8ec007db639aa transformer/diffusion_pytorch_model.safetensors.index.json
19
+ ba9e2980c8630b4abccc643bc9f4a542 vae/config.json
20
+ 6f83de55cb720c7fae051b14528577bf vae/diffusion_pytorch_model.safetensors
unet/Z-Image/src/config/model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model configuration constants for Z-Image."""
2
+
3
+ ADALN_EMBED_DIM = 256
4
+ SEQ_MULTI_OF = 32
5
+
6
+ ROPE_THETA = 256.0
7
+ ROPE_AXES_DIMS = [32, 48, 48]
8
+ ROPE_AXES_LENS = [1536, 512, 512]
9
+
10
+ FREQUENCY_EMBEDDING_SIZE = 256
11
+ MAX_PERIOD = 10000
12
+
13
+ BASE_IMAGE_SEQ_LEN = 256
14
+ MAX_IMAGE_SEQ_LEN = 4096
15
+ BASE_SHIFT = 0.5
16
+ MAX_SHIFT = 1.15
17
+
18
+ DEFAULT_VAE_SCALE_FACTOR = 8
19
+ DEFAULT_VAE_IN_CHANNELS = 3
20
+ DEFAULT_VAE_OUT_CHANNELS = 3
21
+ DEFAULT_VAE_LATENT_CHANNELS = 4
22
+ DEFAULT_VAE_NORM_NUM_GROUPS = 32
23
+ DEFAULT_VAE_SCALING_FACTOR = 0.18215
24
+
25
+ DEFAULT_TRANSFORMER_PATCH_SIZE = (2,)
26
+ DEFAULT_TRANSFORMER_F_PATCH_SIZE = (1,)
27
+ DEFAULT_TRANSFORMER_IN_CHANNELS = 16
28
+ DEFAULT_TRANSFORMER_DIM = 3840
29
+ DEFAULT_TRANSFORMER_N_LAYERS = 30
30
+ DEFAULT_TRANSFORMER_N_REFINER_LAYERS = 2
31
+ DEFAULT_TRANSFORMER_N_HEADS = 30
32
+ DEFAULT_TRANSFORMER_N_KV_HEADS = 30
33
+ DEFAULT_TRANSFORMER_NORM_EPS = 1e-5
34
+ DEFAULT_TRANSFORMER_QK_NORM = True
35
+ DEFAULT_TRANSFORMER_CAP_FEAT_DIM = 2560
36
+ DEFAULT_TRANSFORMER_T_SCALE = 1000.0
37
+
38
+ DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS = 1000
39
+ DEFAULT_SCHEDULER_SHIFT = 3.0
40
+ DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING = False
41
+
42
+ DEFAULT_LOAD_DEVICE = "cuda"
43
+ DEFAULT_LOAD_DTYPE_STR = "bfloat16"
44
+
45
+ BYTES_PER_GB = 2**30
unet/Z-Image/src/tools/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Tools for Z-Image model management."""
2
+
3
+ from .generate_manifest import compute_md5, get_essential_files
4
+
5
+ __all__ = [
6
+ "compute_md5",
7
+ "get_essential_files",
8
+ ]
9
+
unet/Z-Image/src/tools/generate_manifest.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate manifest file with MD5 checksums for model weights.
3
+
4
+ Usage:
5
+ python -m tools.generate_manifest ckpts/Z-Image-Turbo
6
+ python -m tools.generate_manifest ckpts/Z-Image-Turbo --no-checksums # Only list files
7
+ """
8
+
9
+ import argparse
10
+ import hashlib
11
+ from pathlib import Path
12
+ from typing import List
13
+
14
+
15
+ def compute_md5(file_path: Path, chunk_size: int = 8192) -> str:
16
+ """Compute MD5 hash of a file."""
17
+ md5_hash = hashlib.md5()
18
+ with open(file_path, "rb") as f:
19
+ while chunk := f.read(chunk_size):
20
+ md5_hash.update(chunk)
21
+ return md5_hash.hexdigest()
22
+
23
+
24
+ def get_essential_files(model_dir: Path) -> List[Path]:
25
+ """Get list of essential model files."""
26
+ essential_patterns = [
27
+ "model_index.json",
28
+ "transformer/config.json",
29
+ "transformer/*.safetensors*",
30
+ "vae/config.json",
31
+ "vae/*.safetensors",
32
+ "text_encoder/config.json",
33
+ "text_encoder/*.safetensors*",
34
+ "tokenizer/tokenizer.json",
35
+ "tokenizer/tokenizer_config.json",
36
+ "scheduler/scheduler_config.json",
37
+ ]
38
+
39
+ files = []
40
+ for pattern in essential_patterns:
41
+ if "*" in pattern:
42
+ files.extend(model_dir.glob(pattern))
43
+ else:
44
+ file_path = model_dir / pattern
45
+ if file_path.exists():
46
+ files.append(file_path)
47
+
48
+ return sorted(files)
49
+
50
+
51
+ def main():
52
+ parser = argparse.ArgumentParser(description="Generate manifest file for model weights")
53
+ parser.add_argument("model_dir", type=str, help="Path to model directory")
54
+ parser.add_argument("--output", "-o", type=str, default=None,
55
+ help="Output manifest file path (default: auto-detect to config/manifests/)")
56
+ parser.add_argument("--no-checksums", action="store_true",
57
+ help="Only list files without computing checksums")
58
+ parser.add_argument("--verbose", "-v", action="store_true",
59
+ help="Print progress")
60
+
61
+ args = parser.parse_args()
62
+
63
+ model_dir = Path(args.model_dir)
64
+ if not model_dir.exists():
65
+ print(f"Error: Model directory not found: {model_dir}")
66
+ return 1
67
+
68
+ # Determine output path
69
+ if args.output:
70
+ output_file = Path(args.output)
71
+ else:
72
+ # Auto-detect: save to config/manifests/{model-name}.txt
73
+ model_name = model_dir.name.lower() # e.g., "Z-Image-Turbo" -> "z-image-turbo"
74
+ script_dir = Path(__file__).parent
75
+ config_dir = script_dir.parent / "config" / "manifests"
76
+ config_dir.mkdir(parents=True, exist_ok=True)
77
+ output_file = config_dir / f"{model_name}.txt"
78
+
79
+ # Get essential files
80
+ files = get_essential_files(model_dir)
81
+
82
+ if not files:
83
+ print(f"Warning: No essential files found in {model_dir}")
84
+ return 1
85
+
86
+ print(f"Found {len(files)} essential files")
87
+
88
+ # Generate manifest
89
+ with open(output_file, "w", encoding="utf-8") as f:
90
+ f.write("# Z-Image Model Manifest\n")
91
+ if args.no_checksums:
92
+ f.write("# Format: <filepath>\n")
93
+ else:
94
+ f.write("# Format: <md5hash> <filepath>\n")
95
+ f.write("# Generated automatically - DO NOT edit manually\n\n")
96
+
97
+ for file_path in files:
98
+ rel_path = file_path.relative_to(model_dir)
99
+
100
+ if args.no_checksums:
101
+ f.write(f"{rel_path}\n")
102
+ if args.verbose:
103
+ print(f" {rel_path}")
104
+ else:
105
+ if args.verbose:
106
+ print(f"Computing MD5 for {rel_path}...", end=" ", flush=True)
107
+
108
+ try:
109
+ md5_hash = compute_md5(file_path)
110
+ f.write(f"{md5_hash} {rel_path}\n")
111
+ if args.verbose:
112
+ print(f"✓ {md5_hash}")
113
+ except Exception as e:
114
+ print(f"✗ Error: {e}")
115
+ continue
116
+
117
+ print(f"\n✓ Manifest saved to: {output_file}")
118
+ print(f" Total files: {len(files)}")
119
+ if not args.no_checksums:
120
+ print(f" With MD5 checksums for integrity verification")
121
+
122
+ return 0
123
+
124
+
125
+ if __name__ == "__main__":
126
+ exit(main())
127
+
unet/Z-Image/src/utils/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for Z-Image."""
2
+
3
+ from .attention import AttentionBackend, dispatch_attention, set_attention_backend
4
+ from .helpers import format_bytes, print_memory_stats, ensure_model_weights
5
+ from .loader import load_from_local_dir
6
+
7
+ __all__ = [
8
+ "load_from_local_dir",
9
+ "format_bytes",
10
+ "print_memory_stats",
11
+ "ensure_model_weights",
12
+ "AttentionBackend",
13
+ "set_attention_backend",
14
+ "dispatch_attention",
15
+ ]
unet/Z-Image/src/utils/attention.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Attention backend utilities for Z-Image."""
2
+
3
+ # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py
4
+ from enum import Enum
5
+ import functools
6
+ import inspect
7
+ from typing import Callable, Dict, List, Optional, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from .import_utils import is_flash_attn_3_available, is_flash_attn_available, is_torch_version
13
+
14
+ _CAN_USE_FLASH_ATTN_2 = is_flash_attn_available()
15
+ _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
16
+
17
+ # MPS Flash Attention (Apple Silicon)
18
+ try:
19
+ import mps_flash_attn
20
+ _CAN_USE_MPS_FLASH = mps_flash_attn.is_available()
21
+ except ImportError:
22
+ _CAN_USE_MPS_FLASH = False
23
+ mps_flash_attn = None
24
+ _TORCH_VERSION_CHECK = is_torch_version(">=", "2.5.0") # have enable_gqa func call in SPDA
25
+
26
+ if not _TORCH_VERSION_CHECK:
27
+ raise RuntimeError("PyTorch version must be >= 2.5.0 to use this backend.")
28
+ else:
29
+ print("PyTorch version is >= 2.5.0, check pass.")
30
+
31
+ if _CAN_USE_FLASH_ATTN_2:
32
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
33
+ else:
34
+ flash_attn_func = None
35
+ flash_attn_varlen_func = None
36
+
37
+ if _CAN_USE_FLASH_ATTN_3:
38
+ from flash_attn_interface import (
39
+ flash_attn_func as flash_attn_3_func,
40
+ flash_attn_varlen_func as flash_attn_3_varlen_func,
41
+ )
42
+
43
+ _flash_attn_3_sig = inspect.signature(flash_attn_3_func)
44
+ _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = "return_attn_probs" in _flash_attn_3_sig.parameters
45
+ else:
46
+ flash_attn_3_func = None
47
+ flash_attn_3_varlen_func = None
48
+ _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS = False
49
+
50
+
51
+ class AttentionBackend(str, Enum):
52
+ """Supported attention backends."""
53
+
54
+ # Flash Attention
55
+ FLASH = "flash"
56
+ FLASH_VARLEN = "flash_varlen"
57
+ FLASH_3 = "_flash_3"
58
+ FLASH_VARLEN_3 = "_flash_varlen_3"
59
+ # MPS Flash Attention (Apple Silicon)
60
+ MPS_FLASH = "mps_flash"
61
+ # PyTorch Native Backends
62
+ NATIVE = "native"
63
+ NATIVE_FLASH = "_native_flash"
64
+ NATIVE_MATH = "_native_math"
65
+
66
+ @classmethod
67
+ def print_available_backends(cls):
68
+ available_backends = [backend.value for backend in cls.__members__.values()]
69
+ print(f"Available attention backends list: {available_backends}")
70
+
71
+
72
+ # Registry for attention implementations
73
+ _ATTENTION_BACKENDS: Dict[str, Callable] = {}
74
+ _ATTENTION_CONSTRAINTS: Dict[str, List[Callable]] = {}
75
+
76
+
77
+ def register_backend(name: str, constraints: Optional[List[Callable]] = None):
78
+ def decorator(func):
79
+ _ATTENTION_BACKENDS[name] = func
80
+ _ATTENTION_CONSTRAINTS[name] = constraints or []
81
+ return func
82
+
83
+ return decorator
84
+
85
+
86
+ # --- Checks ---
87
+ def _check_device_cuda(query: torch.Tensor, **kwargs) -> None:
88
+ if query.device.type != "cuda":
89
+ raise ValueError("Query must be on a CUDA device.")
90
+
91
+
92
+ def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, **kwargs) -> None:
93
+ if query.dtype not in (torch.bfloat16, torch.float16):
94
+ raise ValueError("Query must be either bfloat16 or float16.")
95
+
96
+
97
+ def _check_device_mps(query: torch.Tensor, **kwargs) -> None:
98
+ if query.device.type != "mps":
99
+ raise ValueError("Query must be on MPS device.")
100
+
101
+
102
+ def _process_mask(attn_mask: Optional[torch.Tensor], dtype: torch.dtype):
103
+ if attn_mask is None:
104
+ return None
105
+
106
+ if attn_mask.ndim == 2:
107
+ attn_mask = attn_mask[:, None, None, :]
108
+
109
+ # Convert bool mask to float additive mask
110
+ if attn_mask.dtype == torch.bool:
111
+ # NOTE: We skip checking for all-True mask (torch.all) to avoid graph breaks in torch.compile
112
+ new_mask = torch.zeros_like(attn_mask, dtype=dtype)
113
+ new_mask.masked_fill_(~attn_mask, float("-inf"))
114
+ return new_mask
115
+
116
+ return attn_mask
117
+
118
+
119
+ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
120
+ """Normalize an attention mask to shape [batch_size, seq_len_k] (bool)."""
121
+ if attn_mask.dtype != torch.bool:
122
+ # Try to convert float mask back to bool if possible, or assume it's float mask
123
+ # For varlen flash attn, we strictly need bool mask indicating valid tokens
124
+ if torch.is_floating_point(attn_mask):
125
+ return attn_mask > -1 # Assuming -inf is masked
126
+ # raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
127
+
128
+ if attn_mask.ndim == 1:
129
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
130
+ elif attn_mask.ndim == 2:
131
+ if attn_mask.size(0) not in [1, batch_size]:
132
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
133
+ elif attn_mask.ndim == 3:
134
+ attn_mask = attn_mask.any(dim=1)
135
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
136
+ elif attn_mask.ndim == 4:
137
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k)
138
+ attn_mask = attn_mask.any(dim=(1, 2))
139
+
140
+ if attn_mask.shape != (batch_size, seq_len_k):
141
+ # Fallback reshape
142
+ return attn_mask.view(batch_size, seq_len_k)
143
+
144
+ return attn_mask
145
+
146
+
147
+ @functools.lru_cache(maxsize=128)
148
+ def _prepare_for_flash_attn_varlen_without_mask(
149
+ batch_size: int,
150
+ seq_len_q: int,
151
+ seq_len_kv: int,
152
+ device: Optional[torch.device] = None,
153
+ ):
154
+ # Optimized to avoid Inductor "pointless_cumsum_replacement" crash and remove graph breaks
155
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
156
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
157
+
158
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
159
+ cu_seqlens_k = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_kv
160
+
161
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (seq_len_q, seq_len_kv)
162
+
163
+
164
+ def _prepare_for_flash_attn_varlen_with_mask(
165
+ batch_size: int,
166
+ seq_len_q: int,
167
+ attn_mask: torch.Tensor,
168
+ device: Optional[torch.device] = None,
169
+ ):
170
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
171
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
172
+ # Use arange for Q to avoid Inductor crash
173
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * seq_len_q
174
+
175
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
176
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
177
+
178
+ max_seqlen_q = seq_len_q
179
+ max_seqlen_k = attn_mask.shape[1] # not max().item(), static shape to avoid graph break
180
+
181
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
182
+
183
+
184
+ def _prepare_for_flash_attn_varlen(
185
+ batch_size: int,
186
+ seq_len_q: int,
187
+ seq_len_kv: int,
188
+ attn_mask: Optional[torch.Tensor] = None,
189
+ device: Optional[torch.device] = None,
190
+ ) -> None:
191
+ if attn_mask is None:
192
+ return _prepare_for_flash_attn_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
193
+ return _prepare_for_flash_attn_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
194
+
195
+
196
+ @register_backend(AttentionBackend.FLASH, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
197
+ def _flash_attention(
198
+ query: torch.Tensor,
199
+ key: torch.Tensor,
200
+ value: torch.Tensor,
201
+ attn_mask: Optional[torch.Tensor] = None,
202
+ dropout_p: float = 0.0,
203
+ is_causal: bool = False,
204
+ scale: Optional[float] = None,
205
+ ) -> torch.Tensor:
206
+ if not _CAN_USE_FLASH_ATTN_2:
207
+ raise RuntimeError(
208
+ f"Flash Attention backend '{AttentionBackend.FLASH}' is not usable because of missing package."
209
+ )
210
+
211
+ out = flash_attn_func(
212
+ q=query,
213
+ k=key,
214
+ v=value,
215
+ dropout_p=dropout_p,
216
+ softmax_scale=scale,
217
+ causal=is_causal,
218
+ )
219
+ return out
220
+
221
+
222
+ @register_backend(AttentionBackend.FLASH_VARLEN, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
223
+ def _flash_varlen_attention(
224
+ query: torch.Tensor,
225
+ key: torch.Tensor,
226
+ value: torch.Tensor,
227
+ attn_mask: Optional[torch.Tensor] = None,
228
+ dropout_p: float = 0.0,
229
+ is_causal: bool = False,
230
+ scale: Optional[float] = None,
231
+ ) -> torch.Tensor:
232
+ if not _CAN_USE_FLASH_ATTN_2:
233
+ raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN}' requires flash-attn.")
234
+
235
+ batch_size, seq_len_q, _, _ = query.shape
236
+ _, seq_len_kv, _, _ = key.shape
237
+
238
+ if attn_mask is not None:
239
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
240
+
241
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
242
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
243
+ )
244
+
245
+ query_packed = query.flatten(0, 1)
246
+
247
+ if attn_mask is not None:
248
+ key_valid = []
249
+ value_valid = []
250
+ for b in range(batch_size):
251
+ valid_len = seqlens_k[b]
252
+ key_valid.append(key[b, :valid_len])
253
+ value_valid.append(value[b, :valid_len])
254
+ key_packed = torch.cat(key_valid, dim=0)
255
+ value_packed = torch.cat(value_valid, dim=0)
256
+ else:
257
+ key_packed = key.flatten(0, 1)
258
+ value_packed = value.flatten(0, 1)
259
+
260
+ out = flash_attn_varlen_func(
261
+ q=query_packed,
262
+ k=key_packed,
263
+ v=value_packed,
264
+ cu_seqlens_q=cu_seqlens_q,
265
+ cu_seqlens_k=cu_seqlens_k,
266
+ max_seqlen_q=max_seqlen_q,
267
+ max_seqlen_k=max_seqlen_k,
268
+ dropout_p=dropout_p,
269
+ softmax_scale=scale,
270
+ causal=is_causal,
271
+ )
272
+ out = out.unflatten(0, (batch_size, -1))
273
+ return out
274
+
275
+
276
+ @register_backend(AttentionBackend.FLASH_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
277
+ def _flash_attention_3(
278
+ query: torch.Tensor,
279
+ key: torch.Tensor,
280
+ value: torch.Tensor,
281
+ attn_mask: Optional[torch.Tensor] = None, # Unused in simple FA3 func
282
+ dropout_p: float = 0.0,
283
+ is_causal: bool = False,
284
+ scale: Optional[float] = None,
285
+ ) -> torch.Tensor:
286
+ if not _CAN_USE_FLASH_ATTN_3:
287
+ raise RuntimeError(f"Backend '{AttentionBackend.FLASH_3}' requires Flash Attention 3 beta.")
288
+
289
+ kwargs = {
290
+ "q": query,
291
+ "k": key,
292
+ "v": value,
293
+ "softmax_scale": scale,
294
+ "causal": is_causal,
295
+ }
296
+
297
+ if _FLASH_ATTN_3_SUPPORTS_RETURN_PROBS:
298
+ kwargs["return_attn_probs"] = False
299
+
300
+ out = flash_attn_3_func(**kwargs)
301
+
302
+ if isinstance(out, tuple):
303
+ out = out[0]
304
+
305
+ return out
306
+
307
+
308
+ @register_backend(AttentionBackend.FLASH_VARLEN_3, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16])
309
+ def _flash_varlen_attention_3(
310
+ query: torch.Tensor,
311
+ key: torch.Tensor,
312
+ value: torch.Tensor,
313
+ attn_mask: Optional[torch.Tensor] = None,
314
+ dropout_p: float = 0.0,
315
+ is_causal: bool = False,
316
+ scale: Optional[float] = None,
317
+ ) -> torch.Tensor:
318
+ if not _CAN_USE_FLASH_ATTN_3:
319
+ raise RuntimeError(f"Backend '{AttentionBackend.FLASH_VARLEN_3}' requires Flash Attention 3 beta.")
320
+
321
+ batch_size, seq_len_q, _, _ = query.shape
322
+ _, seq_len_kv, _, _ = key.shape
323
+
324
+ if attn_mask is not None:
325
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
326
+
327
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_varlen(
328
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
329
+ )
330
+
331
+ query_packed = query.flatten(0, 1)
332
+
333
+ if attn_mask is not None:
334
+ key_valid = []
335
+ value_valid = []
336
+ for b in range(batch_size):
337
+ valid_len = seqlens_k[b]
338
+ key_valid.append(key[b, :valid_len])
339
+ value_valid.append(value[b, :valid_len])
340
+ key_packed = torch.cat(key_valid, dim=0)
341
+ value_packed = torch.cat(value_valid, dim=0)
342
+ else:
343
+ key_packed = key.flatten(0, 1)
344
+ value_packed = value.flatten(0, 1)
345
+
346
+ kwargs = {
347
+ "q": query_packed,
348
+ "k": key_packed,
349
+ "v": value_packed,
350
+ "cu_seqlens_q": cu_seqlens_q,
351
+ "cu_seqlens_k": cu_seqlens_k,
352
+ "max_seqlen_q": max_seqlen_q,
353
+ "max_seqlen_k": max_seqlen_k,
354
+ "softmax_scale": scale,
355
+ "causal": is_causal,
356
+ }
357
+
358
+ supports_return_probs = "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters
359
+
360
+ if supports_return_probs:
361
+ kwargs["return_attn_probs"] = False
362
+
363
+ out = flash_attn_3_varlen_func(**kwargs)
364
+
365
+ if isinstance(out, tuple):
366
+ out = out[0]
367
+
368
+ out = out.unflatten(0, (batch_size, -1))
369
+ return out
370
+
371
+
372
+ @register_backend(AttentionBackend.MPS_FLASH, constraints=[_check_device_mps, _check_qkv_dtype_bf16_or_fp16])
373
+ def _mps_flash_attention(
374
+ query: torch.Tensor,
375
+ key: torch.Tensor,
376
+ value: torch.Tensor,
377
+ attn_mask: Optional[torch.Tensor] = None,
378
+ dropout_p: float = 0.0,
379
+ is_causal: bool = False,
380
+ scale: Optional[float] = None,
381
+ ) -> torch.Tensor:
382
+ """MPS Flash Attention for Apple Silicon (M1/M2/M3/M4)."""
383
+ if not _CAN_USE_MPS_FLASH:
384
+ raise RuntimeError(
385
+ f"MPS Flash Attention backend '{AttentionBackend.MPS_FLASH}' requires mps-flash-attn package. "
386
+ "Install with: pip install mps-flash-attn"
387
+ )
388
+
389
+ # Convert from (B, S, H, D) to (B, H, S, D) for mps-flash-attn
390
+ query = query.transpose(1, 2)
391
+ key = key.transpose(1, 2)
392
+ value = value.transpose(1, 2)
393
+
394
+ # Convert mask to MFA format (bool, True = masked)
395
+ mfa_mask = None
396
+ if attn_mask is not None:
397
+ mfa_mask = mps_flash_attn.convert_mask(_process_mask(attn_mask, query.dtype))
398
+
399
+ out = mps_flash_attn.flash_attention(
400
+ query, key, value,
401
+ is_causal=is_causal,
402
+ scale=scale,
403
+ attn_mask=mfa_mask,
404
+ )
405
+
406
+ # Convert back to (B, S, H, D)
407
+ return out.transpose(1, 2).contiguous()
408
+
409
+
410
+ def _native_attention_wrapper(
411
+ query: torch.Tensor,
412
+ key: torch.Tensor,
413
+ value: torch.Tensor,
414
+ attn_mask: Optional[torch.Tensor] = None,
415
+ dropout_p: float = 0.0,
416
+ is_causal: bool = False,
417
+ scale: Optional[float] = None,
418
+ backend_kernel=None,
419
+ ) -> torch.Tensor:
420
+
421
+ query = query.transpose(1, 2)
422
+ key = key.transpose(1, 2)
423
+ value = value.transpose(1, 2)
424
+ attn_mask = _process_mask(attn_mask, query.dtype)
425
+
426
+ if backend_kernel is not None:
427
+ with torch.nn.attention.sdpa_kernel(backend_kernel):
428
+ out = F.scaled_dot_product_attention(
429
+ query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
430
+ )
431
+ else:
432
+ out = F.scaled_dot_product_attention(
433
+ query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
434
+ )
435
+
436
+ return out.transpose(1, 2).contiguous()
437
+
438
+
439
+ @register_backend(AttentionBackend.NATIVE_FLASH)
440
+ def _native_flash_attention(
441
+ query: torch.Tensor,
442
+ key: torch.Tensor,
443
+ value: torch.Tensor,
444
+ attn_mask: Optional[torch.Tensor] = None,
445
+ dropout_p: float = 0.0,
446
+ is_causal: bool = False,
447
+ scale: Optional[float] = None,
448
+ ) -> torch.Tensor:
449
+ return _native_attention_wrapper(
450
+ query,
451
+ key,
452
+ value,
453
+ attn_mask=None,
454
+ dropout_p=dropout_p,
455
+ is_causal=is_causal,
456
+ scale=scale,
457
+ backend_kernel=torch.nn.attention.SDPBackend.FLASH_ATTENTION,
458
+ )
459
+
460
+
461
+ @register_backend(AttentionBackend.NATIVE_MATH)
462
+ def _math_attention(*args, **kwargs):
463
+ return _native_attention_wrapper(*args, **kwargs, backend_kernel=torch.nn.attention.SDPBackend.MATH)
464
+
465
+
466
+ @register_backend(AttentionBackend.NATIVE)
467
+ def _native_attention(*args, **kwargs):
468
+ return _native_attention_wrapper(*args, **kwargs, backend_kernel=None)
469
+
470
+
471
+ def dispatch_attention(
472
+ query: torch.Tensor,
473
+ key: torch.Tensor,
474
+ value: torch.Tensor,
475
+ attn_mask: Optional[torch.Tensor] = None,
476
+ dropout_p: float = 0.0,
477
+ is_causal: bool = False,
478
+ scale: Optional[float] = None,
479
+ backend: Union[str, AttentionBackend, None] = None,
480
+ ) -> torch.Tensor:
481
+
482
+ if isinstance(backend, AttentionBackend):
483
+ backend = backend.value
484
+ elif backend is None:
485
+ backend = AttentionBackend.NATIVE
486
+ else:
487
+ backend = str(backend)
488
+
489
+ # Explicit dispatch to avoid dynamo guard issues on global dict
490
+ if backend == AttentionBackend.FLASH:
491
+ return _flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
492
+ elif backend == AttentionBackend.FLASH_VARLEN:
493
+ return _flash_varlen_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
494
+ elif backend == AttentionBackend.FLASH_3:
495
+ return _flash_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
496
+ elif backend == AttentionBackend.FLASH_VARLEN_3:
497
+ return _flash_varlen_attention_3(query, key, value, attn_mask, dropout_p, is_causal, scale)
498
+ elif backend == AttentionBackend.MPS_FLASH:
499
+ return _mps_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
500
+ elif backend == AttentionBackend.NATIVE_FLASH:
501
+ return _native_flash_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
502
+ elif backend == AttentionBackend.NATIVE_MATH:
503
+ return _math_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
504
+ else:
505
+ return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale)
506
+
507
+
508
+ def set_attention_backend(backend: Union[str, AttentionBackend, None]):
509
+ try:
510
+ from zimage.transformer import ZImageAttention
511
+
512
+ if backend is not None:
513
+ backend = str(backend)
514
+ ZImageAttention._attention_backend = backend
515
+ except ImportError:
516
+ pass
unet/Z-Image/src/utils/helpers.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper utilities for Z-Image."""
2
+
3
+ import hashlib
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Optional, List, Tuple, Dict
7
+
8
+ from loguru import logger
9
+ import torch
10
+
11
+ from config import BYTES_PER_GB
12
+
13
+
14
+ def format_bytes(size: float) -> str:
15
+ """
16
+ Format bytes to GB string.
17
+
18
+ Args:
19
+ size: Size in bytes
20
+
21
+ Returns:
22
+ Formatted string in GB
23
+ """
24
+ n = size / BYTES_PER_GB
25
+ return f"{n:.2f} GB"
26
+
27
+
28
+ def print_memory_stats(stage: str) -> None:
29
+ """
30
+ Print CUDA memory statistics.
31
+
32
+ Args:
33
+ stage: Description of current stage
34
+ """
35
+ if not torch.cuda.is_available():
36
+ logger.warning("CUDA not available, skipping memory stats")
37
+ return
38
+
39
+ torch.cuda.synchronize()
40
+ allocated = torch.cuda.max_memory_allocated()
41
+ reserved = torch.cuda.max_memory_reserved()
42
+ current_allocated = torch.cuda.memory_allocated()
43
+ current_reserved = torch.cuda.memory_reserved()
44
+
45
+ logger.info(f"[{stage}] Memory Stats:")
46
+ logger.info(f" Current Allocated: {format_bytes(current_allocated)}")
47
+ logger.info(f" Current Reserved: {format_bytes(current_reserved)}")
48
+ logger.info(f" Peak Allocated: {format_bytes(allocated)}")
49
+ logger.info(f" Peak Reserved: {format_bytes(reserved)}")
50
+
51
+
52
+ def compute_file_md5(file_path: Path, chunk_size: int = 8192) -> str:
53
+ """Compute MD5 hash of a file."""
54
+ md5_hash = hashlib.md5()
55
+ with open(file_path, "rb") as f:
56
+ while chunk := f.read(chunk_size):
57
+ md5_hash.update(chunk)
58
+ return md5_hash.hexdigest()
59
+
60
+
61
+ def load_manifest(manifest_file: Path) -> Dict[str, Optional[str]]:
62
+ """Load manifest file. Returns dict mapping file paths to MD5 hashes (or None)."""
63
+ manifest = {}
64
+ if not manifest_file.exists():
65
+ return manifest
66
+
67
+ with open(manifest_file, "r", encoding="utf-8") as f:
68
+ for line_num, line in enumerate(f, 1):
69
+ line = line.strip()
70
+ # Skip empty lines and comments
71
+ if not line or line.startswith("#"):
72
+ continue
73
+
74
+ parts = line.split()
75
+
76
+ if len(parts) == 1:
77
+ # Only file path, no checksum
78
+ file_path = parts[0]
79
+ manifest[file_path] = None
80
+ elif len(parts) == 2:
81
+ # File path with checksum
82
+ if len(parts[0]) == 32 and all(c in '0123456789abcdef' for c in parts[0].lower()):
83
+ md5_hash, file_path = parts
84
+ else:
85
+ file_path, md5_hash = parts
86
+ manifest[file_path] = md5_hash
87
+ else:
88
+ logger.warning(f"Invalid manifest format at line {line_num}: {line}")
89
+ continue
90
+
91
+ return manifest
92
+
93
+
94
+ def verify_file_integrity(
95
+ base_dir: Path,
96
+ manifest: Dict[str, Optional[str]],
97
+ verify_checksums: bool = True
98
+ ) -> Tuple[bool, List[str], List[str]]:
99
+ """
100
+ Verify file integrity using a manifest.
101
+
102
+ Args:
103
+ base_dir: Base directory for relative file paths
104
+ manifest: Dictionary of relative paths to MD5 hashes (None if no hash provided)
105
+ verify_checksums: If True, verify MD5 checksums when available; if False, only check existence
106
+
107
+ Returns:
108
+ Tuple of (all_valid: bool, missing_files: List[str], corrupted_files: List[str])
109
+ """
110
+ missing = []
111
+ corrupted = []
112
+
113
+ for rel_path, expected_md5 in manifest.items():
114
+ file_path = base_dir / rel_path
115
+
116
+ if not file_path.exists():
117
+ missing.append(rel_path)
118
+ continue
119
+
120
+ # Only verify checksum if requested AND hash is available
121
+ if verify_checksums and expected_md5 is not None:
122
+ try:
123
+ actual_md5 = compute_file_md5(file_path)
124
+ if actual_md5 != expected_md5:
125
+ corrupted.append(rel_path)
126
+ logger.debug(f"Checksum mismatch for {rel_path}: expected {expected_md5}, got {actual_md5}")
127
+ except Exception as e:
128
+ logger.error(f"Failed to compute checksum for {rel_path}: {e}")
129
+ corrupted.append(rel_path)
130
+
131
+ all_valid = len(missing) == 0 and len(corrupted) == 0
132
+ return all_valid, missing, corrupted
133
+
134
+
135
+ def ensure_model_weights(
136
+ model_path: str,
137
+ repo_id: str = "Tongyi-MAI/Z-Image-Turbo",
138
+ verify: bool = False,
139
+ manifest_name: Optional[str] = None
140
+ ) -> Path:
141
+ """
142
+ Ensure model weights exist and optionally verify integrity.
143
+
144
+ Args:
145
+ model_path: Path to model directory
146
+ repo_id: HuggingFace repo ID for download
147
+ verify: If True, verify MD5 checksums; if False, only check existence
148
+ manifest_name: Manifest file name in src/config/manifests/ (auto-detect if None)
149
+
150
+ Returns:
151
+ Path to validated model directory
152
+ """
153
+ from huggingface_hub import snapshot_download
154
+
155
+ target_dir = Path(model_path)
156
+
157
+ # Determine manifest path
158
+ if manifest_name:
159
+ # Explicitly specified manifest from config/manifests/
160
+ manifest_path = Path(__file__).parent.parent / "config" / "manifests" / manifest_name
161
+ else:
162
+ # Auto-detect
163
+ model_name = target_dir.name.lower() # e.g., "Z-Image-Turbo" -> "z-image-turbo"
164
+ config_manifest = Path(__file__).parent.parent / "config" / "manifests" / f"{model_name}.txt"
165
+
166
+ if config_manifest.exists():
167
+ manifest_path = config_manifest
168
+ else:
169
+ # Fallback
170
+ manifest_path = target_dir / "manifest.txt"
171
+
172
+ manifest = load_manifest(manifest_path)
173
+
174
+ if not manifest:
175
+ logger.warning(f"Manifest file not found: {manifest_path}")
176
+ logger.warning("Skipping file verification (assuming model exists)")
177
+ if target_dir.exists():
178
+ logger.info(f"✓ Model directory exists: {target_dir}")
179
+ return target_dir
180
+ else:
181
+ logger.warning(f"Model directory not found: {target_dir}")
182
+ missing_files = ["entire model directory"]
183
+ corrupted_files = []
184
+ else:
185
+ # Count files with checksums
186
+ files_with_checksums = sum(1 for v in manifest.values() if v is not None)
187
+
188
+ if verify and files_with_checksums == 0:
189
+ logger.info(f"Verify requested but no checksums in manifest, only checking existence")
190
+ elif verify and files_with_checksums > 0:
191
+ logger.info(f"Verifying {files_with_checksums} file(s) with MD5 checksums...")
192
+
193
+ # Verify files
194
+ all_valid, missing_files, corrupted_files = verify_file_integrity(
195
+ target_dir, manifest, verify_checksums=verify
196
+ )
197
+
198
+ if all_valid:
199
+ if verify and files_with_checksums > 0:
200
+ logger.success(f"✓ All files verified with MD5 checksums in {target_dir}")
201
+ else:
202
+ logger.info(f"✓ All {len(manifest)} required files exist in {target_dir}")
203
+ return target_dir
204
+
205
+ # Report missing and corrupted files
206
+ if missing_files:
207
+ logger.warning(f"Missing {len(missing_files)} file(s):")
208
+ for f in missing_files[:10]:
209
+ logger.warning(f" - {f}")
210
+ if len(missing_files) > 10:
211
+ logger.warning(f" ... and {len(missing_files) - 10} more")
212
+
213
+ if corrupted_files:
214
+ logger.error(f"Corrupted {len(corrupted_files)} file(s) (checksum mismatch):")
215
+ for f in corrupted_files[:10]:
216
+ logger.error(f" - {f}")
217
+ if len(corrupted_files) > 10:
218
+ logger.error(f" ... and {len(corrupted_files) - 10} more")
219
+
220
+ # Download model weights
221
+ logger.info(f"\nAttempting to download from {repo_id}...")
222
+ try:
223
+ target_dir.mkdir(parents=True, exist_ok=True)
224
+ snapshot_download(
225
+ repo_id=repo_id,
226
+ local_dir=str(target_dir),
227
+ local_dir_use_symlinks=False,
228
+ resume_download=True,
229
+ )
230
+ logger.success("✓ Download completed")
231
+ except Exception as e:
232
+ logger.error(f"✗ Download failed: {e}")
233
+ logger.info(
234
+ f"\nIf you are offline, please manually download from:\n"
235
+ f" https://huggingface.co/{repo_id}\n"
236
+ f"and place in: {target_dir.absolute()}"
237
+ )
238
+ raise RuntimeError(f"Failed to download model weights: {e}")
239
+
240
+ # Verify after download
241
+ if manifest:
242
+ all_valid, missing_after, corrupted_after = verify_file_integrity(
243
+ target_dir, manifest, verify_checksums=verify
244
+ )
245
+
246
+ if not all_valid:
247
+ error_msg = []
248
+ if missing_after:
249
+ error_msg.append(f"Still missing {len(missing_after)} file(s)")
250
+ if corrupted_after:
251
+ error_msg.append(f"Still corrupted {len(corrupted_after)} file(s)")
252
+
253
+ raise FileNotFoundError(
254
+ f"After download: {', '.join(error_msg)}\n"
255
+ f"Please verify the download or manually place files in:\n"
256
+ f" {target_dir.absolute()}"
257
+ )
258
+
259
+ logger.success("✓ All model weights validated successfully")
260
+ return target_dir
unet/Z-Image/src/utils/import_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import torch
4
+
5
+
6
+ def is_flash_attn_available():
7
+ return importlib.util.find_spec("flash_attn") is not None
8
+
9
+
10
+ def is_flash_attn_3_available():
11
+ return importlib.util.find_spec("flash_attn_interface") is not None
12
+
13
+
14
+ def is_torch_version(operator: str, version: str):
15
+ from packaging import version as pversion
16
+
17
+ torch_version = pversion.parse(torch.__version__)
18
+ target_version = pversion.parse(version)
19
+
20
+ # print(f"torch_version: {torch_version}, target: torch{operator}{target_version}")
21
+ if operator == ">":
22
+ return torch_version > target_version
23
+ elif operator == ">=":
24
+ return torch_version >= target_version
25
+ elif operator == "==":
26
+ return torch_version == target_version
27
+ elif operator == "<=":
28
+ return torch_version <= target_version
29
+ elif operator == "<":
30
+ return torch_version < target_version
31
+ return False
unet/Z-Image/src/utils/loader.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading utilities for Z-Image components."""
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+ from typing import Optional, Union
8
+
9
+ from loguru import logger
10
+ from safetensors.torch import load_file
11
+ import torch
12
+ from transformers import AutoModel, AutoTokenizer
13
+
14
+ from config import (
15
+ DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
16
+ DEFAULT_SCHEDULER_SHIFT,
17
+ DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
18
+ DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
19
+ DEFAULT_TRANSFORMER_DIM,
20
+ DEFAULT_TRANSFORMER_F_PATCH_SIZE,
21
+ DEFAULT_TRANSFORMER_IN_CHANNELS,
22
+ DEFAULT_TRANSFORMER_N_HEADS,
23
+ DEFAULT_TRANSFORMER_N_KV_HEADS,
24
+ DEFAULT_TRANSFORMER_N_LAYERS,
25
+ DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
26
+ DEFAULT_TRANSFORMER_NORM_EPS,
27
+ DEFAULT_TRANSFORMER_PATCH_SIZE,
28
+ DEFAULT_TRANSFORMER_QK_NORM,
29
+ DEFAULT_TRANSFORMER_T_SCALE,
30
+ DEFAULT_VAE_IN_CHANNELS,
31
+ DEFAULT_VAE_LATENT_CHANNELS,
32
+ DEFAULT_VAE_NORM_NUM_GROUPS,
33
+ DEFAULT_VAE_OUT_CHANNELS,
34
+ DEFAULT_VAE_SCALING_FACTOR,
35
+ ROPE_AXES_DIMS,
36
+ ROPE_AXES_LENS,
37
+ ROPE_THETA,
38
+ )
39
+ from zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL
40
+ from zimage.scheduler import FlowMatchEulerDiscreteScheduler
41
+
42
+ DIFFUSERS_AVAILABLE = False
43
+
44
+
45
+ def load_config(config_path: str) -> dict:
46
+ with open(config_path, "r") as f:
47
+ return json.load(f)
48
+
49
+
50
+ def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> dict:
51
+ """Load sharded safetensors from a directory."""
52
+ weight_dir = Path(weight_dir)
53
+ index_files = list(weight_dir.glob("*.safetensors.index.json"))
54
+
55
+ state_dict = {}
56
+ if index_files:
57
+ # Load sharded weights
58
+ with open(index_files[0], "r") as f:
59
+ index = json.load(f)
60
+ weight_map = index.get("weight_map", {})
61
+ shard_files = set(weight_map.values())
62
+ for shard_file in shard_files:
63
+ shard_path = weight_dir / shard_file
64
+ shard_state = load_file(str(shard_path), device=str(device))
65
+ state_dict.update(shard_state)
66
+ else:
67
+ # Load single safetensors file
68
+ safetensors_files = list(weight_dir.glob("*.safetensors"))
69
+ if not safetensors_files:
70
+ raise FileNotFoundError(f"No safetensors files found in {weight_dir}")
71
+ state_dict = load_file(str(safetensors_files[0]), device=str(device))
72
+
73
+ # Cast to target dtype if specified
74
+ if dtype is not None:
75
+ state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()}
76
+
77
+ return state_dict
78
+
79
+
80
+ def load_from_local_dir(
81
+ model_dir: Union[str, Path],
82
+ device: str = "cuda",
83
+ dtype: torch.dtype = torch.bfloat16,
84
+ verbose: bool = False,
85
+ compile: bool = False,
86
+ ) -> dict:
87
+ """
88
+ Load all Z-Image components from local directory.
89
+
90
+ Args:
91
+ model_dir: Path to model directory
92
+ device: Device to load models on
93
+ dtype: Data type for model weights
94
+ verbose: Whether to display loading logs
95
+ compile: Whether to compile transformer and vae with torch.compile
96
+
97
+ Returns:
98
+ Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler
99
+ """
100
+ model_dir = Path(model_dir)
101
+
102
+ sys.path.insert(0, str(model_dir.parent.parent / "Z-Image" / "src"))
103
+ from zimage.transformer import ZImageTransformer2DModel
104
+
105
+ if verbose:
106
+ logger.info(f"Loading Z-Image from: {model_dir}")
107
+
108
+ # DiT
109
+ if verbose:
110
+ logger.info("Loading DiT...")
111
+ transformer_dir = model_dir / "transformer"
112
+ config = load_config(str(transformer_dir / "config.json"))
113
+
114
+ with torch.device("meta"):
115
+ transformer = ZImageTransformer2DModel(
116
+ all_patch_size=tuple(config.get("all_patch_size", DEFAULT_TRANSFORMER_PATCH_SIZE)),
117
+ all_f_patch_size=tuple(config.get("all_f_patch_size", DEFAULT_TRANSFORMER_F_PATCH_SIZE)),
118
+ in_channels=config.get("in_channels", DEFAULT_TRANSFORMER_IN_CHANNELS),
119
+ dim=config.get("dim", DEFAULT_TRANSFORMER_DIM),
120
+ n_layers=config.get("n_layers", DEFAULT_TRANSFORMER_N_LAYERS),
121
+ n_refiner_layers=config.get("n_refiner_layers", DEFAULT_TRANSFORMER_N_REFINER_LAYERS),
122
+ n_heads=config.get("n_heads", DEFAULT_TRANSFORMER_N_HEADS),
123
+ n_kv_heads=config.get("n_kv_heads", DEFAULT_TRANSFORMER_N_KV_HEADS),
124
+ norm_eps=config.get("norm_eps", DEFAULT_TRANSFORMER_NORM_EPS),
125
+ qk_norm=config.get("qk_norm", DEFAULT_TRANSFORMER_QK_NORM),
126
+ cap_feat_dim=config.get("cap_feat_dim", DEFAULT_TRANSFORMER_CAP_FEAT_DIM),
127
+ rope_theta=config.get("rope_theta", ROPE_THETA),
128
+ t_scale=config.get("t_scale", DEFAULT_TRANSFORMER_T_SCALE),
129
+ axes_dims=config.get("axes_dims", ROPE_AXES_DIMS),
130
+ axes_lens=config.get("axes_lens", ROPE_AXES_LENS),
131
+ ).to(dtype)
132
+
133
+ # DiT (weights to CPU then move to GPU to optimize memory)
134
+ state_dict = load_sharded_safetensors(transformer_dir, device="cpu", dtype=dtype)
135
+ transformer.load_state_dict(state_dict, strict=False, assign=True)
136
+ del state_dict
137
+
138
+ if verbose:
139
+ logger.info("Moving DiT to GPU...")
140
+ transformer = transformer.to(device)
141
+ if torch.cuda.is_available():
142
+ torch.cuda.empty_cache()
143
+ transformer.eval()
144
+
145
+ # VAE
146
+ if verbose:
147
+ logger.info("Loading VAE...")
148
+ vae_dir = model_dir / "vae"
149
+ vae_config = load_config(str(vae_dir / "config.json"))
150
+
151
+ vae = LocalAutoencoderKL(
152
+ in_channels=vae_config.get("in_channels", DEFAULT_VAE_IN_CHANNELS),
153
+ out_channels=vae_config.get("out_channels", DEFAULT_VAE_OUT_CHANNELS),
154
+ down_block_types=tuple(vae_config.get("down_block_types", ("DownEncoderBlock2D",))),
155
+ up_block_types=tuple(vae_config.get("up_block_types", ("UpDecoderBlock2D",))),
156
+ block_out_channels=tuple(vae_config.get("block_out_channels", (64,))),
157
+ layers_per_block=vae_config.get("layers_per_block", 1),
158
+ latent_channels=vae_config.get("latent_channels", DEFAULT_VAE_LATENT_CHANNELS),
159
+ norm_num_groups=vae_config.get("norm_num_groups", DEFAULT_VAE_NORM_NUM_GROUPS),
160
+ scaling_factor=vae_config.get("scaling_factor", DEFAULT_VAE_SCALING_FACTOR),
161
+ shift_factor=vae_config.get("shift_factor", None),
162
+ use_quant_conv=vae_config.get("use_quant_conv", True),
163
+ use_post_quant_conv=vae_config.get("use_post_quant_conv", True),
164
+ mid_block_add_attention=vae_config.get("mid_block_add_attention", True),
165
+ )
166
+
167
+ # VAE (fp32 for better precision)
168
+ vae_state_dict = load_sharded_safetensors(vae_dir, device="cpu")
169
+ vae.load_state_dict(vae_state_dict, strict=False)
170
+ del vae_state_dict
171
+ vae.to(device=device, dtype=torch.float32)
172
+ vae.eval()
173
+ torch.cuda.empty_cache()
174
+
175
+ # Text Encoder
176
+ if verbose:
177
+ logger.info("Loading Text Encoder...")
178
+ text_encoder_dir = model_dir / "text_encoder"
179
+ text_encoder = AutoModel.from_pretrained(
180
+ str(text_encoder_dir),
181
+ # torch_dtype=dtype, # some version use this
182
+ dtype=dtype,
183
+ trust_remote_code=True,
184
+ )
185
+ text_encoder.to(device)
186
+ text_encoder.eval()
187
+
188
+ # Tokenizer
189
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
190
+ if verbose:
191
+ logger.info("Loading Tokenizer...")
192
+ tokenizer_dir = model_dir / "tokenizer"
193
+ tokenizer = AutoTokenizer.from_pretrained(
194
+ str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir),
195
+ trust_remote_code=True,
196
+ )
197
+
198
+ # Scheduler
199
+ if verbose:
200
+ logger.info("Loading Scheduler...")
201
+ scheduler_dir = model_dir / "scheduler"
202
+ scheduler_config = load_config(str(scheduler_dir / "scheduler_config.json"))
203
+ scheduler = FlowMatchEulerDiscreteScheduler(
204
+ num_train_timesteps=scheduler_config.get("num_train_timesteps", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS),
205
+ shift=scheduler_config.get("shift", DEFAULT_SCHEDULER_SHIFT),
206
+ use_dynamic_shifting=scheduler_config.get("use_dynamic_shifting", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING),
207
+ )
208
+
209
+ if compile:
210
+ if verbose:
211
+ logger.info("Compiling DiT and VAE...")
212
+ transformer = torch.compile(transformer)
213
+ vae = torch.compile(vae)
214
+
215
+ if verbose:
216
+ logger.success("All components loaded successfully")
217
+
218
+ return {
219
+ "transformer": transformer,
220
+ "vae": vae,
221
+ "text_encoder": text_encoder,
222
+ "tokenizer": tokenizer,
223
+ "scheduler": scheduler,
224
+ }
unet/Z-Image/src/zimage/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """Z-Image PyTorch Native Implementation."""
2
+
3
+ from .pipeline import generate
4
+ from .transformer import ZImageTransformer2DModel
5
+
6
+ __all__ = [
7
+ "ZImageTransformer2DModel",
8
+ "generate",
9
+ ]
unet/Z-Image/src/zimage/autoencoder.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AutoencoderKL implementation compatible with diffusers weights."""
2
+
3
+ # Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ @dataclass
12
+ class AutoencoderKLOutput:
13
+ sample: torch.Tensor
14
+
15
+
16
+ class AutoencoderConfig:
17
+ def __init__(self, **kwargs):
18
+ self.__dict__.update(kwargs)
19
+
20
+ def get(self, key, default=None):
21
+ return self.__dict__.get(key, default)
22
+
23
+ def __getattr__(self, name):
24
+ return self.__dict__.get(name)
25
+
26
+
27
+ def swish(x):
28
+ return x * torch.sigmoid(x)
29
+
30
+
31
+ class ResnetBlock2D(nn.Module):
32
+ def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6):
33
+ super().__init__()
34
+ out_channels = out_channels or in_channels
35
+ self.in_channels = in_channels
36
+ self.out_channels = out_channels
37
+
38
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
39
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
40
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
41
+ self.dropout = nn.Dropout(dropout)
42
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
43
+
44
+ self.nonlinearity = swish
45
+
46
+ if self.in_channels != self.out_channels:
47
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
48
+ else:
49
+ self.conv_shortcut = None
50
+
51
+ def forward(self, input_tensor, temb=None):
52
+ hidden_states = input_tensor
53
+ hidden_states = self.norm1(hidden_states)
54
+ hidden_states = self.nonlinearity(hidden_states)
55
+ hidden_states = self.conv1(hidden_states)
56
+
57
+ hidden_states = self.norm2(hidden_states)
58
+ hidden_states = self.nonlinearity(hidden_states)
59
+ hidden_states = self.dropout(hidden_states)
60
+ hidden_states = self.conv2(hidden_states)
61
+
62
+ if self.conv_shortcut is not None:
63
+ input_tensor = self.conv_shortcut(input_tensor)
64
+
65
+ output_tensor = (input_tensor + hidden_states) / 1.0
66
+ return output_tensor
67
+
68
+
69
+ class Attention(nn.Module):
70
+ def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6):
71
+ super().__init__()
72
+ self.heads = heads
73
+ self.in_channels = in_channels
74
+ self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
75
+
76
+ self.to_q = nn.Linear(in_channels, in_channels)
77
+ self.to_k = nn.Linear(in_channels, in_channels)
78
+ self.to_v = nn.Linear(in_channels, in_channels)
79
+ self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)])
80
+
81
+ def forward(self, hidden_states):
82
+ b, c, h, w = hidden_states.shape
83
+ residual = hidden_states
84
+ hidden_states = self.group_norm(hidden_states)
85
+ hidden_states = hidden_states.view(b, c, -1).transpose(1, 2) # (B, H*W, C)
86
+
87
+ query = self.to_q(hidden_states)
88
+ key = self.to_k(hidden_states)
89
+ value = self.to_v(hidden_states)
90
+
91
+ import torch.nn.functional as F
92
+
93
+ hidden_states = F.scaled_dot_product_attention(query, key, value)
94
+
95
+ hidden_states = self.to_out[0](hidden_states)
96
+ hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w)
97
+
98
+ return residual + hidden_states
99
+
100
+
101
+ class Downsample2D(nn.Module):
102
+ def __init__(self, channels, with_conv=True, out_channels=None, padding=1):
103
+ super().__init__()
104
+ out_channels = out_channels or channels
105
+ self.with_conv = with_conv
106
+ if with_conv:
107
+ self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding)
108
+
109
+ def forward(self, hidden_states):
110
+ if self.with_conv:
111
+ return self.conv(hidden_states)
112
+ else:
113
+ return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
114
+
115
+
116
+ class Upsample2D(nn.Module):
117
+ def __init__(self, channels, with_conv=True, out_channels=None):
118
+ super().__init__()
119
+ out_channels = out_channels or channels
120
+ self.with_conv = with_conv
121
+ if with_conv:
122
+ self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
123
+
124
+ def forward(self, hidden_states):
125
+ hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
126
+ if self.with_conv:
127
+ hidden_states = self.conv(hidden_states)
128
+ return hidden_states
129
+
130
+
131
+ class DownEncoderBlock2D(nn.Module):
132
+ def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True):
133
+ super().__init__()
134
+ resnets = []
135
+ for i in range(num_layers):
136
+ in_c = in_channels if i == 0 else out_channels
137
+ resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
138
+ self.resnets = nn.ModuleList(resnets)
139
+
140
+ if add_downsample:
141
+ self.downsamplers = nn.ModuleList(
142
+ [Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)]
143
+ )
144
+ else:
145
+ self.downsamplers = None
146
+
147
+ def forward(self, hidden_states):
148
+ for resnet in self.resnets:
149
+ hidden_states = resnet(hidden_states)
150
+
151
+ if self.downsamplers is not None:
152
+ for downsampler in self.downsamplers:
153
+ pad = (0, 1, 0, 1)
154
+ hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
155
+ hidden_states = downsampler(hidden_states)
156
+
157
+ return hidden_states
158
+
159
+
160
+ class UpDecoderBlock2D(nn.Module):
161
+ def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True):
162
+ super().__init__()
163
+ resnets = []
164
+ for i in range(num_layers):
165
+ in_c = in_channels if i == 0 else out_channels
166
+ resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
167
+ self.resnets = nn.ModuleList(resnets)
168
+
169
+ if add_upsample:
170
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)])
171
+ else:
172
+ self.upsamplers = None
173
+
174
+ def forward(self, hidden_states):
175
+ for resnet in self.resnets:
176
+ hidden_states = resnet(hidden_states)
177
+
178
+ if self.upsamplers is not None:
179
+ for upsampler in self.upsamplers:
180
+ hidden_states = upsampler(hidden_states)
181
+
182
+ return hidden_states
183
+
184
+
185
+ class UNetMidBlock2D(nn.Module):
186
+ def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None):
187
+ super().__init__()
188
+ self.resnets = nn.ModuleList(
189
+ [
190
+ ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
191
+ ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
192
+ ]
193
+ )
194
+ self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)])
195
+
196
+ def forward(self, hidden_states):
197
+ hidden_states = self.resnets[0](hidden_states)
198
+ for attn in self.attentions:
199
+ hidden_states = attn(hidden_states)
200
+ hidden_states = self.resnets[1](hidden_states)
201
+ return hidden_states
202
+
203
+
204
+ class Encoder(nn.Module):
205
+ def __init__(
206
+ self,
207
+ in_channels=3,
208
+ out_channels=3,
209
+ block_out_channels=(64,),
210
+ layers_per_block=2,
211
+ norm_num_groups=32,
212
+ double_z=True,
213
+ ):
214
+ super().__init__()
215
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
216
+
217
+ self.down_blocks = nn.ModuleList([])
218
+ output_channel = block_out_channels[0]
219
+ for i, block_out_channel in enumerate(block_out_channels):
220
+ input_channel = output_channel
221
+ output_channel = block_out_channel
222
+ is_final_block = i == len(block_out_channels) - 1
223
+
224
+ block = DownEncoderBlock2D(
225
+ input_channel,
226
+ output_channel,
227
+ num_layers=layers_per_block,
228
+ resnet_groups=norm_num_groups,
229
+ add_downsample=not is_final_block,
230
+ )
231
+ self.down_blocks.append(block)
232
+
233
+ self.mid_block = UNetMidBlock2D(
234
+ block_out_channels[-1],
235
+ resnet_groups=norm_num_groups,
236
+ )
237
+
238
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
239
+ self.conv_act = nn.SiLU()
240
+
241
+ conv_out_channels = 2 * out_channels if double_z else out_channels
242
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
243
+
244
+ def forward(self, x):
245
+ x = self.conv_in(x)
246
+ for block in self.down_blocks:
247
+ x = block(x)
248
+ x = self.mid_block(x)
249
+ x = self.conv_norm_out(x)
250
+ x = self.conv_act(x)
251
+ x = self.conv_out(x)
252
+ return x
253
+
254
+
255
+ class Decoder(nn.Module):
256
+ def __init__(
257
+ self,
258
+ in_channels=3,
259
+ out_channels=3,
260
+ block_out_channels=(64,),
261
+ layers_per_block=2,
262
+ norm_num_groups=32,
263
+ ):
264
+ super().__init__()
265
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
266
+
267
+ self.mid_block = UNetMidBlock2D(
268
+ block_out_channels[-1],
269
+ resnet_groups=norm_num_groups,
270
+ )
271
+
272
+ self.up_blocks = nn.ModuleList([])
273
+ reversed_block_out_channels = list(reversed(block_out_channels))
274
+ output_channel = reversed_block_out_channels[0]
275
+
276
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
277
+ input_channel = output_channel
278
+ output_channel = block_out_channel
279
+ is_final_block = i == len(block_out_channels) - 1
280
+ block = UpDecoderBlock2D(
281
+ input_channel,
282
+ output_channel,
283
+ num_layers=layers_per_block + 1,
284
+ resnet_groups=norm_num_groups,
285
+ add_upsample=not is_final_block,
286
+ )
287
+ self.up_blocks.append(block)
288
+
289
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
290
+ self.conv_act = nn.SiLU()
291
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
292
+
293
+ def forward(self, x):
294
+ x = self.conv_in(x)
295
+ x = self.mid_block(x)
296
+ for block in self.up_blocks:
297
+ x = block(x)
298
+ x = self.conv_norm_out(x)
299
+ x = self.conv_act(x)
300
+ x = self.conv_out(x)
301
+ return x
302
+
303
+
304
+ class AutoencoderKL(nn.Module):
305
+ def __init__(
306
+ self,
307
+ in_channels: int = 3,
308
+ out_channels: int = 3,
309
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
310
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
311
+ block_out_channels: Tuple[int] = (64,),
312
+ layers_per_block: int = 1,
313
+ act_fn: str = "silu",
314
+ latent_channels: int = 4,
315
+ norm_num_groups: int = 32,
316
+ sample_size: int = 32,
317
+ scaling_factor: float = 0.18215,
318
+ shift_factor: Optional[float] = None,
319
+ force_upcast: bool = True,
320
+ use_quant_conv: bool = True,
321
+ use_post_quant_conv: bool = True,
322
+ mid_block_add_attention: bool = True,
323
+ **kwargs,
324
+ ):
325
+ super().__init__()
326
+ self.config = AutoencoderConfig(
327
+ in_channels=in_channels,
328
+ out_channels=out_channels,
329
+ block_out_channels=block_out_channels,
330
+ layers_per_block=layers_per_block,
331
+ latent_channels=latent_channels,
332
+ scaling_factor=scaling_factor,
333
+ shift_factor=shift_factor,
334
+ )
335
+
336
+ self.encoder = Encoder(
337
+ in_channels=in_channels,
338
+ out_channels=latent_channels,
339
+ block_out_channels=block_out_channels,
340
+ layers_per_block=layers_per_block,
341
+ norm_num_groups=norm_num_groups,
342
+ double_z=True,
343
+ )
344
+
345
+ self.decoder = Decoder(
346
+ in_channels=latent_channels,
347
+ out_channels=out_channels,
348
+ block_out_channels=block_out_channels,
349
+ layers_per_block=layers_per_block,
350
+ norm_num_groups=norm_num_groups,
351
+ )
352
+
353
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
354
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
355
+
356
+ @property
357
+ def dtype(self):
358
+ return next(self.parameters()).dtype
359
+
360
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
361
+ if self.post_quant_conv is not None:
362
+ z = self.post_quant_conv(z)
363
+
364
+ dec = self.decoder(z)
365
+
366
+ if not return_dict:
367
+ return (dec,)
368
+
369
+ return AutoencoderKLOutput(sample=dec)
unet/Z-Image/src/zimage/pipeline.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Z-Image Pipeline."""
2
+
3
+ import inspect
4
+ from typing import List, Optional, Union
5
+
6
+ from loguru import logger
7
+ import torch
8
+
9
+ from config import (
10
+ BASE_IMAGE_SEQ_LEN,
11
+ BASE_SHIFT,
12
+ DEFAULT_CFG_TRUNCATION,
13
+ DEFAULT_GUIDANCE_SCALE,
14
+ DEFAULT_HEIGHT,
15
+ DEFAULT_INFERENCE_STEPS,
16
+ DEFAULT_MAX_SEQUENCE_LENGTH,
17
+ DEFAULT_WIDTH,
18
+ MAX_IMAGE_SEQ_LEN,
19
+ MAX_SHIFT,
20
+ )
21
+
22
+
23
+ def calculate_shift(
24
+ image_seq_len,
25
+ base_seq_len: int = BASE_IMAGE_SEQ_LEN,
26
+ max_seq_len: int = MAX_IMAGE_SEQ_LEN,
27
+ base_shift: float = BASE_SHIFT,
28
+ max_shift: float = MAX_SHIFT,
29
+ ):
30
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
31
+ b = base_shift - m * base_seq_len
32
+ mu = image_seq_len * m + b
33
+ return mu
34
+
35
+
36
+ def retrieve_timesteps(
37
+ scheduler,
38
+ num_inference_steps: Optional[int] = None,
39
+ device: Optional[Union[str, torch.device]] = None,
40
+ timesteps: Optional[List[int]] = None,
41
+ sigmas: Optional[List[float]] = None,
42
+ **kwargs,
43
+ ):
44
+ if timesteps is not None and sigmas is not None:
45
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
46
+ if timesteps is not None:
47
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
48
+ if not accepts_timesteps:
49
+ raise ValueError(f"The scheduler does not support custom timestep schedules.")
50
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
51
+ timesteps = scheduler.timesteps
52
+ num_inference_steps = len(timesteps)
53
+ elif sigmas is not None:
54
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
55
+ if not accept_sigmas:
56
+ raise ValueError(f"The scheduler does not support custom sigmas schedules.")
57
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
58
+ timesteps = scheduler.timesteps
59
+ num_inference_steps = len(timesteps)
60
+ else:
61
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
62
+ timesteps = scheduler.timesteps
63
+ return timesteps, num_inference_steps
64
+
65
+
66
+ @torch.no_grad()
67
+ def generate(
68
+ transformer,
69
+ vae,
70
+ text_encoder,
71
+ tokenizer,
72
+ scheduler,
73
+ prompt: Union[str, List[str]],
74
+ height: int = DEFAULT_HEIGHT,
75
+ width: int = DEFAULT_WIDTH,
76
+ num_inference_steps: int = DEFAULT_INFERENCE_STEPS,
77
+ guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
78
+ negative_prompt: Optional[Union[str, List[str]]] = None,
79
+ num_images_per_prompt: int = 1,
80
+ generator: Optional[torch.Generator] = None,
81
+ cfg_normalization: bool = False,
82
+ cfg_truncation: float = DEFAULT_CFG_TRUNCATION,
83
+ max_sequence_length: int = DEFAULT_MAX_SEQUENCE_LENGTH,
84
+ output_type: str = "pil",
85
+ ):
86
+ device = next(transformer.parameters()).device
87
+
88
+ if hasattr(vae, "config") and hasattr(vae.config, "block_out_channels"):
89
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
90
+ else:
91
+ vae_scale_factor = 8
92
+ vae_scale = vae_scale_factor * 2
93
+
94
+ if height % vae_scale != 0:
95
+ raise ValueError(f"Height must be divisible by {vae_scale} (got {height}).")
96
+ if width % vae_scale != 0:
97
+ raise ValueError(f"Width must be divisible by {vae_scale} (got {width}).")
98
+
99
+ if isinstance(prompt, str):
100
+ batch_size = 1
101
+ prompt = [prompt]
102
+ else:
103
+ batch_size = len(prompt)
104
+
105
+ do_classifier_free_guidance = guidance_scale > 1.0
106
+ logger.info(f"Generating image: {height}x{width}, steps={num_inference_steps}, cfg={guidance_scale}")
107
+
108
+ formatted_prompts = []
109
+ for p in prompt:
110
+ messages = [{"role": "user", "content": p}]
111
+ formatted_prompt = tokenizer.apply_chat_template(
112
+ messages,
113
+ tokenize=False,
114
+ add_generation_prompt=True,
115
+ enable_thinking=True,
116
+ )
117
+ formatted_prompts.append(formatted_prompt)
118
+
119
+ text_inputs = tokenizer(
120
+ formatted_prompts,
121
+ padding="max_length",
122
+ max_length=max_sequence_length,
123
+ truncation=True,
124
+ return_tensors="pt",
125
+ )
126
+
127
+ text_input_ids = text_inputs.input_ids.to(device)
128
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
129
+
130
+ prompt_embeds = text_encoder(
131
+ input_ids=text_input_ids,
132
+ attention_mask=prompt_masks,
133
+ output_hidden_states=True,
134
+ ).hidden_states[-2]
135
+
136
+ prompt_embeds_list = []
137
+ for i in range(len(prompt_embeds)):
138
+ prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])
139
+
140
+ negative_prompt_embeds_list = []
141
+ if do_classifier_free_guidance:
142
+ if negative_prompt is None:
143
+ negative_prompt = ["" for _ in prompt]
144
+ elif isinstance(negative_prompt, str):
145
+ negative_prompt = [negative_prompt]
146
+
147
+ neg_formatted = []
148
+ for p in negative_prompt:
149
+ messages = [{"role": "user", "content": p}]
150
+ formatted_prompt = tokenizer.apply_chat_template(
151
+ messages,
152
+ tokenize=False,
153
+ add_generation_prompt=True,
154
+ enable_thinking=True,
155
+ )
156
+ neg_formatted.append(formatted_prompt)
157
+
158
+ neg_inputs = tokenizer(
159
+ neg_formatted,
160
+ padding="max_length",
161
+ max_length=max_sequence_length,
162
+ truncation=True,
163
+ return_tensors="pt",
164
+ )
165
+
166
+ neg_input_ids = neg_inputs.input_ids.to(device)
167
+ neg_masks = neg_inputs.attention_mask.to(device).bool()
168
+
169
+ neg_embeds = text_encoder(
170
+ input_ids=neg_input_ids,
171
+ attention_mask=neg_masks,
172
+ output_hidden_states=True,
173
+ ).hidden_states[-2]
174
+
175
+ for i in range(len(neg_embeds)):
176
+ negative_prompt_embeds_list.append(neg_embeds[i][neg_masks[i]])
177
+
178
+ if num_images_per_prompt > 1:
179
+ prompt_embeds_list = [pe for pe in prompt_embeds_list for _ in range(num_images_per_prompt)]
180
+ if do_classifier_free_guidance:
181
+ negative_prompt_embeds_list = [
182
+ npe for npe in negative_prompt_embeds_list for _ in range(num_images_per_prompt)
183
+ ]
184
+
185
+ height_latent = 2 * (int(height) // vae_scale)
186
+ width_latent = 2 * (int(width) // vae_scale)
187
+ shape = (batch_size * num_images_per_prompt, transformer.in_channels, height_latent, width_latent)
188
+
189
+ latents = torch.randn(shape, generator=generator, device=device, dtype=torch.float32)
190
+
191
+ actual_batch_size = batch_size * num_images_per_prompt
192
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
193
+
194
+ mu = calculate_shift(
195
+ image_seq_len,
196
+ scheduler.config.get("base_image_seq_len", 256),
197
+ scheduler.config.get("max_image_seq_len", 4096),
198
+ scheduler.config.get("base_shift", 0.5),
199
+ scheduler.config.get("max_shift", 1.15),
200
+ )
201
+ scheduler.sigma_min = 0.0
202
+ scheduler_kwargs = {"mu": mu}
203
+ timesteps, num_inference_steps = retrieve_timesteps(
204
+ scheduler,
205
+ num_inference_steps,
206
+ device,
207
+ sigmas=None,
208
+ **scheduler_kwargs,
209
+ )
210
+
211
+ logger.info(f"Sampling loop start: {num_inference_steps} steps")
212
+
213
+ from tqdm import tqdm
214
+
215
+ # Denoising loop with progress bar
216
+ for i, t in enumerate(tqdm(timesteps, desc="Denoising", total=len(timesteps))):
217
+ # If current t is 0 and it's the last step, skip computation
218
+ if t == 0 and i == len(timesteps) - 1:
219
+ logger.debug(f"Step {i+1}/{num_inference_steps} | t: {t.item():.2f} | Skipping last step")
220
+ continue
221
+
222
+ timestep = t.expand(latents.shape[0])
223
+ timestep = (1000 - timestep) / 1000
224
+ t_norm = timestep[0].item()
225
+
226
+ current_guidance_scale = guidance_scale
227
+ if do_classifier_free_guidance and cfg_truncation is not None and float(cfg_truncation) <= 1:
228
+ if t_norm > cfg_truncation:
229
+ current_guidance_scale = 0.0
230
+
231
+ apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0
232
+
233
+ if apply_cfg:
234
+ latents_typed = latents.to(
235
+ transformer.dtype if hasattr(transformer, "dtype") else next(transformer.parameters()).dtype
236
+ )
237
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1)
238
+ prompt_embeds_model_input = prompt_embeds_list + negative_prompt_embeds_list
239
+ timestep_model_input = timestep.repeat(2)
240
+ else:
241
+ latent_model_input = latents.to(next(transformer.parameters()).dtype)
242
+ prompt_embeds_model_input = prompt_embeds_list
243
+ timestep_model_input = timestep
244
+
245
+ latent_model_input = latent_model_input.unsqueeze(2)
246
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
247
+
248
+ model_out_list = transformer(
249
+ latent_model_input_list,
250
+ timestep_model_input,
251
+ prompt_embeds_model_input,
252
+ )[0]
253
+
254
+ if apply_cfg:
255
+ pos_out = model_out_list[:actual_batch_size]
256
+ neg_out = model_out_list[actual_batch_size:]
257
+ noise_pred = []
258
+ for j in range(actual_batch_size):
259
+ pos = pos_out[j].float()
260
+ neg = neg_out[j].float()
261
+ pred = pos + current_guidance_scale * (pos - neg)
262
+
263
+ if cfg_normalization and float(cfg_normalization) > 0.0:
264
+ ori_pos_norm = torch.linalg.vector_norm(pos)
265
+ new_pos_norm = torch.linalg.vector_norm(pred)
266
+ max_new_norm = ori_pos_norm * float(cfg_normalization)
267
+ if new_pos_norm > max_new_norm:
268
+ pred = pred * (max_new_norm / new_pos_norm)
269
+ noise_pred.append(pred)
270
+ noise_pred = torch.stack(noise_pred, dim=0)
271
+ else:
272
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
273
+
274
+ noise_pred = -noise_pred.squeeze(2)
275
+ latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
276
+ assert latents.dtype == torch.float32
277
+
278
+ if output_type == "latent":
279
+ return latents
280
+
281
+ shift_factor = getattr(vae.config, "shift_factor", 0.0) or 0.0
282
+ latents = (latents.to(vae.dtype) / vae.config.scaling_factor) + shift_factor
283
+ image = vae.decode(latents, return_dict=False)[0]
284
+
285
+ if output_type == "pil":
286
+ from PIL import Image
287
+
288
+ image = (image / 2 + 0.5).clamp(0, 1)
289
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
290
+ image = (image * 255).round().astype("uint8")
291
+ image = [Image.fromarray(img) for img in image]
292
+
293
+ return image
unet/Z-Image/src/zimage/transformer.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Z-Image Transformer."""
2
+
3
+ import math
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ from config import (
12
+ ADALN_EMBED_DIM,
13
+ FREQUENCY_EMBEDDING_SIZE,
14
+ MAX_PERIOD,
15
+ ROPE_AXES_DIMS,
16
+ ROPE_AXES_LENS,
17
+ ROPE_THETA,
18
+ SEQ_MULTI_OF,
19
+ )
20
+
21
+
22
+ class TimestepEmbedder(nn.Module):
23
+ def __init__(self, out_size, mid_size=None, frequency_embedding_size=FREQUENCY_EMBEDDING_SIZE):
24
+ super().__init__()
25
+ if mid_size is None:
26
+ mid_size = out_size
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(frequency_embedding_size, mid_size, bias=True),
29
+ nn.SiLU(),
30
+ nn.Linear(mid_size, out_size, bias=True),
31
+ )
32
+ self.frequency_embedding_size = frequency_embedding_size
33
+
34
+ @staticmethod
35
+ def timestep_embedding(t, dim, max_period=MAX_PERIOD):
36
+ with torch.amp.autocast("cuda", enabled=False):
37
+ half = dim // 2
38
+ freqs = torch.exp(
39
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
40
+ )
41
+ args = t[:, None].float() * freqs[None]
42
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
43
+ if dim % 2:
44
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
45
+ return embedding
46
+
47
+ def forward(self, t):
48
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
49
+ weight_dtype = self.mlp[0].weight.dtype
50
+ if weight_dtype.is_floating_point:
51
+ t_freq = t_freq.to(weight_dtype)
52
+ t_emb = self.mlp(t_freq)
53
+ return t_emb
54
+
55
+
56
+ class RMSNorm(nn.Module):
57
+ def __init__(self, dim: int, eps: float = 1e-5):
58
+ super().__init__()
59
+ self.eps = eps
60
+ self.weight = nn.Parameter(torch.ones(dim))
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
64
+ return output * self.weight
65
+
66
+
67
+ class FeedForward(nn.Module):
68
+ def __init__(self, dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
71
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
72
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
73
+
74
+ def forward(self, x):
75
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
76
+
77
+
78
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
79
+ with torch.amp.autocast("cuda", enabled=False):
80
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
81
+ freqs_cis = freqs_cis.unsqueeze(2)
82
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
83
+ return x_out.type_as(x_in)
84
+
85
+
86
+ class ZImageAttention(nn.Module):
87
+ _attention_backend = None
88
+
89
+ def __init__(self, dim: int, n_heads: int, n_kv_heads: int, qk_norm: bool = True, eps: float = 1e-5):
90
+ super().__init__()
91
+ self.n_heads = n_heads
92
+ self.n_kv_heads = n_kv_heads
93
+ self.head_dim = dim // n_heads
94
+
95
+ self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
96
+ self.to_k = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
97
+ self.to_v = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
98
+ self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)])
99
+
100
+ self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
101
+ self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ attention_mask: Optional[torch.Tensor] = None,
107
+ freqs_cis: Optional[torch.Tensor] = None,
108
+ ) -> torch.Tensor:
109
+ query = self.to_q(hidden_states)
110
+ key = self.to_k(hidden_states)
111
+ value = self.to_v(hidden_states)
112
+
113
+ query = query.unflatten(-1, (self.n_heads, -1))
114
+ key = key.unflatten(-1, (self.n_kv_heads, -1))
115
+ value = value.unflatten(-1, (self.n_kv_heads, -1))
116
+
117
+ if self.norm_q is not None:
118
+ query = self.norm_q(query)
119
+ if self.norm_k is not None:
120
+ key = self.norm_k(key)
121
+
122
+ if freqs_cis is not None:
123
+ query = apply_rotary_emb(query, freqs_cis)
124
+ key = apply_rotary_emb(key, freqs_cis)
125
+
126
+ dtype = query.dtype
127
+ query, key = query.to(dtype), key.to(dtype)
128
+
129
+ # Dispatch
130
+ from utils.attention import dispatch_attention
131
+
132
+ hidden_states = dispatch_attention(
133
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend
134
+ )
135
+
136
+ hidden_states = hidden_states.flatten(2, 3)
137
+ hidden_states = hidden_states.to(dtype)
138
+
139
+ output = self.to_out[0](hidden_states)
140
+ return output
141
+
142
+
143
+ class ZImageTransformerBlock(nn.Module):
144
+ def __init__(
145
+ self,
146
+ layer_id: int,
147
+ dim: int,
148
+ n_heads: int,
149
+ n_kv_heads: int,
150
+ norm_eps: float,
151
+ qk_norm: bool,
152
+ modulation=True,
153
+ ):
154
+ super().__init__()
155
+ self.dim = dim
156
+ self.head_dim = dim // n_heads
157
+ self.layer_id = layer_id
158
+ self.modulation = modulation
159
+
160
+ self.attention = ZImageAttention(dim, n_heads, n_kv_heads, qk_norm, norm_eps)
161
+ self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
162
+
163
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
164
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
165
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
166
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
167
+
168
+ if modulation:
169
+ self.adaLN_modulation = nn.ModuleList([nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)])
170
+
171
+ def forward(
172
+ self,
173
+ x: torch.Tensor,
174
+ attn_mask: torch.Tensor,
175
+ freqs_cis: torch.Tensor,
176
+ adaln_input: Optional[torch.Tensor] = None,
177
+ ):
178
+ if self.modulation:
179
+ assert adaln_input is not None
180
+ scale_msa, gate_msa, scale_mlp, gate_mlp = (
181
+ self.adaLN_modulation[0](adaln_input).unsqueeze(1).chunk(4, dim=2)
182
+ )
183
+ gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
184
+ scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
185
+
186
+ attn_out = self.attention(
187
+ self.attention_norm1(x) * scale_msa,
188
+ attention_mask=attn_mask,
189
+ freqs_cis=freqs_cis,
190
+ )
191
+ x = x + gate_msa * self.attention_norm2(attn_out)
192
+ x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
193
+ else:
194
+ attn_out = self.attention(
195
+ self.attention_norm1(x),
196
+ attention_mask=attn_mask,
197
+ freqs_cis=freqs_cis,
198
+ )
199
+ x = x + self.attention_norm2(attn_out)
200
+ x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
201
+
202
+ return x
203
+
204
+
205
+ class FinalLayer(nn.Module):
206
+ def __init__(self, hidden_size, out_channels):
207
+ super().__init__()
208
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
209
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
210
+ self.adaLN_modulation = nn.Sequential(
211
+ nn.SiLU(),
212
+ nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
213
+ )
214
+
215
+ def forward(self, x, c):
216
+ scale = 1.0 + self.adaLN_modulation(c)
217
+ x = self.norm_final(x) * scale.unsqueeze(1)
218
+ x = self.linear(x)
219
+ return x
220
+
221
+
222
+ class RopeEmbedder:
223
+ def __init__(
224
+ self,
225
+ theta: float = ROPE_THETA,
226
+ axes_dims: List[int] = ROPE_AXES_DIMS,
227
+ axes_lens: List[int] = ROPE_AXES_LENS,
228
+ ):
229
+ self.theta = theta
230
+ self.axes_dims = axes_dims
231
+ self.axes_lens = axes_lens
232
+ assert len(axes_dims) == len(axes_lens)
233
+ self.freqs_cis = None
234
+
235
+ @staticmethod
236
+ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = ROPE_THETA):
237
+ with torch.device("cpu"):
238
+ freqs_cis = []
239
+ for i, (d, e) in enumerate(zip(dim, end)):
240
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
241
+ timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
242
+ freqs = torch.outer(timestep, freqs).float()
243
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
244
+ freqs_cis.append(freqs_cis_i)
245
+ return freqs_cis
246
+
247
+ def __call__(self, ids: torch.Tensor):
248
+ assert ids.ndim == 2
249
+ assert ids.shape[-1] == len(self.axes_dims)
250
+ device = ids.device
251
+
252
+ if self.freqs_cis is None:
253
+ self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
254
+ self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
255
+ else:
256
+ if self.freqs_cis[0].device != device:
257
+ self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
258
+
259
+ result = []
260
+ for i in range(len(self.axes_dims)):
261
+ index = ids[:, i]
262
+ result.append(self.freqs_cis[i][index])
263
+ return torch.cat(result, dim=-1)
264
+
265
+
266
+ class ZImageTransformer2DModel(nn.Module):
267
+ def __init__(
268
+ self,
269
+ all_patch_size=(2,),
270
+ all_f_patch_size=(1,),
271
+ in_channels=16,
272
+ dim=3840,
273
+ n_layers=30,
274
+ n_refiner_layers=2,
275
+ n_heads=30,
276
+ n_kv_heads=30,
277
+ norm_eps=1e-5,
278
+ qk_norm=True,
279
+ cap_feat_dim=2560,
280
+ rope_theta=ROPE_THETA,
281
+ t_scale=1000.0,
282
+ axes_dims=ROPE_AXES_DIMS,
283
+ axes_lens=ROPE_AXES_LENS,
284
+ ):
285
+ super().__init__()
286
+ self.in_channels = in_channels
287
+ self.out_channels = in_channels
288
+ self.all_patch_size = all_patch_size
289
+ self.all_f_patch_size = all_f_patch_size
290
+ self.dim = dim
291
+ self.n_heads = n_heads
292
+ self.rope_theta = rope_theta
293
+ self.t_scale = t_scale
294
+
295
+ assert len(all_patch_size) == len(all_f_patch_size)
296
+
297
+ all_x_embedder = {}
298
+ all_final_layer = {}
299
+ for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
300
+ x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
301
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
302
+ final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
303
+ all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
304
+
305
+ self.all_x_embedder = nn.ModuleDict(all_x_embedder)
306
+ self.all_final_layer = nn.ModuleDict(all_final_layer)
307
+
308
+ self.noise_refiner = nn.ModuleList(
309
+ [
310
+ ZImageTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True)
311
+ for layer_id in range(n_refiner_layers)
312
+ ]
313
+ )
314
+
315
+ self.context_refiner = nn.ModuleList(
316
+ [
317
+ ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False)
318
+ for layer_id in range(n_refiner_layers)
319
+ ]
320
+ )
321
+
322
+ self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
323
+ self.cap_embedder = nn.Sequential(
324
+ RMSNorm(cap_feat_dim, eps=norm_eps),
325
+ nn.Linear(cap_feat_dim, dim, bias=True),
326
+ )
327
+
328
+ self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
329
+ self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
330
+
331
+ self.layers = nn.ModuleList(
332
+ [
333
+ ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
334
+ for layer_id in range(n_layers)
335
+ ]
336
+ )
337
+
338
+ head_dim = dim // n_heads
339
+ assert head_dim == sum(axes_dims)
340
+ self.axes_dims = axes_dims
341
+ self.axes_lens = axes_lens
342
+
343
+ self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
344
+
345
+ def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
346
+ pH = pW = patch_size
347
+ pF = f_patch_size
348
+ bsz = len(x)
349
+ assert len(size) == bsz
350
+ for i in range(bsz):
351
+ F, H, W = size[i]
352
+ ori_len = (F // pF) * (H // pH) * (W // pW)
353
+ x[i] = (
354
+ x[i][:ori_len]
355
+ .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
356
+ .permute(6, 0, 3, 1, 4, 2, 5)
357
+ .reshape(self.out_channels, F, H, W)
358
+ )
359
+ return x
360
+
361
+ @staticmethod
362
+ def create_coordinate_grid(size, start=None, device=None):
363
+ if start is None:
364
+ start = (0 for _ in size)
365
+ axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
366
+ grids = torch.meshgrid(axes, indexing="ij")
367
+ return torch.stack(grids, dim=-1)
368
+
369
+ def patchify_and_embed(
370
+ self,
371
+ all_image: List[torch.Tensor],
372
+ all_cap_feats: List[torch.Tensor],
373
+ patch_size: int,
374
+ f_patch_size: int,
375
+ ):
376
+ pH = pW = patch_size
377
+ pF = f_patch_size
378
+ device = all_image[0].device
379
+
380
+ all_image_out = []
381
+ all_image_size = []
382
+ all_image_pos_ids = []
383
+ all_image_pad_mask = []
384
+ all_cap_pos_ids = []
385
+ all_cap_pad_mask = []
386
+ all_cap_feats_out = []
387
+
388
+ for _, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
389
+ cap_ori_len = len(cap_feat)
390
+ cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
391
+ cap_padded_pos_ids = self.create_coordinate_grid(
392
+ size=(cap_ori_len + cap_padding_len, 1, 1),
393
+ start=(1, 0, 0),
394
+ device=device,
395
+ ).flatten(0, 2)
396
+ all_cap_pos_ids.append(cap_padded_pos_ids)
397
+ # pad mask
398
+ all_cap_pad_mask.append(
399
+ torch.cat(
400
+ [
401
+ torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
402
+ torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
403
+ ],
404
+ dim=0,
405
+ )
406
+ if cap_padding_len > 0
407
+ else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
408
+ )
409
+ # padded feature
410
+ all_cap_feats_out.append(
411
+ torch.cat(
412
+ [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
413
+ dim=0,
414
+ )
415
+ if cap_padding_len > 0
416
+ else cap_feat
417
+ )
418
+
419
+ C, F, H, W = image.size()
420
+ all_image_size.append((F, H, W))
421
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
422
+
423
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
424
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
425
+
426
+ image_ori_len = len(image)
427
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
428
+
429
+ image_ori_pos_ids = self.create_coordinate_grid(
430
+ size=(F_tokens, H_tokens, W_tokens),
431
+ start=(cap_ori_len + cap_padding_len + 1, 0, 0),
432
+ device=device,
433
+ ).flatten(0, 2)
434
+ image_padded_pos_ids = torch.cat(
435
+ [
436
+ image_ori_pos_ids,
437
+ self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
438
+ .flatten(0, 2)
439
+ .repeat(image_padding_len, 1),
440
+ ],
441
+ dim=0,
442
+ )
443
+ all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
444
+ # pad mask
445
+ image_pad_mask = torch.cat(
446
+ [
447
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
448
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
449
+ ],
450
+ dim=0,
451
+ )
452
+ all_image_pad_mask.append(
453
+ image_pad_mask
454
+ if image_padding_len > 0
455
+ else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
456
+ )
457
+ # padded feature
458
+ image_padded_feat = torch.cat(
459
+ [image, image[-1:].repeat(image_padding_len, 1)],
460
+ dim=0,
461
+ )
462
+ all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
463
+
464
+ return (
465
+ all_image_out,
466
+ all_cap_feats_out,
467
+ all_image_size,
468
+ all_image_pos_ids,
469
+ all_cap_pos_ids,
470
+ all_image_pad_mask,
471
+ all_cap_pad_mask,
472
+ )
473
+
474
+ def forward(
475
+ self,
476
+ x: List[torch.Tensor],
477
+ t,
478
+ cap_feats: List[torch.Tensor],
479
+ patch_size=2,
480
+ f_patch_size=1,
481
+ ):
482
+ assert patch_size in self.all_patch_size
483
+ assert f_patch_size in self.all_f_patch_size
484
+
485
+ bsz = len(x)
486
+ device = x[0].device
487
+ t = t * self.t_scale
488
+ t = self.t_embedder(t)
489
+
490
+ (
491
+ x,
492
+ cap_feats,
493
+ x_size,
494
+ x_pos_ids,
495
+ cap_pos_ids,
496
+ x_inner_pad_mask,
497
+ cap_inner_pad_mask,
498
+ ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
499
+
500
+ x_item_seqlens = [len(_) for _ in x]
501
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
502
+ x_max_item_seqlen = max(x_item_seqlens)
503
+
504
+ x = torch.cat(x, dim=0)
505
+ x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
506
+
507
+ adaln_input = t.type_as(x)
508
+ x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
509
+ x = list(x.split(x_item_seqlens, dim=0))
510
+ x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
511
+
512
+ x = pad_sequence(x, batch_first=True, padding_value=0.0)
513
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
514
+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
515
+ x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
516
+
517
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
518
+ for i, seq_len in enumerate(x_item_seqlens):
519
+ x_attn_mask[i, :seq_len] = 1
520
+
521
+ for layer in self.noise_refiner:
522
+ x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
523
+
524
+ cap_item_seqlens = [len(_) for _ in cap_feats]
525
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
526
+ cap_max_item_seqlen = max(cap_item_seqlens)
527
+
528
+ cap_feats = torch.cat(cap_feats, dim=0)
529
+ cap_feats = self.cap_embedder(cap_feats)
530
+ cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
531
+ cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
532
+ cap_freqs_cis = list(
533
+ self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
534
+ )
535
+
536
+ cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
537
+ cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
538
+ cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] # same for dynamo compatibility
539
+
540
+ cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
541
+ for i, seq_len in enumerate(cap_item_seqlens):
542
+ cap_attn_mask[i, :seq_len] = 1
543
+
544
+ for layer in self.context_refiner:
545
+ cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
546
+
547
+ unified = []
548
+ unified_freqs_cis = []
549
+ for i in range(bsz):
550
+ x_len = x_item_seqlens[i]
551
+ cap_len = cap_item_seqlens[i]
552
+ unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
553
+ unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
554
+ unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
555
+ assert unified_item_seqlens == [len(_) for _ in unified]
556
+ unified_max_item_seqlen = max(unified_item_seqlens)
557
+
558
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
559
+ unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
560
+ unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
561
+ for i, seq_len in enumerate(unified_item_seqlens):
562
+ unified_attn_mask[i, :seq_len] = 1
563
+
564
+ for layer in self.layers:
565
+ unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
566
+
567
+ unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
568
+ unified = list(unified.unbind(dim=0))
569
+ x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
570
+
571
+ return x, {}
upscale_models/1x-ITF-SkinDiffDetail-Lite-v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94d368b633614958f84f335b129fd85abd30200e8fbc575b859ba6762116222b
3
+ size 20099337
upscale_models/1x_PureVision.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c3bec2111d0f7c3926f3171f37ce28e72502b744e084c566d8960c06a6a06a3
3
+ size 67120607
upscale_models/2x_PureVision.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c49109c8257b80b3cddf0110786ad11c1ca05214a470bfcc7fd49b6461dfcaee
3
+ size 67037663
upscale_models/4x-ClearRealityV1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4cd3a25b00e0be949d4302fc774eb4d7f2ed5f47cdb51551e2d75fa6562e51e
3
+ size 9016074
upscale_models/4x-UltraSharp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5812231fc936b42af08a5edba784195495d303d5b3248c24489ef0c4021fe01
3
+ size 66961958
upscale_models/4xFFHQDAT.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa91faa8c1f72c32646d71abf51e952c81b4984948e89e6e5a8c40822a6cf3cc
3
+ size 154152604
upscale_models/4xNomos8k_atd_jpg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b2bfb0e822c79594288dd43efaec213b6f0244384bd98db75072b0ce5a729fe
3
+ size 81959074
upscale_models/4xNomos8k_span_otf_weak.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26e7ef6483faf93b47f48af262ed4bb8dededa20af34a923e344cb63cafeec0a
3
+ size 9015866
upscale_models/4x_NMKD-Siax_200k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:560424d9f68625713fc47e9e7289a98aabe1d744e1cd6a9ae5a35e9957fd127e
3
+ size 66957746
upscale_models/4x_NMKD-Superscale-SP_178000_G.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d1b0078fe71446e0469d8d4df59e96baa80d83cda600d68237d655830821bcc
3
+ size 66958607
upscale_models/4x_foolhardy_Remacri.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1a73bd89c2da1ae494774746398689048b5a892bd9653e146713f9df8bca86a
3
+ size 67025055
upscale_models/RealESRGAN_x4plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1
3
+ size 67040989
vae_approx/taew2_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d26151e76cdc2c9424bef988de874b33d9a53f30ef3060cd556c429c469c797e
3
+ size 22678901
vitmatte/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bda9289db1bb6762d978b42d1c62ae3f34daf7497171a347a1d09657efd788cb
3
+ size 103294572