Diffusers
Safetensors
BiliSakura commited on
Commit
c5cfae9
·
verified ·
1 Parent(s): 16fb05f

Upload folder using huggingface_hub

Browse files
Files changed (46) hide show
  1. README.md +267 -0
  2. __init__.py +9 -0
  3. convert_forward_renderer.py +99 -0
  4. convert_inverse_renderer_512.py +118 -0
  5. dinov2/README.md +61 -0
  6. dinov2/config.json +24 -0
  7. dinov2/model.safetensors +3 -0
  8. dinov2/preprocessor_config.json +27 -0
  9. imaa/config.json +11 -0
  10. imaa/imaa.py +205 -0
  11. imaa/model.safetensors +3 -0
  12. model_index.json +39 -0
  13. pipeline_intrinsic_weather.py +486 -0
  14. pipeline_intrinsic_weather_forward.py +1191 -0
  15. pipeline_intrinsic_weather_inverse.py +1119 -0
  16. pipeline_utils.py +104 -0
  17. scheduler/scheduler_config.json +9 -0
  18. test_all_pipelines.py +141 -0
  19. text_encoder/config.json +24 -0
  20. text_encoder/model.safetensors +3 -0
  21. text_encoder_2/config.json +24 -0
  22. text_encoder_2/model.safetensors +3 -0
  23. text_encoder_3/config.json +31 -0
  24. text_encoder_3/model-00001-of-00002.safetensors +3 -0
  25. text_encoder_3/model-00002-of-00002.safetensors +3 -0
  26. text_encoder_3/model.safetensors.index.json +226 -0
  27. tokenizer/merges.txt +0 -0
  28. tokenizer/special_tokens_map.json +30 -0
  29. tokenizer/tokenizer_config.json +30 -0
  30. tokenizer/vocab.json +0 -0
  31. tokenizer_2/merges.txt +0 -0
  32. tokenizer_2/special_tokens_map.json +30 -0
  33. tokenizer_2/tokenizer_config.json +38 -0
  34. tokenizer_2/vocab.json +0 -0
  35. tokenizer_3/special_tokens_map.json +125 -0
  36. tokenizer_3/spiece.model +3 -0
  37. tokenizer_3/tokenizer.json +0 -0
  38. tokenizer_3/tokenizer_config.json +940 -0
  39. transformer/forward/config.json +31 -0
  40. transformer/forward/diffusion_pytorch_model.safetensors +3 -0
  41. transformer/forward/lora/pytorch_lora_weights.safetensors +3 -0
  42. transformer/inverse-512/config.json +31 -0
  43. transformer/inverse-512/diffusion_pytorch_model.safetensors +3 -0
  44. transformer/inverse-512/transformer_intrinsic_weather.py +1527 -0
  45. vae/config.json +36 -0
  46. vae/diffusion_pytorch_model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IntrinsicWeather (Diffusers)
2
+
3
+ Diffusers-format checkpoint for **[IntrinsicWeather: Controllable Weather Editing in Intrinsic Space](https://arxiv.org/pdf/2508.06982v6)** (CVPR 2026 Highlight).
4
+
5
+ This repo bundles inverse rendering, forward weather rendering, and the IMAA gating module into a single Hugging Face–compatible layout. Shared Stable Diffusion 3 components (VAE, text encoders, tokenizers, scheduler) are stored once; task-specific transformers live under `transformer/<variant>/`.
6
+
7
+ ## Model layout
8
+
9
+ ```
10
+ IntrisicWeather-diffusers/
11
+ ├── dinov2/ # bundled DINOv2 weights (for IMAA / decomposition)
12
+ ├── imaa/ # Intrinsic Map-Aware Attention weights
13
+ ├── text_encoder/, text_encoder_2/, text_encoder_3/
14
+ ├── tokenizer/, tokenizer_2/, tokenizer_3/
15
+ ├── vae/, scheduler/
16
+ ├── transformer/
17
+ │ ├── inverse-512/ # IntrinsicWeatherSD3Transformer2DModel (in_channels=32)
18
+ │ │ └── transformer_intrinsic_weather.py
19
+ │ └── forward/ # SD3Transformer2DModel (in_channels=96)
20
+ │ └── lora/ # forward-renderer LoRA (loaded by default)
21
+ ├── pipeline_intrinsic_weather.py # unified: RGB → maps → weather RGB
22
+ ├── pipeline_intrinsic_weather_inverse.py # inverse only
23
+ ├── pipeline_intrinsic_weather_forward.py # forward only
24
+ ├── pipeline_utils.py
25
+ ├── model_index.json
26
+ ├── convert_inverse_renderer_512.py
27
+ ├── convert_forward_renderer.py
28
+ └── test_all_pipelines.py
29
+ ```
30
+
31
+ | Component | Source | Notes |
32
+ |-----------|--------|-------|
33
+ | Inverse transformer | [GilgameshYX/InverseRenderer-512](https://huggingface.co/GilgameshYX/InverseRenderer-512) | 512×512 decomposition |
34
+ | Forward transformer + LoRA | [GilgameshYX/ForwardRenderer](https://huggingface.co/GilgameshYX/ForwardRenderer) | LoRA in `transformer/forward/lora/` |
35
+ | IMAA | InverseRenderer-512 `imaa.pth` | Required for map-aware inverse attention |
36
+ | SD3 shared weights | `stabilityai/stable-diffusion-3-medium-diffusers` | VAE + text encoders only |
37
+ | Transformer config | `stabilityai/stable-diffusion-3.5-medium` | Architecture template for weight loading |
38
+
39
+ ## Requirements
40
+
41
+ - Python 3.10+
42
+ - CUDA GPU recommended (~20 GB VRAM for full end-to-end inference at 512×512)
43
+ - `torch`, `diffusers>=0.38`, `transformers`, `safetensors`, `torchvision`, `Pillow`
44
+
45
+ ```bash
46
+ pip install torch diffusers transformers safetensors torchvision pillow accelerate
47
+ ```
48
+
49
+ ## Quick start (end-to-end weather edit)
50
+
51
+ The unified pipeline decomposes an input RGB image into intrinsic maps, then renders a weather-conditioned result. **DINOv2** is required for decomposition (bundled under `dinov2/`, or use `facebook/dinov2-base` from Hugging Face).
52
+
53
+ ```python
54
+ from pathlib import Path
55
+
56
+ import torch
57
+ from PIL import Image
58
+ from transformers import AutoImageProcessor, AutoModel
59
+
60
+ from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
61
+
62
+ repo_dir = Path(".").resolve() # path to this folder
63
+ device = "cuda"
64
+ dtype = torch.bfloat16
65
+
66
+ pipe = IntrinsicWeatherPipeline.from_pretrained(
67
+ repo_dir,
68
+ inverse_transformer_subfolder="inverse-512",
69
+ forward_transformer_subfolder="forward",
70
+ device=device,
71
+ local_files_only=True,
72
+ torch_dtype=dtype,
73
+ load_lora=True,
74
+ load_imaa=True,
75
+ )
76
+
77
+ dino_path = repo_dir / "dinov2"
78
+ dino_processor = AutoImageProcessor.from_pretrained(dino_path, local_files_only=True)
79
+ dino_model = AutoModel.from_pretrained(dino_path, local_files_only=True).to(device)
80
+ dino_model.eval()
81
+
82
+ image = Image.open("input.png").convert("RGB")
83
+ result = pipe(
84
+ image=image,
85
+ weather="snowy", # rainy | sunny | snowy | foggy | overcast | night
86
+ dino_model=dino_model,
87
+ dino_processor=dino_processor,
88
+ image_size=512,
89
+ render_size=512,
90
+ num_inverse_steps=50,
91
+ num_forward_steps=50,
92
+ guidance_scale=6.0,
93
+ image_guidance_scale=1.5,
94
+ generator=torch.Generator(device=device).manual_seed(42),
95
+ )
96
+ result.images[0].save("output_snowy.png")
97
+ ```
98
+
99
+ Run from inside this directory (or add it to `PYTHONPATH`) so `pipeline_intrinsic_weather.py` and `imaa/` resolve correctly.
100
+
101
+ ## Pipelines
102
+
103
+ ### 1. `IntrinsicWeatherPipeline` (unified)
104
+
105
+ Full pipeline: **RGB → intrinsic maps → weather RGB**.
106
+
107
+ ```python
108
+ pipe = IntrinsicWeatherPipeline.from_pretrained(
109
+ repo_dir,
110
+ inverse_transformer_subfolder="inverse-512",
111
+ forward_transformer_subfolder="forward",
112
+ device="cuda",
113
+ torch_dtype=torch.bfloat16,
114
+ )
115
+ ```
116
+
117
+ Useful kwargs:
118
+
119
+ | Argument | Default | Description |
120
+ |----------|---------|-------------|
121
+ | `inverse_transformer_subfolder` | `"inverse-512"` | Inverse transformer under `transformer/` |
122
+ | `forward_transformer_subfolder` | `"forward"` | Forward transformer under `transformer/` |
123
+ | `load_lora` | `True` | Load LoRA from `transformer/forward/lora/` |
124
+ | `load_imaa` | `True` | Load IMAA weights from `imaa/` |
125
+ | `device` | `None` | Moves all modules to device (IMAA stays float32) |
126
+
127
+ Sub-methods:
128
+
129
+ - `pipe.decompose(image, dino_model, dino_processor, ...)` → dict of intrinsic maps
130
+ - `pipe.render(maps, weather="rainy", ...)` → weather-conditioned RGB
131
+
132
+ ### 2. `IntrinsicWeatherInversePipeline`
133
+
134
+ Inverse rendering only (single intrinsic map per call).
135
+
136
+ ```python
137
+ from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
138
+
139
+ pipe = IntrinsicWeatherInversePipeline.from_pretrained(
140
+ repo_dir,
141
+ transformer_subfolder="inverse-512",
142
+ device="cuda",
143
+ torch_dtype=torch.bfloat16,
144
+ )
145
+ ```
146
+
147
+ Load the transformer separately if needed:
148
+
149
+ ```python
150
+ transformer = IntrinsicWeatherInversePipeline.load_transformer(
151
+ "inverse-512", repo_dir, device="cuda"
152
+ )
153
+ pipe = IntrinsicWeatherInversePipeline.from_pretrained(
154
+ repo_dir, transformer=transformer, device="cuda"
155
+ )
156
+ ```
157
+
158
+ IMAA and DINO are used by the unified pipeline’s `decompose()` path; for standalone inverse calls, pass `map_aware_mask` from IMAA manually (see `test_all_pipelines.py`).
159
+
160
+ ### 3. `IntrinsicWeatherForwardPipeline`
161
+
162
+ Forward weather rendering from intrinsic maps.
163
+
164
+ ```python
165
+ from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
166
+
167
+ pipe = IntrinsicWeatherForwardPipeline.from_pretrained(
168
+ repo_dir,
169
+ transformer_subfolder="forward",
170
+ device="cuda",
171
+ torch_dtype=torch.bfloat16,
172
+ load_lora=True,
173
+ )
174
+ ```
175
+
176
+ LoRA weights are read from `transformer/forward/lora/` when `load_lora=True`.
177
+
178
+ ## Weather presets
179
+
180
+ Built-in weather keys (or pass a custom prompt string):
181
+
182
+ | Key | Prompt |
183
+ |-----|--------|
184
+ | `rainy` | A rainy day. |
185
+ | `sunny` | A sunny day. |
186
+ | `snowy` | A snowy day. |
187
+ | `foggy` | A foggy day. |
188
+ | `overcast` | An overcast day. |
189
+ | `night` | A night scene. |
190
+
191
+ ## Intrinsic maps (AoVs)
192
+
193
+ The inverse renderer produces five appearance-of-variety maps:
194
+
195
+ `albedo`, `normal`, `roughness`, `metallic`, `irradiance`
196
+
197
+ ## Loading transformers manually
198
+
199
+ Transformers are stored per variant under `transformer/<subfolder>/`. Use `pipeline_utils.load_transformer_from_subfolder`:
200
+
201
+ ```python
202
+ from pipeline_utils import load_transformer_from_subfolder, load_transformer_lora
203
+
204
+ inverse = load_transformer_from_subfolder(repo_dir, "inverse-512", device="cuda")
205
+ forward = load_transformer_from_subfolder(repo_dir, "forward", device="cuda")
206
+ ```
207
+
208
+ - `inverse-512` uses a custom `IntrinsicWeatherSD3Transformer2DModel` (`in_channels=32`).
209
+ - `forward` uses the standard `SD3Transformer2DModel` (`in_channels=96`).
210
+
211
+ ## Dtype and device notes
212
+
213
+ - Default dtype is **`torch.bfloat16`** for transformers, VAE, and text encoders.
214
+ - **IMAA** stays in **float32** (DINO patch tokens are float32).
215
+ - Pass `device="cuda"` to `from_pretrained` on all three pipeline classes; the unified pipeline moves every registered module to the target device automatically.
216
+
217
+ ## Testing
218
+
219
+ Smoke-test all pipelines on CUDA:
220
+
221
+ ```bash
222
+ python test_all_pipelines.py
223
+ ```
224
+
225
+ Runs 2-step inverse, forward (with LoRA), and unified load checks with `bfloat16`.
226
+
227
+ ## Re-converting from original checkpoints
228
+
229
+ If you have the raw GilgameshYX checkpoints:
230
+
231
+ ```bash
232
+ # Inverse renderer (512) + IMAA
233
+ python convert_inverse_renderer_512.py
234
+
235
+ # Forward renderer + LoRA
236
+ python convert_forward_renderer.py
237
+ ```
238
+
239
+ See `conversion_metadata.json` and `conversion_metadata_forward.json` for source paths used during conversion.
240
+
241
+ ## Hugging Face Hub loading
242
+
243
+ When published to the Hub, load with `trust_remote_code=True`:
244
+
245
+ ```python
246
+ from diffusers import DiffusionPipeline
247
+
248
+ pipe = DiffusionPipeline.from_pretrained(
249
+ "BiliSakura/IntrisicWeather-diffusers",
250
+ custom_pipeline="pipeline_intrinsic_weather.py",
251
+ trust_remote_code=True,
252
+ torch_dtype=torch.bfloat16,
253
+ )
254
+ ```
255
+
256
+ For local use, importing `IntrinsicWeatherPipeline` directly (as in Quick start) is simpler and avoids Hub cache path issues with custom modules.
257
+
258
+ ## References
259
+
260
+ - **Paper:** [IntrinsicWeather (arXiv:2508.06982)](https://arxiv.org/pdf/2508.06982v6)
261
+ - **Project page:** https://yixinzhu042.github.io/IntrinsicWeather/
262
+ - **Upstream diffusers repo:** [IntrinsicWeather-diffusers](https://github.com/YixinZhu042/IntrinsicWeather)
263
+ - **Original weights:** [GilgameshYX/InverseRenderer-512](https://huggingface.co/GilgameshYX/InverseRenderer-512), [GilgameshYX/ForwardRenderer](https://huggingface.co/GilgameshYX/ForwardRenderer)
264
+
265
+ ## License
266
+
267
+ Weights and code follow the licenses of the upstream IntrinsicWeather project and the Stable Diffusion 3 components used for shared modules.
__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
2
+ from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
3
+ from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
4
+
5
+ __all__ = [
6
+ "IntrinsicWeatherPipeline",
7
+ "IntrinsicWeatherForwardPipeline",
8
+ "IntrinsicWeatherInversePipeline",
9
+ ]
convert_forward_renderer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert GilgameshYX ForwardRenderer into BiliSakura IntrisicWeather-diffusers layout."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import shutil
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ from diffusers.models.transformers import SD3Transformer2DModel
12
+
13
+ COLLECTION_ROOT = Path(__file__).resolve().parent
14
+ INTRINSIC_REPO = Path("/data/projects/IntrinsicWeather-diffusers")
15
+ sys.path.insert(0, str(INTRINSIC_REPO / "src"))
16
+ sys.path.insert(0, str(INTRINSIC_REPO))
17
+
18
+ from scripts._conversion_utils import ( # noqa: E402
19
+ expand_sd3_input_projection,
20
+ load_torch,
21
+ write_scheduler_config,
22
+ )
23
+
24
+ from _collection_setup import install_hub_pipelines # noqa: E402
25
+
26
+ SD3_PATH = Path(
27
+ "/data/projects/Visual-Generative-Foundation-Model-Collection/models/stabilityai/stable-diffusion-3-medium-diffusers"
28
+ )
29
+ SD35_TRANSFORMER_REPO = "stabilityai/stable-diffusion-3.5-medium"
30
+ CKPT_PATH = Path(
31
+ "/data/projects/Visual-Generative-Foundation-Model-Collection/models/GilgameshYX/ForwardRenderer"
32
+ )
33
+ OUTPUT_ROOT = COLLECTION_ROOT
34
+ TRANSFORMER_VARIANT = "forward"
35
+ SHARED_COMPONENTS = (
36
+ "text_encoder",
37
+ "text_encoder_2",
38
+ "text_encoder_3",
39
+ "tokenizer",
40
+ "tokenizer_2",
41
+ "tokenizer_3",
42
+ "vae",
43
+ "scheduler",
44
+ )
45
+
46
+
47
+ def copy_sd3_shared_components(sd3_path: Path, output_path: Path) -> None:
48
+ for name in SHARED_COMPONENTS:
49
+ src = sd3_path / name
50
+ dst = output_path / name
51
+ if dst.exists():
52
+ print(f"Skipping existing shared component: {dst}")
53
+ continue
54
+ print(f"Copying {name} ...")
55
+ shutil.copytree(src, dst)
56
+
57
+
58
+ def main() -> None:
59
+ transformer_dir = OUTPUT_ROOT / "transformer" / TRANSFORMER_VARIANT
60
+ transformer_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ print(f"Ensuring shared SD3 components from {SD3_PATH} ...")
63
+ copy_sd3_shared_components(SD3_PATH, OUTPUT_ROOT)
64
+ write_scheduler_config(OUTPUT_ROOT)
65
+ install_hub_pipelines(OUTPUT_ROOT)
66
+
67
+ print("Converting forward renderer transformer ...")
68
+ transformer = SD3Transformer2DModel.from_config(
69
+ SD3Transformer2DModel.load_config(SD35_TRANSFORMER_REPO, subfolder="transformer")
70
+ )
71
+ transformer = expand_sd3_input_projection(transformer, in_channels=96)
72
+ transformer.load_state_dict(load_torch(CKPT_PATH / "pytorch_model.bin"), strict=True)
73
+ transformer.save_pretrained(transformer_dir.as_posix(), safe_serialization=True)
74
+
75
+ print("Saving LoRA weights ...")
76
+ lora_dir = transformer_dir / "lora"
77
+ lora_dir.mkdir(parents=True, exist_ok=True)
78
+ shutil.copy2(CKPT_PATH / "pytorch_lora_weights.safetensors", lora_dir / "pytorch_lora_weights.safetensors")
79
+
80
+ conversion_metadata = {
81
+ "task": "forward_renderer",
82
+ "transformer_variant": TRANSFORMER_VARIANT,
83
+ "source_transformer_checkpoint": str((CKPT_PATH / "pytorch_model.bin").resolve()),
84
+ "source_lora_checkpoint": str((CKPT_PATH / "pytorch_lora_weights.safetensors").resolve()),
85
+ "lora_dir": str((lora_dir).resolve()),
86
+ "sd3_path": str(SD3_PATH.resolve()),
87
+ "sd35_transformer_repo": SD35_TRANSFORMER_REPO,
88
+ "in_channels": 96,
89
+ }
90
+ (OUTPUT_ROOT / "conversion_metadata_forward.json").write_text(
91
+ json.dumps(conversion_metadata, indent=2) + "\n",
92
+ encoding="utf-8",
93
+ )
94
+ print(f"Saved transformer to: {transformer_dir}")
95
+ print("Load with: load_forward_pipeline(transformer_subfolder='forward')")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
convert_inverse_renderer_512.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert GilgameshYX InverseRenderer-512 into BiliSakura IntrisicWeather-diffusers layout."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import shutil
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ from diffusers.models.transformers import SD3Transformer2DModel
12
+
13
+ COLLECTION_ROOT = Path(__file__).resolve().parent
14
+ INTRINSIC_REPO = Path("/data/projects/IntrinsicWeather-diffusers")
15
+ sys.path.insert(0, str(INTRINSIC_REPO / "src"))
16
+ sys.path.insert(0, str(INTRINSIC_REPO))
17
+
18
+ from intrinsic_weather.models.transformers.transformer_intrinsic_weather import ( # noqa: E402
19
+ IntrinsicWeatherSD3Transformer2DModel,
20
+ )
21
+ from scripts._conversion_utils import ( # noqa: E402
22
+ ROOT as REPO_ROOT,
23
+ expand_sd3_input_projection,
24
+ merge_sharded_state_dict,
25
+ save_imaa_bundle,
26
+ write_scheduler_config,
27
+ )
28
+
29
+ from _collection_setup import install_hub_pipelines # noqa: E402
30
+
31
+ SD3_PATH = Path(
32
+ "/data/projects/Visual-Generative-Foundation-Model-Collection/models/stabilityai/stable-diffusion-3-medium-diffusers"
33
+ )
34
+ SD35_TRANSFORMER_REPO = "stabilityai/stable-diffusion-3.5-medium"
35
+ CKPT_PATH = Path(
36
+ "/data/projects/Visual-Generative-Foundation-Model-Collection/models/GilgameshYX/InverseRenderer-512"
37
+ )
38
+ OUTPUT_ROOT = COLLECTION_ROOT
39
+ TRANSFORMER_VARIANT = "inverse-512"
40
+ SHARED_COMPONENTS = (
41
+ "text_encoder",
42
+ "text_encoder_2",
43
+ "text_encoder_3",
44
+ "tokenizer",
45
+ "tokenizer_2",
46
+ "tokenizer_3",
47
+ "vae",
48
+ "scheduler",
49
+ )
50
+
51
+
52
+ def copy_sd3_shared_components(sd3_path: Path, output_path: Path) -> None:
53
+ for name in SHARED_COMPONENTS:
54
+ src = sd3_path / name
55
+ dst = output_path / name
56
+ if dst.exists():
57
+ print(f"Skipping existing shared component: {dst}")
58
+ continue
59
+ print(f"Copying {name} ...")
60
+ shutil.copytree(src, dst)
61
+
62
+
63
+ def main() -> None:
64
+ transformer_dir = OUTPUT_ROOT / "transformer" / TRANSFORMER_VARIANT
65
+ transformer_dir.mkdir(parents=True, exist_ok=True)
66
+
67
+ print(f"Copying shared SD3 components from {SD3_PATH} ...")
68
+ copy_sd3_shared_components(SD3_PATH, OUTPUT_ROOT)
69
+ write_scheduler_config(OUTPUT_ROOT)
70
+ install_hub_pipelines(OUTPUT_ROOT)
71
+
72
+ print("Converting inverse renderer transformer ...")
73
+ base_transformer = SD3Transformer2DModel.from_config(
74
+ SD3Transformer2DModel.load_config(SD35_TRANSFORMER_REPO, subfolder="transformer")
75
+ )
76
+ base_transformer = expand_sd3_input_projection(base_transformer, in_channels=32)
77
+ custom_blocks = IntrinsicWeatherSD3Transformer2DModel.from_config(base_transformer.config)
78
+ custom_blocks.load_state_dict(
79
+ merge_sharded_state_dict(
80
+ [
81
+ CKPT_PATH / "pytorch_model-00001-of-00002.bin",
82
+ CKPT_PATH / "pytorch_model-00002-of-00002.bin",
83
+ ]
84
+ ),
85
+ strict=True,
86
+ )
87
+ custom_blocks.save_pretrained(transformer_dir.as_posix(), safe_serialization=True)
88
+ shutil.copy2(
89
+ REPO_ROOT / "src" / "intrinsic_weather" / "models" / "transformers" / "transformer_intrinsic_weather.py",
90
+ transformer_dir / "transformer_intrinsic_weather.py",
91
+ )
92
+
93
+ print("Saving IMAA weights ...")
94
+ save_imaa_bundle(CKPT_PATH / "imaa.pth", OUTPUT_ROOT, safe_serialization=True)
95
+
96
+ conversion_metadata = {
97
+ "task": "inverse_renderer",
98
+ "resolution": 512,
99
+ "transformer_variant": TRANSFORMER_VARIANT,
100
+ "source_transformer_checkpoints": [
101
+ str((CKPT_PATH / "pytorch_model-00001-of-00002.bin").resolve()),
102
+ str((CKPT_PATH / "pytorch_model-00002-of-00002.bin").resolve()),
103
+ ],
104
+ "source_imaa_checkpoint": str((CKPT_PATH / "imaa.pth").resolve()),
105
+ "sd3_path": str(SD3_PATH.resolve()),
106
+ "sd35_transformer_repo": SD35_TRANSFORMER_REPO,
107
+ "in_channels": 32,
108
+ }
109
+ (OUTPUT_ROOT / "conversion_metadata.json").write_text(
110
+ json.dumps(conversion_metadata, indent=2) + "\n",
111
+ encoding="utf-8",
112
+ )
113
+ print(f"Saved transformer to: {transformer_dir}")
114
+ print("Load with: load_inverse_pipeline(transformer_subfolder='inverse-512')")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
dinov2/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - dino
5
+ - vision
6
+ inference: false
7
+ ---
8
+
9
+ # Vision Transformer (base-sized model) trained using DINOv2
10
+
11
+ Vision Transformer (ViT) model trained using the DINOv2 method. It was introduced in the paper [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by Oquab et al. and first released in [this repository](https://github.com/facebookresearch/dinov2).
12
+
13
+ Disclaimer: The team releasing DINOv2 did not write a model card for this model so this model card has been written by the Hugging Face team.
14
+
15
+ ## Model description
16
+
17
+ The Vision Transformer (ViT) is a transformer encoder model (BERT-like) pretrained on a large collection of images in a self-supervised fashion.
18
+
19
+ Images are presented to the model as a sequence of fixed-size patches, which are linearly embedded. One also adds a [CLS] token to the beginning of a sequence to use it for classification tasks. One also adds absolute position embeddings before feeding the sequence to the layers of the Transformer encoder.
20
+
21
+ Note that this model does not include any fine-tuned heads.
22
+
23
+ By pre-training the model, it learns an inner representation of images that can then be used to extract features useful for downstream tasks: if you have a dataset of labeled images for instance, you can train a standard classifier by placing a linear layer on top of the pre-trained encoder. One typically places a linear layer on top of the [CLS] token, as the last hidden state of this token can be seen as a representation of an entire image.
24
+
25
+ ## Intended uses & limitations
26
+
27
+ You can use the raw model for feature extraction. See the [model hub](https://huggingface.co/models?search=facebook/dinov2) to look for
28
+ fine-tuned versions on a task that interests you.
29
+
30
+ ### How to use
31
+
32
+ Here is how to use this model:
33
+
34
+ ```python
35
+ from transformers import AutoImageProcessor, AutoModel
36
+ from PIL import Image
37
+ import requests
38
+
39
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
40
+ image = Image.open(requests.get(url, stream=True).raw)
41
+
42
+ processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
43
+ model = AutoModel.from_pretrained('facebook/dinov2-base')
44
+
45
+ inputs = processor(images=image, return_tensors="pt")
46
+ outputs = model(**inputs)
47
+ last_hidden_states = outputs.last_hidden_state
48
+ ```
49
+
50
+ ### BibTeX entry and citation info
51
+
52
+ ```bibtex
53
+ misc{oquab2023dinov2,
54
+ title={DINOv2: Learning Robust Visual Features without Supervision},
55
+ author={Maxime Oquab and Timothée Darcet and Théo Moutakanni and Huy Vo and Marc Szafraniec and Vasil Khalidov and Pierre Fernandez and Daniel Haziza and Francisco Massa and Alaaeldin El-Nouby and Mahmoud Assran and Nicolas Ballas and Wojciech Galuba and Russell Howes and Po-Yao Huang and Shang-Wen Li and Ishan Misra and Michael Rabbat and Vasu Sharma and Gabriel Synnaeve and Hu Xu and Hervé Jegou and Julien Mairal and Patrick Labatut and Armand Joulin and Piotr Bojanowski},
56
+ year={2023},
57
+ eprint={2304.07193},
58
+ archivePrefix={arXiv},
59
+ primaryClass={cs.CV}
60
+ }
61
+ ```
dinov2/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Dinov2Model"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "drop_path_rate": 0.0,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 768,
10
+ "image_size": 518,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_eps": 1e-06,
13
+ "layerscale_value": 1.0,
14
+ "mlp_ratio": 4,
15
+ "model_type": "dinov2",
16
+ "num_attention_heads": 12,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 12,
19
+ "patch_size": 14,
20
+ "qkv_bias": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.31.0.dev0",
23
+ "use_swiglu_ffn": false
24
+ }
dinov2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d73036b56966966d07975d696bde331762f37297e2f095de8cea0040c3aa0841
3
+ size 346345912
dinov2/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.485,
13
+ 0.456,
14
+ 0.406
15
+ ],
16
+ "image_processor_type": "BitImageProcessor",
17
+ "image_std": [
18
+ 0.229,
19
+ 0.224,
20
+ 0.225
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 256
26
+ }
27
+ }
imaa/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "IMAA",
3
+ "num_maps": 5,
4
+ "map_embedding_dim": 256,
5
+ "common_dim": 128,
6
+ "conv_channels": [
7
+ 128,
8
+ 64
9
+ ],
10
+ "dino_patch_dim": 768
11
+ }
imaa/imaa.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+
23
+ def extract_patch_tokens_min_windows(
24
+ images: torch.Tensor,
25
+ model: nn.Module,
26
+ processor,
27
+ window_size: int = 224,
28
+ device: str | torch.device = "cuda",
29
+ ) -> torch.Tensor:
30
+ r"""
31
+ Tile each image with a minimal window set and return averaged DINO patch tokens.
32
+
33
+ Args:
34
+ images (`torch.Tensor`): Batch of RGB images `(B, C, H, W)`.
35
+ model: DINO vision transformer.
36
+ processor: Hugging Face image processor for DINO.
37
+ window_size (`int`): Sliding-window size in pixels.
38
+ device: Device for intermediate tensors.
39
+
40
+ Returns:
41
+ `torch.Tensor` of shape `(B, H//patch, W//patch, hidden_size)`.
42
+ """
43
+ batch_size, _, height, width = images.shape
44
+ hidden_size = model.config.hidden_size
45
+ patch_size = model.config.patch_size
46
+ token_avgs = []
47
+
48
+ for batch_idx in range(batch_size):
49
+ image = images[batch_idx].float()
50
+ if image.max() <= 1.0:
51
+ image_np = (image.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype("uint8")
52
+ else:
53
+ image_np = image.permute(1, 2, 0).cpu().numpy().clip(0, 255).astype("uint8")
54
+
55
+ token_sum = torch.zeros((height // patch_size, width // patch_size, hidden_size), device=device)
56
+ token_count = torch.zeros((height // patch_size, width // patch_size, 1), device=device)
57
+
58
+ num_y = (height + window_size - 1) // window_size
59
+ num_x = (width + window_size - 1) // window_size
60
+ y_positions = [index * window_size for index in range(num_y - 1)] + [height - window_size]
61
+ x_positions = [index * window_size for index in range(num_x - 1)] + [width - window_size]
62
+
63
+ for y in y_positions:
64
+ for x in x_positions:
65
+ patch = image_np[y : y + window_size, x : x + window_size, :]
66
+ inputs = processor(images=patch, return_tensors="pt").to(device)
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ patch_tokens = outputs.last_hidden_state[:, 1:, :]
70
+ patch_tokens = patch_tokens.reshape(
71
+ 1, window_size // patch_size, window_size // patch_size, hidden_size
72
+ ).squeeze(0)
73
+
74
+ y0, x0 = y // patch_size, x // patch_size
75
+ y1, x1 = y0 + window_size // patch_size, x0 + window_size // patch_size
76
+ token_sum[y0:y1, x0:x1, :] += patch_tokens
77
+ token_count[y0:y1, x0:x1, 0] += 1
78
+
79
+ token_avgs.append(token_sum / token_count)
80
+
81
+ return torch.stack(token_avgs, dim=0)
82
+
83
+
84
+ class LayerNorm2d(nn.Module):
85
+ def __init__(self, channels: int) -> None:
86
+ super().__init__()
87
+ self.norm = nn.LayerNorm([channels])
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ x = x.permute(0, 2, 3, 1)
91
+ x = self.norm(x)
92
+ return x.permute(0, 3, 1, 2)
93
+
94
+
95
+ class IMAA(nn.Module):
96
+ r"""
97
+ Intrinsic Map-Aware Attention (IMAA) gating module.
98
+
99
+ Produces per-map attention biases from DINO patch tokens and learnable map embeddings.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ dino_model: Optional[nn.Module] = None,
105
+ processor=None,
106
+ num_maps: int = 5,
107
+ map_embedding_dim: int = 256,
108
+ common_dim: int = 128,
109
+ conv_channels: Optional[list[int]] = None,
110
+ dino_patch_dim: int = 768,
111
+ ) -> None:
112
+ super().__init__()
113
+ conv_channels = conv_channels or [128, 64]
114
+ self.dino = dino_model
115
+ self.processor = processor
116
+ if self.dino is not None:
117
+ self.dino.eval()
118
+ for param in self.dino.parameters():
119
+ param.requires_grad = False
120
+
121
+ self.num_maps = num_maps
122
+ self.map_embedding_dim = map_embedding_dim
123
+ self.common_dim = common_dim
124
+ self.dino_patch_dim = dino_patch_dim
125
+ self.map_embedding = nn.Parameter(torch.randn(num_maps, map_embedding_dim))
126
+ self.dino_proj = nn.Conv2d(dino_patch_dim, common_dim, kernel_size=1)
127
+ self.map_proj = nn.Linear(map_embedding_dim, common_dim)
128
+ self.fusion_layer = nn.Sequential(
129
+ nn.Conv2d(common_dim * 2, common_dim, 1),
130
+ LayerNorm2d(common_dim),
131
+ nn.ReLU(),
132
+ nn.Conv2d(common_dim, common_dim, 3, padding=1),
133
+ )
134
+
135
+ conv_layers: list[nn.Module] = []
136
+ in_channels = common_dim
137
+ for out_channels in conv_channels:
138
+ conv_layers.extend([nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()])
139
+ in_channels = out_channels
140
+ conv_layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))
141
+ self.conv_head = nn.Sequential(*conv_layers)
142
+
143
+ def forward(
144
+ self,
145
+ image: Optional[torch.Tensor] = None,
146
+ patch_tokens: Optional[torch.Tensor] = None,
147
+ output_size: Optional[Tuple[int, int]] = None,
148
+ map_ids: Optional[torch.Tensor] = None,
149
+ ) -> torch.Tensor:
150
+ if patch_tokens is None:
151
+ if self.dino is None or image is None:
152
+ raise ValueError("Either `patch_tokens` or (`image` and a frozen DINO model) must be provided.")
153
+ patch_tokens = extract_patch_tokens_min_windows(
154
+ image, self.dino, self.processor, window_size=224, device=image.device
155
+ )
156
+
157
+ dino_feat_map = patch_tokens.permute(0, 3, 1, 2)
158
+ dino_proj = self.dino_proj(dino_feat_map)
159
+ map_emb = self.map_embedding[map_ids]
160
+ map_proj = self.map_proj(map_emb).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, dino_proj.size(2), dino_proj.size(3))
161
+ fused_map = self.fusion_layer(torch.cat([dino_proj, map_proj], dim=1))
162
+ raw_gating_map = self.conv_head(fused_map)
163
+ aligned_map = (
164
+ F.interpolate(raw_gating_map, size=output_size, mode="bilinear", align_corners=False)
165
+ if output_size is not None
166
+ else raw_gating_map
167
+ )
168
+ return torch.sigmoid(aligned_map)
169
+
170
+
171
+ def build_attn_mask(
172
+ w_gating: torch.Tensor,
173
+ text_len: int,
174
+ img_len: int,
175
+ lam: float,
176
+ ) -> torch.Tensor:
177
+ r"""
178
+ Build an additive attention mask from IMAA gating weights.
179
+
180
+ Args:
181
+ w_gating (`torch.Tensor`): Gating map `[B, 1, H, W]` or flattened `[B, img_len]`.
182
+ text_len (`int`): Number of text tokens prepended to image tokens.
183
+ img_len (`int`): Expected number of image tokens.
184
+ lam (`float`): Mask scaling factor.
185
+
186
+ Returns:
187
+ Attention bias tensor shaped for SD3 joint attention.
188
+ """
189
+ batch_size = w_gating.shape[0]
190
+ total_len = text_len + img_len
191
+ if w_gating.dim() == 4:
192
+ w_gating = w_gating.view(batch_size, -1)
193
+
194
+ gating = lam * w_gating
195
+ actual_img_len = gating.shape[1]
196
+ if actual_img_len != img_len:
197
+ if actual_img_len > img_len:
198
+ gating = gating[:, :img_len]
199
+ else:
200
+ padding = torch.zeros(batch_size, img_len - actual_img_len, device=gating.device, dtype=gating.dtype)
201
+ gating = torch.cat([gating, padding], dim=1)
202
+
203
+ col_bias = torch.zeros(batch_size, total_len, device=w_gating.device, dtype=w_gating.dtype)
204
+ col_bias[:, text_len:] = gating
205
+ return col_bias.view(batch_size, 1, 1, total_len)
imaa/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5b9f0ff31b173c101a24bcf78cd951639acc72bebdd5231f7a8d500d41b0457
3
+ size 2140596
model_index.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "IntrinsicWeatherInversePipeline"
5
+ ],
6
+ "_diffusers_version": "0.38.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "vae": [
12
+ "diffusers",
13
+ "AutoencoderKL"
14
+ ],
15
+ "text_encoder": [
16
+ "transformers",
17
+ "CLIPTextModelWithProjection"
18
+ ],
19
+ "text_encoder_2": [
20
+ "transformers",
21
+ "CLIPTextModelWithProjection"
22
+ ],
23
+ "text_encoder_3": [
24
+ "transformers",
25
+ "T5EncoderModel"
26
+ ],
27
+ "tokenizer": [
28
+ "transformers",
29
+ "CLIPTokenizer"
30
+ ],
31
+ "tokenizer_2": [
32
+ "transformers",
33
+ "CLIPTokenizer"
34
+ ],
35
+ "tokenizer_3": [
36
+ "transformers",
37
+ "T5TokenizerFast"
38
+ ]
39
+ }
pipeline_intrinsic_weather.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: IntrinsicWeatherPipeline.
2
+ Inverse decomposition + forward weather rendering in one call.
3
+ Load with native Hugging Face diffusers and trust_remote_code=True.
4
+ """
5
+
6
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ from __future__ import annotations
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Dict, List, Optional, Union
24
+
25
+ import numpy as np
26
+ import PIL.Image
27
+ import torch
28
+ import torchvision.transforms as T
29
+ from pathlib import Path
30
+ from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.loaders import SD3LoraLoaderMixin
32
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
34
+ from diffusers.utils import logging, replace_example_docstring
35
+ from transformers import (
36
+ CLIPTextModelWithProjection,
37
+ CLIPTokenizer,
38
+ PreTrainedModel,
39
+ PreTrainedTokenizer,
40
+ T5EncoderModel,
41
+ T5TokenizerFast,
42
+ )
43
+
44
+ from imaa.imaa import IMAA, build_attn_mask, extract_patch_tokens_min_windows
45
+ from pipeline_utils import load_transformer_from_subfolder, load_transformer_lora, resolve_repo_dir
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ AOVS = ["albedo", "normal", "roughness", "metallic", "irradiance"]
50
+ INVERSE_PROMPTS = {
51
+ "albedo": "Albedo (diffuse basecolor)",
52
+ "normal": "Camera-space Normal",
53
+ "roughness": "Roughness",
54
+ "metallic": "Metallicness",
55
+ "irradiance": "Irradiance (lighting)",
56
+ }
57
+ WEATHER_PROMPTS = {
58
+ "rainy": "A rainy day.",
59
+ "sunny": "A sunny day.",
60
+ "snowy": "A snowy day.",
61
+ "foggy": "A foggy day.",
62
+ "overcast": "An overcast day.",
63
+ "night": "A night scene.",
64
+ }
65
+
66
+
67
+ EXAMPLE_DOC_STRING = """
68
+ Examples:
69
+ ```py
70
+ >>> from pathlib import Path
71
+ >>> import torch
72
+ >>> from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
73
+ >>> from transformers import AutoImageProcessor, AutoModel
74
+
75
+ >>> repo_dir = Path("./IntrisicWeather-diffusers").resolve()
76
+ >>> pipe = IntrinsicWeatherPipeline.from_pretrained(
77
+ ... repo_dir,
78
+ ... inverse_transformer_subfolder="inverse-512",
79
+ ... forward_transformer_subfolder="forward",
80
+ ... local_files_only=True,
81
+ ... torch_dtype=torch.float16,
82
+ ... )
83
+ >>> pipe.to("cuda")
84
+ >>> dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
85
+ >>> dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to("cuda")
86
+
87
+ >>> from PIL import Image
88
+ >>> image = Image.open("input.png")
89
+ >>> result = pipe(
90
+ ... image=image,
91
+ ... weather="rainy",
92
+ ... dino_model=dino_model,
93
+ ... dino_processor=dino_processor,
94
+ ... )
95
+ >>> result.images[0].save("rainy.png")
96
+ ```
97
+ """
98
+
99
+
100
+ @dataclass
101
+ class IntrinsicWeatherPipelineOutput(StableDiffusion3PipelineOutput):
102
+ maps: Optional[Dict[str, Union[PIL.Image.Image, np.ndarray]]] = None
103
+
104
+
105
+ class IntrinsicWeatherPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
106
+ r"""
107
+ End-to-end IntrinsicWeather pipeline: RGB → intrinsic maps → weather-conditioned RGB.
108
+
109
+ Parameters:
110
+ inverse_transformer ([`IntrinsicWeatherSD3Transformer2DModel`]):
111
+ Map-aware SD3 transformer for intrinsic decomposition.
112
+ forward_transformer ([`SD3Transformer2DModel`]):
113
+ SD3 transformer for forward weather rendering.
114
+ imaa ([`IMAA`]):
115
+ Map-aware attention gating module used during inverse rendering.
116
+ vae, text_encoder(s), tokenizer(s), scheduler:
117
+ Shared Stable Diffusion 3 components.
118
+ """
119
+
120
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->inverse_transformer->forward_transformer->vae"
121
+ _optional_components: List[str] = []
122
+
123
+ @classmethod
124
+ def from_pretrained(
125
+ cls,
126
+ pretrained_model_name_or_path,
127
+ *args,
128
+ inverse_transformer_subfolder: str = "inverse-512",
129
+ forward_transformer_subfolder: str = "forward",
130
+ inverse_transformer=None,
131
+ forward_transformer=None,
132
+ load_lora: bool = True,
133
+ load_imaa: bool = True,
134
+ **kwargs,
135
+ ):
136
+ repo_dir = resolve_repo_dir(pretrained_model_name_or_path)
137
+ dtype = kwargs.get("torch_dtype", torch.bfloat16)
138
+ device = kwargs.get("device")
139
+ if inverse_transformer is None:
140
+ inverse_transformer = load_transformer_from_subfolder(
141
+ repo_dir,
142
+ inverse_transformer_subfolder,
143
+ dtype=dtype,
144
+ device=device,
145
+ )
146
+ if forward_transformer is None:
147
+ forward_transformer = load_transformer_from_subfolder(
148
+ repo_dir,
149
+ forward_transformer_subfolder,
150
+ dtype=dtype,
151
+ device=device,
152
+ )
153
+
154
+ imaa = IMAA(dino_model=None, processor=None, num_maps=5, map_embedding_dim=256, common_dim=128)
155
+ if device is not None:
156
+ imaa = imaa.to(device)
157
+
158
+ pipe = DiffusionPipeline.from_pretrained(
159
+ repo_dir.as_posix(),
160
+ *args,
161
+ inverse_transformer=inverse_transformer,
162
+ forward_transformer=forward_transformer,
163
+ imaa=imaa,
164
+ custom_pipeline=str(repo_dir / "pipeline_intrinsic_weather.py"),
165
+ trust_remote_code=True,
166
+ **kwargs,
167
+ )
168
+ if load_imaa:
169
+ pipe.load_imaa_weights(repo_dir / "imaa")
170
+ if load_lora:
171
+ load_transformer_lora(pipe._forward, repo_dir, forward_transformer_subfolder)
172
+ if device is not None:
173
+ for name in pipe.components.keys():
174
+ module = getattr(pipe, name, None)
175
+ if module is not None and hasattr(module, "to"):
176
+ if name == "imaa":
177
+ module.to(device=device)
178
+ else:
179
+ module.to(device=device, dtype=dtype)
180
+ return pipe
181
+
182
+ def __init__(
183
+ self,
184
+ inverse_transformer,
185
+ forward_transformer,
186
+ imaa: IMAA,
187
+ scheduler,
188
+ vae,
189
+ text_encoder: CLIPTextModelWithProjection,
190
+ tokenizer: CLIPTokenizer,
191
+ text_encoder_2: CLIPTextModelWithProjection,
192
+ tokenizer_2: CLIPTokenizer,
193
+ text_encoder_3: T5EncoderModel,
194
+ tokenizer_3: T5TokenizerFast,
195
+ ) -> None:
196
+ super().__init__()
197
+ from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
198
+ from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
199
+
200
+ self.register_modules(
201
+ inverse_transformer=inverse_transformer,
202
+ forward_transformer=forward_transformer,
203
+ imaa=imaa,
204
+ scheduler=scheduler,
205
+ vae=vae,
206
+ text_encoder=text_encoder,
207
+ tokenizer=tokenizer,
208
+ text_encoder_2=text_encoder_2,
209
+ tokenizer_2=tokenizer_2,
210
+ text_encoder_3=text_encoder_3,
211
+ tokenizer_3=tokenizer_3,
212
+ )
213
+ shared = dict(
214
+ scheduler=scheduler,
215
+ vae=vae,
216
+ text_encoder=text_encoder,
217
+ tokenizer=tokenizer,
218
+ text_encoder_2=text_encoder_2,
219
+ tokenizer_2=tokenizer_2,
220
+ text_encoder_3=text_encoder_3,
221
+ tokenizer_3=tokenizer_3,
222
+ )
223
+ self._inverse = IntrinsicWeatherInversePipeline(transformer=inverse_transformer, **shared)
224
+ self._forward = IntrinsicWeatherForwardPipeline(transformer=forward_transformer, **shared)
225
+
226
+ def load_imaa_weights(self, imaa_dir: Union[str, Path]) -> None:
227
+ r"""Load IMAA weights from an `imaa/` subfolder produced by the conversion script."""
228
+ imaa_dir = Path(imaa_dir)
229
+ weights_path = imaa_dir / "model.safetensors"
230
+ if weights_path.exists():
231
+ from safetensors.torch import load_file
232
+
233
+ state_dict = load_file(weights_path.as_posix())
234
+ else:
235
+ payload = torch.load(imaa_dir / "pytorch_model.bin", map_location="cpu", weights_only=False)
236
+ state_dict = payload["model_state_dict"] if isinstance(payload, dict) and "model_state_dict" in payload else payload
237
+ self.imaa.load_state_dict(state_dict)
238
+ self.imaa.eval()
239
+
240
+ @staticmethod
241
+ def _resolve_weather_prompt(weather: str) -> str:
242
+ key = weather.lower().strip()
243
+ if key in WEATHER_PROMPTS:
244
+ return WEATHER_PROMPTS[key]
245
+ return weather
246
+
247
+ def _preprocess_rgb(self, image: PipelineImageInput, size: int) -> torch.Tensor:
248
+ if isinstance(image, PIL.Image.Image):
249
+ pil = image.convert("RGB")
250
+ elif isinstance(image, torch.Tensor):
251
+ if image.ndim == 3:
252
+ image = image.unsqueeze(0)
253
+ tensor = image
254
+ if tensor.shape[1] != 3:
255
+ raise ValueError(f"Expected 3 RGB channels, got shape {tuple(tensor.shape)}")
256
+ if tensor.min() < 0:
257
+ tensor = (tensor + 1.0) / 2.0
258
+ return tensor.to(device=self._execution_device, dtype=self.dtype)
259
+ else:
260
+ raise TypeError(f"`image` must be PIL.Image or torch.Tensor, got {type(image)}")
261
+
262
+ transform = T.Compose([T.Resize((size, size), interpolation=T.InterpolationMode.BILINEAR), T.ToTensor()])
263
+ tensor = transform(pil)
264
+ return tensor.unsqueeze(0).to(device=self._execution_device, dtype=self.dtype)
265
+
266
+ @staticmethod
267
+ def _map_array_to_tensor(
268
+ array: Union[np.ndarray, PIL.Image.Image, torch.Tensor],
269
+ size: int,
270
+ device: torch.device,
271
+ dtype: torch.dtype,
272
+ ) -> torch.Tensor:
273
+ if isinstance(array, torch.Tensor):
274
+ tensor = array
275
+ if tensor.ndim == 3:
276
+ tensor = tensor.unsqueeze(0)
277
+ if tensor.min() < 0:
278
+ tensor = (tensor + 1.0) / 2.0
279
+ return tensor.to(device=device, dtype=dtype)
280
+
281
+ if isinstance(array, np.ndarray):
282
+ if array.dtype != np.uint8:
283
+ array = np.clip(array * 255.0, 0, 255).astype(np.uint8)
284
+ pil = PIL.Image.fromarray(array)
285
+ else:
286
+ pil = array.convert("RGB")
287
+
288
+ transform = T.Compose(
289
+ [
290
+ T.Resize((size, size), interpolation=T.InterpolationMode.BILINEAR),
291
+ T.ToTensor(),
292
+ ]
293
+ )
294
+ return transform(pil).unsqueeze(0).to(device=device, dtype=dtype)
295
+
296
+ @torch.no_grad()
297
+ def decompose(
298
+ self,
299
+ image: PipelineImageInput,
300
+ dino_model: PreTrainedModel,
301
+ dino_processor: PreTrainedTokenizer,
302
+ num_inference_steps: int = 50,
303
+ image_size: Optional[int] = None,
304
+ output_type: str = "pil",
305
+ ) -> Dict[str, Union[PIL.Image.Image, np.ndarray]]:
306
+ r"""
307
+ Run inverse rendering and return intrinsic maps.
308
+
309
+ Args:
310
+ image: Input RGB image.
311
+ dino_model: Frozen DINOv2 vision model.
312
+ dino_processor: Matching image processor.
313
+ num_inference_steps: Denoising steps per intrinsic map.
314
+ image_size: Square resize before decomposition. Defaults to input size.
315
+ output_type: `"pil"` or `"np"` per map.
316
+ """
317
+ if image_size is None:
318
+ if isinstance(image, PIL.Image.Image):
319
+ image_size = max(image.size)
320
+ elif isinstance(image, torch.Tensor):
321
+ image_size = image.shape[-1]
322
+ else:
323
+ image_size = 1024
324
+
325
+ image_tensor = self._preprocess_rgb(image, image_size)
326
+ patch_tokens = extract_patch_tokens_min_windows(
327
+ image_tensor,
328
+ dino_model,
329
+ dino_processor,
330
+ window_size=224,
331
+ device=image_tensor.device,
332
+ )
333
+ output_size = (image_tensor.shape[2] // 16, image_tensor.shape[3] // 16)
334
+ img_len = output_size[0] * output_size[1]
335
+
336
+ maps: Dict[str, Union[PIL.Image.Image, np.ndarray]] = {}
337
+ for map_index, aov_name in enumerate(AOVS):
338
+ prompt_embeds, _, pooled_prompt_embeds, _ = self._inverse.encode_prompt(
339
+ prompt=INVERSE_PROMPTS[aov_name],
340
+ prompt_2=None,
341
+ prompt_3=None,
342
+ do_classifier_free_guidance=False,
343
+ )
344
+ map_aware_mask = self.imaa(
345
+ patch_tokens=patch_tokens,
346
+ output_size=output_size,
347
+ map_ids=torch.tensor([map_index], device=image_tensor.device),
348
+ )
349
+ attn_mask = build_attn_mask(map_aware_mask, 154, img_len, 0.7)
350
+ result = self._inverse(
351
+ image=image_tensor,
352
+ prompt_embeds=prompt_embeds,
353
+ pooled_prompt_embeds=pooled_prompt_embeds,
354
+ guidance_scale=0.0,
355
+ image_guidance_scale=0.0,
356
+ num_inference_steps=num_inference_steps,
357
+ output_type=output_type,
358
+ aov=[aov_name],
359
+ map_aware_mask=attn_mask.to(image_tensor.device),
360
+ )
361
+ maps[aov_name] = result.images[0] if output_type == "pil" else result.images[0]
362
+ return maps
363
+
364
+ @torch.no_grad()
365
+ def render(
366
+ self,
367
+ maps: Dict[str, PipelineImageInput],
368
+ weather: str = "rainy",
369
+ num_inference_steps: int = 50,
370
+ guidance_scale: float = 6.0,
371
+ image_guidance_scale: float = 1.5,
372
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
373
+ output_type: str = "pil",
374
+ render_size: Optional[int] = None,
375
+ ) -> Union[PIL.Image.Image, np.ndarray]:
376
+ r"""
377
+ Run forward weather rendering from intrinsic maps.
378
+
379
+ Args:
380
+ maps: Dict with keys `albedo`, `normal`, `roughness`, `metallic`, `irradiance`.
381
+ weather: Preset name (`rainy`, `sunny`, ...) or a custom prompt string.
382
+ """
383
+ if render_size is None:
384
+ sample = next(iter(maps.values()))
385
+ if isinstance(sample, PIL.Image.Image):
386
+ render_size = max(sample.size)
387
+ elif isinstance(sample, np.ndarray):
388
+ render_size = sample.shape[0]
389
+ else:
390
+ render_size = sample.shape[-1]
391
+
392
+ device = self._execution_device
393
+ dtype = self.dtype
394
+ aov_tensors = {
395
+ name: self._map_array_to_tensor(maps[name], render_size, device, dtype) for name in AOVS if name in maps
396
+ }
397
+ for name in AOVS:
398
+ if name not in aov_tensors:
399
+ aov_tensors[name] = torch.zeros((1, 3, render_size, render_size), device=device, dtype=dtype)
400
+
401
+ result = self._forward(
402
+ albedo=aov_tensors["albedo"],
403
+ normal=aov_tensors["normal"],
404
+ roughness=aov_tensors["roughness"],
405
+ metallic=aov_tensors["metallic"],
406
+ irradiance=aov_tensors.get("irradiance"),
407
+ prompt=[self._resolve_weather_prompt(weather)],
408
+ guidance_scale=guidance_scale,
409
+ image_guidance_scale=image_guidance_scale,
410
+ num_inference_steps=num_inference_steps,
411
+ required_aovs=AOVS,
412
+ generator=generator,
413
+ output_type=output_type,
414
+ )
415
+ return result.images[0]
416
+
417
+ @torch.no_grad()
418
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
419
+ def __call__(
420
+ self,
421
+ image: PipelineImageInput,
422
+ weather: str = "rainy",
423
+ dino_model: Optional[PreTrainedModel] = None,
424
+ dino_processor: Optional[PreTrainedTokenizer] = None,
425
+ num_inverse_steps: int = 50,
426
+ num_forward_steps: int = 50,
427
+ guidance_scale: float = 6.0,
428
+ image_guidance_scale: float = 1.5,
429
+ image_size: Optional[int] = None,
430
+ render_size: Optional[int] = None,
431
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
432
+ output_type: str = "pil",
433
+ return_maps: bool = False,
434
+ return_dict: bool = True,
435
+ ):
436
+ r"""
437
+ Decompose an RGB image into intrinsic maps, then render a weather-conditioned RGB image.
438
+
439
+ Examples:
440
+
441
+ Args:
442
+ image: Source RGB photograph or render.
443
+ weather: Weather preset (`rainy`, `sunny`, `snowy`, `foggy`, ...) or custom prompt.
444
+ dino_model (`transformers.PreTrainedModel`, *optional*):
445
+ DINOv2 model for IMAA gating. Required unless maps are passed via `maps=`.
446
+ dino_processor: DINO image processor paired with `dino_model`.
447
+ num_inverse_steps: Denoising steps for each intrinsic map.
448
+ num_forward_steps: Denoising steps for forward rendering.
449
+ guidance_scale: Text CFG for forward rendering.
450
+ image_guidance_scale: Image CFG for forward rendering.
451
+ image_size: Square size for inverse stage.
452
+ render_size: Square size for forward stage. Defaults to `image_size`.
453
+ return_maps: If `True`, also return decomposed intrinsic maps.
454
+ return_dict: Return [`IntrinsicWeatherPipelineOutput`] when `True`.
455
+ """
456
+ if dino_model is None or dino_processor is None:
457
+ raise ValueError("`dino_model` and `dino_processor` are required for end-to-end weather editing.")
458
+
459
+ maps = self.decompose(
460
+ image=image,
461
+ dino_model=dino_model,
462
+ dino_processor=dino_processor,
463
+ num_inference_steps=num_inverse_steps,
464
+ image_size=image_size,
465
+ output_type="np",
466
+ )
467
+ rendered = self.render(
468
+ maps=maps,
469
+ weather=weather,
470
+ num_inference_steps=num_forward_steps,
471
+ guidance_scale=guidance_scale,
472
+ image_guidance_scale=image_guidance_scale,
473
+ generator=generator,
474
+ output_type=output_type,
475
+ render_size=render_size or image_size,
476
+ )
477
+
478
+ self.maybe_free_model_hooks()
479
+
480
+ if not return_dict:
481
+ return (rendered, maps) if return_maps else (rendered,)
482
+
483
+ return IntrinsicWeatherPipelineOutput(
484
+ images=[rendered] if output_type == "pil" else rendered,
485
+ maps=maps if return_maps else None,
486
+ )
pipeline_intrinsic_weather_forward.py ADDED
@@ -0,0 +1,1191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: IntrinsicWeatherForwardPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from __future__ import annotations
20
+
21
+ import inspect
22
+ from pathlib import Path
23
+ from typing import Any, Callable, Dict, List, Optional, Union
24
+
25
+ import PIL.Image
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torchvision
29
+ from transformers import (
30
+ CLIPTextModelWithProjection,
31
+ CLIPTokenizer,
32
+ T5EncoderModel,
33
+ T5TokenizerFast,
34
+ )
35
+
36
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
37
+ from diffusers.loaders import SD3LoraLoaderMixin
38
+ from diffusers.models.autoencoders import AutoencoderKL
39
+ from diffusers.models.transformers import SD3Transformer2DModel
40
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
41
+ from diffusers.utils import (
42
+ deprecate,
43
+ is_torch_xla_available,
44
+ logging,
45
+ replace_example_docstring,
46
+ )
47
+ from diffusers.utils.torch_utils import randn_tensor
48
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
49
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
50
+
51
+ from pipeline_utils import (
52
+ load_transformer_from_subfolder,
53
+ load_transformer_lora,
54
+ resolve_repo_dir,
55
+ resolve_transformer_lora_dir,
56
+ set_flow_timesteps,
57
+ )
58
+
59
+
60
+ if is_torch_xla_available():
61
+ import torch_xla.core.xla_model as xm
62
+
63
+ XLA_AVAILABLE = True
64
+ else:
65
+ XLA_AVAILABLE = False
66
+
67
+
68
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
69
+
70
+ EXAMPLE_DOC_STRING = """
71
+ Examples:
72
+ ```py
73
+ >>> from pathlib import Path
74
+ >>> import torch
75
+ >>> from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
76
+
77
+ >>> repo_dir = Path("./IntrisicWeather-diffusers").resolve()
78
+ >>> pipe = IntrinsicWeatherForwardPipeline.from_pretrained(
79
+ ... repo_dir,
80
+ ... transformer_subfolder="forward",
81
+ ... local_files_only=True,
82
+ ... torch_dtype=torch.float16,
83
+ ... load_lora=True,
84
+ ... )
85
+ >>> pipe.to("cuda")
86
+ >>> image = pipe(
87
+ ... prompt="A rainy day.",
88
+ ... albedo=albedo_tensor,
89
+ ... normal=normal_tensor,
90
+ ... roughness=roughness_tensor,
91
+ ... metallic=metallic_tensor,
92
+ ... irradiance=irradiance_tensor,
93
+ ... num_inference_steps=50,
94
+ ... ).images[0]
95
+ ```
96
+ """
97
+
98
+
99
+
100
+
101
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
102
+ def retrieve_latents(
103
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
104
+ ):
105
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
106
+ return encoder_output.latent_dist.sample(generator)
107
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
108
+ return encoder_output.latent_dist.mode()
109
+ elif hasattr(encoder_output, "latents"):
110
+ return encoder_output.latents
111
+ else:
112
+ raise AttributeError("Could not access latents of provided encoder_output")
113
+
114
+ class VaeAoVProcessor(VaeImageProcessor):
115
+ def postprocess(
116
+ self,
117
+ image: torch.FloatTensor,
118
+ output_type: str = "pil",
119
+ do_denormalize: Optional[List[bool]] = None,
120
+ do_gamma_correction: bool = True,
121
+ ):
122
+ if not isinstance(image, torch.Tensor):
123
+ raise ValueError(
124
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
125
+ )
126
+ if output_type not in ["latent", "pt", "np", "pil"]:
127
+ deprecation_message = (
128
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
129
+ "`pil`, `np`, `pt`, `latent`"
130
+ )
131
+ deprecate(
132
+ "Unsupported output_type",
133
+ "1.0.0",
134
+ deprecation_message,
135
+ standard_warn=False,
136
+ )
137
+ output_type = "np"
138
+
139
+ if output_type == "latent":
140
+ return image
141
+
142
+ if do_denormalize is None:
143
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
144
+
145
+ image = torch.stack(
146
+ [
147
+ self.denormalize(image[i]) if do_denormalize[i] else image[i]
148
+ for i in range(image.shape[0])
149
+ ]
150
+ )
151
+
152
+ # Gamma correction
153
+ if do_gamma_correction:
154
+ # image = torch.pow(image, 1.0 / 2.2)
155
+ image = image ** (1 / 2.2)
156
+
157
+ if output_type == "pt":
158
+ return image
159
+
160
+ image = self.pt_to_numpy(image)
161
+
162
+ if output_type == "np":
163
+ return image
164
+
165
+ if output_type == "pil":
166
+ return self.numpy_to_pil(image)
167
+
168
+ class IntrinsicWeatherForwardPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
169
+ r"""
170
+ Args:
171
+ transformer ([`SD3Transformer2DModel`]):
172
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
173
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
174
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
175
+ vae ([`AutoencoderKL`]):
176
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
177
+ text_encoder ([`CLIPTextModelWithProjection`]):
178
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
179
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
180
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
181
+ as its dimension.
182
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
183
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
184
+ specifically the
185
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
186
+ variant.
187
+ text_encoder_3 ([`T5EncoderModel`]):
188
+ Frozen text-encoder. Stable Diffusion 3 uses
189
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
190
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
191
+ tokenizer (`CLIPTokenizer`):
192
+ Tokenizer of class
193
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
194
+ tokenizer_2 (`CLIPTokenizer`):
195
+ Second Tokenizer of class
196
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
197
+ tokenizer_3 (`T5TokenizerFast`):
198
+ Tokenizer of class
199
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
200
+ """
201
+
202
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
203
+ _optional_components = []
204
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
205
+
206
+ @classmethod
207
+ def load_transformer(
208
+ cls,
209
+ transformer_subfolder: str,
210
+ pretrained_model_name_or_path: str | Path,
211
+ *,
212
+ dtype: torch.dtype = torch.bfloat16,
213
+ device: str | torch.device | None = None,
214
+ ):
215
+ return load_transformer_from_subfolder(
216
+ pretrained_model_name_or_path,
217
+ transformer_subfolder,
218
+ dtype=dtype,
219
+ device=device,
220
+ )
221
+
222
+ @classmethod
223
+ def from_pretrained(
224
+ cls,
225
+ pretrained_model_name_or_path,
226
+ *args,
227
+ transformer_subfolder: str = "forward",
228
+ transformer=None,
229
+ load_lora: bool = True,
230
+ **kwargs,
231
+ ):
232
+ repo_dir = resolve_repo_dir(pretrained_model_name_or_path)
233
+ dtype = kwargs.get("torch_dtype", torch.bfloat16)
234
+ device = kwargs.get("device")
235
+ if transformer is None:
236
+ transformer = cls.load_transformer(
237
+ transformer_subfolder,
238
+ repo_dir,
239
+ dtype=dtype,
240
+ device=device,
241
+ )
242
+ pipe = DiffusionPipeline.from_pretrained(
243
+ repo_dir.as_posix(),
244
+ *args,
245
+ transformer=transformer,
246
+ custom_pipeline=str(repo_dir / "pipeline_intrinsic_weather_forward.py"),
247
+ trust_remote_code=True,
248
+ **kwargs,
249
+ )
250
+ lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder)
251
+ if load_lora and lora_dir is not None:
252
+ pipe.load_lora_weights(lora_dir.as_posix())
253
+ if device is not None:
254
+ pipe = pipe.to(device)
255
+ return pipe
256
+
257
+ def __init__(
258
+ self,
259
+ transformer: SD3Transformer2DModel,
260
+ scheduler: FlowMatchEulerDiscreteScheduler,
261
+ vae: AutoencoderKL,
262
+ text_encoder: CLIPTextModelWithProjection,
263
+ tokenizer: CLIPTokenizer,
264
+ text_encoder_2: CLIPTextModelWithProjection,
265
+ tokenizer_2: CLIPTokenizer,
266
+ text_encoder_3: T5EncoderModel,
267
+ tokenizer_3: T5TokenizerFast,
268
+ ):
269
+ super().__init__()
270
+
271
+ self.register_modules(
272
+ vae=vae,
273
+ text_encoder=text_encoder,
274
+ text_encoder_2=text_encoder_2,
275
+ text_encoder_3=text_encoder_3,
276
+ tokenizer=tokenizer,
277
+ tokenizer_2=tokenizer_2,
278
+ tokenizer_3=tokenizer_3,
279
+ transformer=transformer,
280
+ scheduler=scheduler,
281
+ )
282
+ self.vae_scale_factor = (
283
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
284
+ )
285
+ self.image_processor = VaeAoVProcessor(vae_scale_factor=self.vae_scale_factor)
286
+ self.tokenizer_max_length = (
287
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
288
+ )
289
+ self.default_sample_size = (
290
+ self.transformer.config.sample_size
291
+ if hasattr(self, "transformer") and self.transformer is not None
292
+ else 128
293
+ )
294
+
295
+ def _get_t5_prompt_embeds(
296
+ self,
297
+ prompt: Union[str, List[str]] = None,
298
+ num_images_per_prompt: int = 1,
299
+ device: Optional[torch.device] = None,
300
+ dtype: Optional[torch.dtype] = None,
301
+ ):
302
+ device = device or self._execution_device
303
+ dtype = dtype or self.text_encoder.dtype
304
+
305
+ prompt = [prompt] if isinstance(prompt, str) else prompt
306
+ batch_size = len(prompt)
307
+
308
+ if self.text_encoder_3 is None:
309
+ return torch.zeros(
310
+ (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
311
+ device=device,
312
+ dtype=dtype,
313
+ )
314
+
315
+ text_inputs = self.tokenizer_3(
316
+ prompt,
317
+ padding="max_length",
318
+ max_length=self.tokenizer_max_length,
319
+ truncation=True,
320
+ add_special_tokens=True,
321
+ return_tensors="pt",
322
+ )
323
+ text_input_ids = text_inputs.input_ids
324
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
325
+
326
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
327
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
328
+ logger.warning(
329
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
330
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
331
+ )
332
+
333
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
334
+
335
+ dtype = self.text_encoder_3.dtype
336
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
337
+
338
+ _, seq_len, _ = prompt_embeds.shape
339
+
340
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
341
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
342
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
343
+
344
+ return prompt_embeds
345
+
346
+ def _get_clip_prompt_embeds(
347
+ self,
348
+ prompt: Union[str, List[str]],
349
+ num_images_per_prompt: int = 1,
350
+ device: Optional[torch.device] = None,
351
+ clip_skip: Optional[int] = None,
352
+ clip_model_index: int = 0,
353
+ ):
354
+ device = device or self._execution_device
355
+
356
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
357
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
358
+
359
+ tokenizer = clip_tokenizers[clip_model_index]
360
+ text_encoder = clip_text_encoders[clip_model_index]
361
+
362
+ prompt = [prompt] if isinstance(prompt, str) else prompt
363
+ batch_size = len(prompt)
364
+
365
+ text_inputs = tokenizer(
366
+ prompt,
367
+ padding="max_length",
368
+ max_length=self.tokenizer_max_length,
369
+ truncation=True,
370
+ return_tensors="pt",
371
+ )
372
+
373
+ text_input_ids = text_inputs.input_ids
374
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
375
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
376
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
377
+ logger.warning(
378
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
379
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
380
+ )
381
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
382
+ pooled_prompt_embeds = prompt_embeds[0]
383
+
384
+ if clip_skip is None:
385
+ prompt_embeds = prompt_embeds.hidden_states[-2]
386
+ else:
387
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
388
+
389
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
390
+
391
+ _, seq_len, _ = prompt_embeds.shape
392
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
393
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
394
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
395
+
396
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
397
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
398
+
399
+ return prompt_embeds, pooled_prompt_embeds
400
+
401
+ def encode_prompt(
402
+ self,
403
+ prompt: Union[str, List[str]],
404
+ prompt_2: Union[str, List[str]],
405
+ prompt_3: Union[str, List[str]],
406
+ device: Optional[torch.device] = None,
407
+ num_images_per_prompt: int = 1,
408
+ do_classifier_free_guidance: bool = True,
409
+ negative_prompt: Optional[Union[str, List[str]]] = None,
410
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
411
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
412
+ prompt_embeds: Optional[torch.FloatTensor] = None,
413
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
414
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
415
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
416
+ clip_skip: Optional[int] = None,
417
+ ):
418
+ r"""
419
+
420
+ Args:
421
+ prompt (`str` or `List[str]`, *optional*):
422
+ prompt to be encoded
423
+ prompt_2 (`str` or `List[str]`, *optional*):
424
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
425
+ used in all text-encoders
426
+ prompt_3 (`str` or `List[str]`, *optional*):
427
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
428
+ used in all text-encoders
429
+ device: (`torch.device`):
430
+ torch device
431
+ num_images_per_prompt (`int`):
432
+ number of images that should be generated per prompt
433
+ do_classifier_free_guidance (`bool`):
434
+ whether to use classifier free guidance or not
435
+ negative_prompt (`str` or `List[str]`, *optional*):
436
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
437
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
438
+ less than `1`).
439
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
440
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
441
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
442
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
443
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
444
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
445
+ prompt_embeds (`torch.FloatTensor`, *optional*):
446
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
447
+ provided, text embeddings will be generated from `prompt` input argument.
448
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
449
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
450
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
451
+ argument.
452
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
453
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
454
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
455
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
456
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
457
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
458
+ input argument.
459
+ clip_skip (`int`, *optional*):
460
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
461
+ the output of the pre-final layer will be used for computing the prompt embeddings.
462
+ """
463
+ device = device or self._execution_device
464
+
465
+ prompt = [prompt] if isinstance(prompt, str) else prompt
466
+ if prompt is not None:
467
+ batch_size = len(prompt)
468
+ else:
469
+ batch_size = prompt_embeds.shape[0]
470
+
471
+ if prompt_embeds is None:
472
+ prompt_2 = prompt_2 or prompt
473
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
474
+
475
+ prompt_3 = prompt_3 or prompt
476
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
477
+
478
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
479
+ prompt=prompt,
480
+ device=device,
481
+ num_images_per_prompt=num_images_per_prompt,
482
+ clip_skip=clip_skip,
483
+ clip_model_index=0,
484
+ )
485
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
486
+ prompt=prompt_2,
487
+ device=device,
488
+ num_images_per_prompt=num_images_per_prompt,
489
+ clip_skip=clip_skip,
490
+ clip_model_index=1,
491
+ )
492
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
493
+
494
+ t5_prompt_embed = self._get_t5_prompt_embeds(
495
+ prompt=prompt_3,
496
+ num_images_per_prompt=num_images_per_prompt,
497
+ device=device,
498
+ )
499
+
500
+ clip_prompt_embeds = torch.nn.functional.pad(
501
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
502
+ )
503
+
504
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
505
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
506
+
507
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
508
+ negative_prompt = negative_prompt or ""
509
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
510
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
511
+
512
+ # normalize str to list
513
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
514
+ negative_prompt_2 = (
515
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
516
+ )
517
+ negative_prompt_3 = (
518
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
519
+ )
520
+
521
+ if prompt is not None and type(prompt) is not type(negative_prompt):
522
+ raise TypeError(
523
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
524
+ f" {type(prompt)}."
525
+ )
526
+ elif batch_size != len(negative_prompt):
527
+ raise ValueError(
528
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
529
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
530
+ " the batch size of `prompt`."
531
+ )
532
+
533
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
534
+ negative_prompt,
535
+ device=device,
536
+ num_images_per_prompt=num_images_per_prompt,
537
+ clip_skip=None,
538
+ clip_model_index=0,
539
+ )
540
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
541
+ negative_prompt_2,
542
+ device=device,
543
+ num_images_per_prompt=num_images_per_prompt,
544
+ clip_skip=None,
545
+ clip_model_index=1,
546
+ )
547
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
548
+
549
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
550
+ prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
551
+ )
552
+
553
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
554
+ negative_clip_prompt_embeds,
555
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
556
+ )
557
+
558
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
559
+ negative_pooled_prompt_embeds = torch.cat(
560
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
561
+ )
562
+
563
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
564
+
565
+ def check_inputs(
566
+ self,
567
+ prompt,
568
+ prompt_2,
569
+ prompt_3,
570
+ # height,
571
+ # width,
572
+ negative_prompt=None,
573
+ negative_prompt_2=None,
574
+ negative_prompt_3=None,
575
+ prompt_embeds=None,
576
+ negative_prompt_embeds=None,
577
+ pooled_prompt_embeds=None,
578
+ negative_pooled_prompt_embeds=None,
579
+ callback_on_step_end_tensor_inputs=None,
580
+ ):
581
+ # if height % 8 != 0 or width % 8 != 0:
582
+ # raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
583
+
584
+ if callback_on_step_end_tensor_inputs is not None and not all(
585
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
586
+ ):
587
+ raise ValueError(
588
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
589
+ )
590
+
591
+ if prompt is not None and prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
594
+ " only forward one of the two."
595
+ )
596
+ elif prompt_2 is not None and prompt_embeds is not None:
597
+ raise ValueError(
598
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
599
+ " only forward one of the two."
600
+ )
601
+ elif prompt_3 is not None and prompt_embeds is not None:
602
+ raise ValueError(
603
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
604
+ " only forward one of the two."
605
+ )
606
+ elif prompt is None and prompt_embeds is None:
607
+ raise ValueError(
608
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
609
+ )
610
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
611
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
612
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
613
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
614
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
615
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
616
+
617
+ if negative_prompt is not None and negative_prompt_embeds is not None:
618
+ raise ValueError(
619
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
620
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
621
+ )
622
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
623
+ raise ValueError(
624
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
625
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
626
+ )
627
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
628
+ raise ValueError(
629
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
630
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
631
+ )
632
+
633
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
634
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
635
+ raise ValueError(
636
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
637
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
638
+ f" {negative_prompt_embeds.shape}."
639
+ )
640
+
641
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
642
+ raise ValueError(
643
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
644
+ )
645
+
646
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
647
+ raise ValueError(
648
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
649
+ )
650
+
651
+ def prepare_latents(
652
+ self,
653
+ batch_size,
654
+ num_channels_latents,
655
+ height,
656
+ width,
657
+ dtype,
658
+ device,
659
+ generator,
660
+ latents=None,
661
+ ):
662
+ if latents is not None:
663
+ return latents.to(device=device, dtype=dtype)
664
+
665
+ shape = (
666
+ batch_size,
667
+ num_channels_latents,
668
+ int(height) // self.vae_scale_factor,
669
+ int(width) // self.vae_scale_factor,
670
+ )
671
+
672
+ if isinstance(generator, list) and len(generator) != batch_size:
673
+ raise ValueError(
674
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
675
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
676
+ )
677
+
678
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
679
+
680
+ return latents
681
+
682
+ def prepare_image_latents(
683
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
684
+ ):
685
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
686
+ raise ValueError(
687
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
688
+ )
689
+
690
+ image = image.to(device=device, dtype=dtype)
691
+
692
+ batch_size = batch_size * num_images_per_prompt
693
+
694
+ if image.shape[1] == self.vae.config.latent_channels:
695
+ image_latents = image
696
+ else:
697
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
698
+ # ? normalize image latents
699
+ # image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
700
+
701
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
702
+ # expand image_latents for batch_size
703
+ deprecation_message = (
704
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
705
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
706
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
707
+ " your script to pass as many initial images as text prompts to suppress this warning."
708
+ )
709
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
710
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
711
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
712
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
713
+ raise ValueError(
714
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
715
+ )
716
+ else:
717
+ image_latents = torch.cat([image_latents], dim=0)
718
+
719
+ if do_classifier_free_guidance:
720
+ uncond_image_latents = torch.zeros_like(image_latents)
721
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
722
+
723
+ return image_latents
724
+
725
+
726
+ @property
727
+ def guidance_scale(self):
728
+ return self._guidance_scale
729
+ @property
730
+ def image_guidance_scale(self):
731
+ return self._image_guidance_scale
732
+
733
+ @property
734
+ def clip_skip(self):
735
+ return self._clip_skip
736
+
737
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
738
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
739
+ # corresponds to doing no classifier free guidance.
740
+ @property
741
+ def do_classifier_free_guidance(self):
742
+ return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
743
+
744
+ @property
745
+ def joint_attention_kwargs(self):
746
+ return self._joint_attention_kwargs
747
+
748
+ @property
749
+ def num_timesteps(self):
750
+ return self._num_timesteps
751
+
752
+ @property
753
+ def interrupt(self):
754
+ return self._interrupt
755
+
756
+ @torch.no_grad()
757
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
758
+ def __call__(
759
+ self,
760
+ prompt: Union[str, List[str]] = None,
761
+ prompt_2: Optional[Union[str, List[str]]] = None,
762
+ prompt_3: Optional[Union[str, List[str]]] = None,
763
+ albedo: PipelineImageInput = None,
764
+ normal : PipelineImageInput = None,
765
+ metallic : PipelineImageInput = None,
766
+ roughness : PipelineImageInput = None,
767
+ irradiance : PipelineImageInput = None,
768
+ height: Optional[int] = None,
769
+ width: Optional[int] = None,
770
+ num_inference_steps: int = 28,
771
+ timesteps: List[int] = None,
772
+ guidance_scale: float = 7.0,
773
+ image_guidance_scale: float = 1.5,
774
+ negative_prompt: Optional[Union[str, List[str]]] = None,
775
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
776
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
777
+ num_images_per_prompt: Optional[int] = 1,
778
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
779
+ latents: Optional[torch.FloatTensor] = None,
780
+ prompt_embeds: Optional[torch.FloatTensor] = None,
781
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
782
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
783
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
784
+ output_type: Optional[str] = "pil",
785
+ return_dict: bool = True,
786
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
787
+ clip_skip: Optional[int] = None,
788
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
789
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
790
+ mask_img: Optional[PipelineImageInput] = None,
791
+ required_aovs: List[str] = ["albedo"],
792
+ **kwargs
793
+ ):
794
+ r"""
795
+ Function invoked when calling the pipeline for generation.
796
+
797
+ Args:
798
+ prompt (`str` or `List[str]`, *optional*):
799
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
800
+ instead.
801
+ prompt_2 (`str` or `List[str]`, *optional*):
802
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
803
+ will be used instead
804
+ prompt_3 (`str` or `List[str]`, *optional*):
805
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
806
+ will be used instead
807
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
808
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
809
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
810
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
811
+ num_inference_steps (`int`, *optional*, defaults to 50):
812
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
813
+ expense of slower inference.
814
+ timesteps (`List[int]`, *optional*):
815
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
816
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
817
+ passed will be used. Must be in descending order.
818
+ guidance_scale (`float`, *optional*, defaults to 5.0):
819
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
820
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
821
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
822
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
823
+ usually at the expense of lower image quality.
824
+ negative_prompt (`str` or `List[str]`, *optional*):
825
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
826
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
827
+ less than `1`).
828
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
829
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
830
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
831
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
832
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
833
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
834
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
835
+ The number of images to generate per prompt.
836
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
837
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
838
+ to make generation deterministic.
839
+ latents (`torch.FloatTensor`, *optional*):
840
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
841
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
842
+ tensor will ge generated by sampling using the supplied random `generator`.
843
+ prompt_embeds (`torch.FloatTensor`, *optional*):
844
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
845
+ provided, text embeddings will be generated from `prompt` input argument.
846
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
847
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
848
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
849
+ argument.
850
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
851
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
852
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
853
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
854
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
855
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
856
+ input argument.
857
+ output_type (`str`, *optional*, defaults to `"pil"`):
858
+ The output format of the generate image. Choose between
859
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
860
+ return_dict (`bool`, *optional*, defaults to `True`):
861
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
862
+ of a plain tuple.
863
+ joint_attention_kwargs (`dict`, *optional*):
864
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
865
+ `self.processor` in
866
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
867
+ callback_on_step_end (`Callable`, *optional*):
868
+ A function that calls at the end of each denoising steps during the inference. The function is called
869
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
870
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
871
+ `callback_on_step_end_tensor_inputs`.
872
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
873
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
874
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
875
+ `._callback_tensor_inputs` attribute of your pipeline class.
876
+
877
+ Examples:
878
+
879
+ Returns:
880
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
881
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
882
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
883
+ """
884
+
885
+ # height = height or self.default_sample_size * self.vae_scale_factor
886
+ # width = width or self.default_sample_size * self.vae_scale_factor
887
+
888
+ # 1. Check inputs. Raise error if not correct
889
+ self.check_inputs(
890
+ prompt,
891
+ prompt_2,
892
+ prompt_3,
893
+ negative_prompt=negative_prompt,
894
+ negative_prompt_2=negative_prompt_2,
895
+ negative_prompt_3=negative_prompt_3,
896
+ prompt_embeds=prompt_embeds,
897
+ negative_prompt_embeds=negative_prompt_embeds,
898
+ pooled_prompt_embeds=pooled_prompt_embeds,
899
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
900
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
901
+ )
902
+
903
+ self._guidance_scale = guidance_scale
904
+ self._image_guidance_scale = image_guidance_scale
905
+ self._clip_skip = clip_skip
906
+ self._joint_attention_kwargs = joint_attention_kwargs
907
+ self._interrupt = False
908
+
909
+ # 2. Define call parameters
910
+ if prompt is not None and isinstance(prompt, str):
911
+ batch_size = 1
912
+ elif prompt is not None and isinstance(prompt, list):
913
+ batch_size = len(prompt)
914
+ else:
915
+ batch_size = prompt_embeds.shape[0]
916
+
917
+ device = self._execution_device
918
+
919
+ (
920
+ prompt_embeds,
921
+ negative_prompt_embeds,
922
+ pooled_prompt_embeds,
923
+ negative_pooled_prompt_embeds,
924
+ ) = self.encode_prompt(
925
+ prompt=prompt,
926
+ prompt_2=prompt_2,
927
+ prompt_3=prompt_3,
928
+ negative_prompt=negative_prompt,
929
+ negative_prompt_2=negative_prompt_2,
930
+ negative_prompt_3=negative_prompt_3,
931
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
932
+ prompt_embeds=prompt_embeds,
933
+ negative_prompt_embeds=negative_prompt_embeds,
934
+ pooled_prompt_embeds=pooled_prompt_embeds,
935
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
936
+ device=device,
937
+ clip_skip=self.clip_skip,
938
+ num_images_per_prompt=num_images_per_prompt,
939
+ )
940
+ # print("prompt:", prompt_embeds.shape)
941
+ if self.do_classifier_free_guidance:
942
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
943
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
944
+
945
+ # Similiarly
946
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0)
947
+
948
+ # if self.do_classifier_free_guidance:
949
+ # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
950
+ # pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
951
+
952
+ # 3. Preprocess image
953
+ # Get image dimensions from the first available AOV (albedo, normal, roughness, metallic, irradiance)
954
+ old_height, old_width = None, None
955
+ aov_dict = {"albedo": albedo, "normal": normal, "roughness": roughness, "metallic": metallic, "irradiance": irradiance}
956
+ for aov_name in required_aovs:
957
+ aov_value = aov_dict.get(aov_name)
958
+ if aov_value is not None:
959
+ if isinstance(aov_value, PIL.Image.Image):
960
+ old_width, old_height = aov_value.size
961
+ break
962
+ elif isinstance(aov_value, torch.Tensor):
963
+ if len(aov_value.shape) == 4:
964
+ old_height = aov_value.shape[2]
965
+ old_width = aov_value.shape[3]
966
+ else:
967
+ old_height = aov_value.shape[1]
968
+ old_width = aov_value.shape[2]
969
+ break
970
+
971
+ # Fallback: if no AOV available, use default dimensions
972
+ if old_height is None or old_width is None:
973
+ old_height, old_width = 1024, 1024 # default fallback
974
+
975
+
976
+ # For all aovs, the preprocessing remap the values to [-1, 1]
977
+ preprocessed_aovs = {}
978
+ for aov_name in required_aovs:
979
+ if aov_name == "albedo":
980
+ if albedo is not None:
981
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
982
+ albedo
983
+ )
984
+ else:
985
+ preprocessed_aovs[aov_name] = None
986
+
987
+ if aov_name == "normal":
988
+ if normal is not None:
989
+ preprocessed_aovs[aov_name] = (
990
+ self.image_processor.preprocess(normal)
991
+ )
992
+ else:
993
+ preprocessed_aovs[aov_name] = None
994
+
995
+ if aov_name == "roughness":
996
+ if roughness is not None:
997
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
998
+ roughness
999
+ )
1000
+ else:
1001
+ preprocessed_aovs[aov_name] = None
1002
+ if aov_name == "metallic":
1003
+ if metallic is not None:
1004
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
1005
+ metallic
1006
+ )
1007
+ else:
1008
+ preprocessed_aovs[aov_name] = None
1009
+ if aov_name == "irradiance":
1010
+ if irradiance is not None:
1011
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
1012
+ irradiance
1013
+ )
1014
+ else:
1015
+ preprocessed_aovs[aov_name] = None
1016
+ # print("aovs:", preprocessed_aovs.keys())
1017
+ # 4. Prepare latent variables
1018
+ num_channels_latents = self.vae.config.latent_channels
1019
+ # height, width = image_latents.shape[-2:]
1020
+ height = old_height / 8
1021
+ width = old_width / 8
1022
+ height = height * self.vae_scale_factor
1023
+ width = width * self.vae_scale_factor
1024
+ latents = self.prepare_latents(
1025
+ batch_size * num_images_per_prompt,
1026
+ num_channels_latents,
1027
+ height,
1028
+ width,
1029
+ prompt_embeds.dtype,
1030
+ device,
1031
+ generator,
1032
+ latents,
1033
+ )
1034
+ height_latent, width_latent = latents.shape[-2:]
1035
+
1036
+ # 5. Prepare timesteps
1037
+ set_flow_timesteps(
1038
+ self.scheduler,
1039
+ self.transformer,
1040
+ num_inference_steps,
1041
+ height_latent,
1042
+ width_latent,
1043
+ device,
1044
+ )
1045
+ timesteps = self.scheduler.timesteps
1046
+
1047
+ # timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1048
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1049
+ self._num_timesteps = len(timesteps)
1050
+
1051
+ # 6. Prepare Image latent
1052
+ image_latents = []
1053
+ for aov_name, aov in preprocessed_aovs.items():
1054
+ # print(aov_name)
1055
+ if aov is None:
1056
+ image_latent = torch.zeros(
1057
+ batch_size,
1058
+ num_channels_latents,
1059
+ height_latent,
1060
+ width_latent,
1061
+ dtype=prompt_embeds.dtype,
1062
+ device=device,
1063
+ )
1064
+ # print(image_latent.shape)
1065
+ if self.do_classifier_free_guidance:
1066
+ image_latents.append(
1067
+ torch.cat([image_latent, image_latent, image_latent], dim=0)
1068
+ )
1069
+ else:
1070
+ image_latents.append(image_latent)
1071
+ else:
1072
+ image_latent = (
1073
+ self.prepare_image_latents(
1074
+ aov,
1075
+ batch_size,
1076
+ num_images_per_prompt,
1077
+ prompt_embeds.dtype,
1078
+ device,
1079
+ self.do_classifier_free_guidance,
1080
+ generator,
1081
+ )
1082
+ )
1083
+ image_latents.append(image_latent)
1084
+ image_latents = torch.cat(image_latents, dim=1)
1085
+
1086
+ # 7. Check that shapes of latents and image match the DIT in_channels
1087
+ num_channels_image = image_latents.shape[1]
1088
+
1089
+ if num_channels_latents + num_channels_image != self.transformer.config.in_channels:
1090
+ raise ValueError(
1091
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1092
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1093
+ f" `num_channels_image`: {num_channels_image} "
1094
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
1095
+ " `pipeline.transformer` or your `image` input."
1096
+ )
1097
+
1098
+ # 8. Denoising loop
1099
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1100
+ for i, t in enumerate(timesteps):
1101
+ if self.interrupt:
1102
+ continue
1103
+
1104
+ # expand the latents if we are doing classifier free guidance
1105
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
1106
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1107
+ timestep = t.expand(latent_model_input.shape[0])
1108
+
1109
+ scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
1110
+
1111
+ # if "mask_index" in kwargs and kwargs['mask_index'] is not None:
1112
+ # mask_index = kwargs['mask_index']
1113
+ # else:
1114
+ # mask_index = None
1115
+ noise_pred = self.transformer(
1116
+ hidden_states=scaled_latent_model_input,
1117
+ timestep=timestep,
1118
+ encoder_hidden_states=prompt_embeds,
1119
+ pooled_projections=pooled_prompt_embeds,
1120
+ joint_attention_kwargs=self.joint_attention_kwargs,
1121
+ return_dict=False,
1122
+ # mask_index= mask_index,
1123
+ )[0]
1124
+
1125
+ # perform guidance
1126
+ if self.do_classifier_free_guidance:
1127
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
1128
+ noise_pred = (
1129
+ noise_pred_uncond
1130
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
1131
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
1132
+ )
1133
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # neg, prompt
1134
+ # noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1135
+
1136
+ # compute the previous noisy sample x_t -> x_t-1
1137
+ latents_dtype = latents.dtype
1138
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1139
+
1140
+ if latents.dtype != latents_dtype:
1141
+ if torch.backends.mps.is_available():
1142
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1143
+ latents = latents.to(latents_dtype)
1144
+
1145
+ if callback_on_step_end is not None:
1146
+ callback_kwargs = {}
1147
+ for k in callback_on_step_end_tensor_inputs:
1148
+ callback_kwargs[k] = locals()[k]
1149
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1150
+
1151
+ latents = callback_outputs.pop("latents", latents)
1152
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1153
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1154
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1155
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1156
+ )
1157
+ image_latents = callback_outputs.pop("image_latents", image_latents)
1158
+ if mask_img is not None:
1159
+ mask_image_latents = callback_outputs.pop("mask_image_latents", mask_image_latents)
1160
+ # call the callback, if provided
1161
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1162
+ progress_bar.update()
1163
+
1164
+ if XLA_AVAILABLE:
1165
+ xm.mark_step()
1166
+
1167
+ if output_type == "latent":
1168
+ image = latents
1169
+
1170
+ else:
1171
+ # latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1172
+ latents = latents / self.vae.config.scaling_factor
1173
+ image = self.vae.decode(latents, return_dict=False)[0]
1174
+
1175
+ do_denormalize = [True] * image.shape[0]
1176
+
1177
+ image = torchvision.transforms.Resize((old_height, old_width),interpolation=PIL.Image.BICUBIC)(image)
1178
+ image = self.image_processor.postprocess(
1179
+ image,
1180
+ output_type=output_type,
1181
+ do_denormalize=do_denormalize,
1182
+ do_gamma_correction=False
1183
+ )
1184
+
1185
+ # Offload all models
1186
+ self.maybe_free_model_hooks()
1187
+
1188
+ if not return_dict:
1189
+ return (image,)
1190
+
1191
+ return StableDiffusion3PipelineOutput(images=image)
pipeline_intrinsic_weather_inverse.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: IntrinsicWeatherInversePipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ # Copyright 2026 The HuggingFace Team. All rights reserved.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ from __future__ import annotations
20
+
21
+ import inspect
22
+ from pathlib import Path
23
+ from typing import Any, Callable, Dict, List, Optional, Union
24
+
25
+ import PIL.Image
26
+ import torch
27
+ import torchvision
28
+ from transformers import (
29
+ CLIPTextModelWithProjection,
30
+ CLIPTokenizer,
31
+ T5EncoderModel,
32
+ T5TokenizerFast,
33
+ )
34
+
35
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
36
+ from diffusers.loaders import SD3LoraLoaderMixin
37
+ from diffusers.models.autoencoders import AutoencoderKL
38
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
39
+ from diffusers.utils import (
40
+ deprecate,
41
+ is_torch_xla_available,
42
+ logging,
43
+ replace_example_docstring,
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
47
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
48
+
49
+ from pipeline_utils import load_transformer_from_subfolder, resolve_repo_dir, set_flow_timesteps
50
+
51
+
52
+ if is_torch_xla_available():
53
+ import torch_xla.core.xla_model as xm
54
+
55
+ XLA_AVAILABLE = True
56
+ else:
57
+ XLA_AVAILABLE = False
58
+
59
+
60
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
61
+
62
+ EXAMPLE_DOC_STRING = """
63
+ Examples:
64
+ ```py
65
+ >>> from pathlib import Path
66
+ >>> import torch
67
+ >>> from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
68
+
69
+ >>> repo_dir = Path("./IntrisicWeather-diffusers").resolve()
70
+ >>> transformer = IntrinsicWeatherInversePipeline.load_transformer("inverse-512", repo_dir)
71
+ >>> pipe = IntrinsicWeatherInversePipeline.from_pretrained(
72
+ ... repo_dir,
73
+ ... transformer_subfolder="inverse-512",
74
+ ... local_files_only=True,
75
+ ... torch_dtype=torch.float16,
76
+ ... )
77
+ >>> pipe.to("cuda")
78
+ >>> output = pipe(
79
+ ... image=input_rgb_tensor,
80
+ ... prompt="Albedo (diffuse basecolor)",
81
+ ... aov=["albedo"],
82
+ ... num_inference_steps=50,
83
+ ... )
84
+ ```
85
+ """
86
+
87
+
88
+
89
+
90
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
91
+ def retrieve_latents(
92
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
93
+ ):
94
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
95
+ return encoder_output.latent_dist.sample(generator)
96
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
97
+ return encoder_output.latent_dist.mode()
98
+ elif hasattr(encoder_output, "latents"):
99
+ return encoder_output.latents
100
+ else:
101
+ raise AttributeError("Could not access latents of provided encoder_output")
102
+
103
+ class VaeAoVProcessor(VaeImageProcessor):
104
+ def postprocess(
105
+ self,
106
+ image: torch.FloatTensor,
107
+ output_type: str = "pil",
108
+ do_denormalize: Optional[List[bool]] = None,
109
+ do_gamma_correction: bool = True,
110
+ ):
111
+ if not isinstance(image, torch.Tensor):
112
+ raise ValueError(
113
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
114
+ )
115
+ if output_type not in ["latent", "pt", "np", "pil"]:
116
+ deprecation_message = (
117
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
118
+ "`pil`, `np`, `pt`, `latent`"
119
+ )
120
+ deprecate(
121
+ "Unsupported output_type",
122
+ "1.0.0",
123
+ deprecation_message,
124
+ standard_warn=False,
125
+ )
126
+ output_type = "np"
127
+
128
+ if output_type == "latent":
129
+ return image
130
+
131
+ if do_denormalize is None:
132
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
133
+
134
+ image = torch.stack(
135
+ [
136
+ self.denormalize(image[i]) if do_denormalize[i] else image[i]
137
+ for i in range(image.shape[0])
138
+ ]
139
+ )
140
+
141
+ # Gamma correction
142
+ if do_gamma_correction:
143
+ # image = torch.pow(image, 1.0 / 2.2)
144
+ image = image ** (1 / 2.2)
145
+
146
+ if output_type == "pt":
147
+ return image
148
+
149
+ image = self.pt_to_numpy(image)
150
+
151
+ if output_type == "np":
152
+ return image
153
+
154
+ if output_type == "pil":
155
+ return self.numpy_to_pil(image)
156
+
157
+ class IntrinsicWeatherInversePipeline(DiffusionPipeline, SD3LoraLoaderMixin):
158
+ r"""
159
+ Args:
160
+ transformer ([`SD3Transformer2DModel`]):
161
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
162
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
163
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
164
+ vae ([`AutoencoderKL`]):
165
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
166
+ text_encoder ([`CLIPTextModelWithProjection`]):
167
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
168
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
169
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
170
+ as its dimension.
171
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
172
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
173
+ specifically the
174
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
175
+ variant.
176
+ text_encoder_3 ([`T5EncoderModel`]):
177
+ Frozen text-encoder. Stable Diffusion 3 uses
178
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
179
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
180
+ tokenizer (`CLIPTokenizer`):
181
+ Tokenizer of class
182
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
183
+ tokenizer_2 (`CLIPTokenizer`):
184
+ Second Tokenizer of class
185
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
186
+ tokenizer_3 (`T5TokenizerFast`):
187
+ Tokenizer of class
188
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
189
+ """
190
+
191
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
192
+ _optional_components = []
193
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
194
+
195
+ @classmethod
196
+ def load_transformer(
197
+ cls,
198
+ transformer_subfolder: str,
199
+ pretrained_model_name_or_path: str | Path,
200
+ *,
201
+ dtype: torch.dtype = torch.bfloat16,
202
+ device: str | torch.device | None = None,
203
+ ):
204
+ return load_transformer_from_subfolder(
205
+ pretrained_model_name_or_path,
206
+ transformer_subfolder,
207
+ dtype=dtype,
208
+ device=device,
209
+ )
210
+
211
+ @classmethod
212
+ def from_pretrained(
213
+ cls,
214
+ pretrained_model_name_or_path,
215
+ *args,
216
+ transformer_subfolder: str = "inverse-512",
217
+ transformer=None,
218
+ **kwargs,
219
+ ):
220
+ repo_dir = resolve_repo_dir(pretrained_model_name_or_path)
221
+ dtype = kwargs.get("torch_dtype", torch.bfloat16)
222
+ device = kwargs.get("device")
223
+ if transformer is None:
224
+ transformer = cls.load_transformer(
225
+ transformer_subfolder,
226
+ repo_dir,
227
+ dtype=dtype,
228
+ device=device,
229
+ )
230
+ pipe = DiffusionPipeline.from_pretrained(
231
+ repo_dir.as_posix(),
232
+ *args,
233
+ transformer=transformer,
234
+ custom_pipeline=str(repo_dir / "pipeline_intrinsic_weather_inverse.py"),
235
+ trust_remote_code=True,
236
+ **kwargs,
237
+ )
238
+ if device is not None:
239
+ pipe = pipe.to(device)
240
+ return pipe
241
+
242
+ def __init__(
243
+ self,
244
+ transformer,
245
+ scheduler: FlowMatchEulerDiscreteScheduler,
246
+ vae: AutoencoderKL,
247
+ text_encoder: CLIPTextModelWithProjection,
248
+ tokenizer: CLIPTokenizer,
249
+ text_encoder_2: CLIPTextModelWithProjection,
250
+ tokenizer_2: CLIPTokenizer,
251
+ text_encoder_3: T5EncoderModel,
252
+ tokenizer_3: T5TokenizerFast,
253
+ ):
254
+ super().__init__()
255
+
256
+ self.register_modules(
257
+ vae=vae,
258
+ text_encoder=text_encoder,
259
+ text_encoder_2=text_encoder_2,
260
+ text_encoder_3=text_encoder_3,
261
+ tokenizer=tokenizer,
262
+ tokenizer_2=tokenizer_2,
263
+ tokenizer_3=tokenizer_3,
264
+ transformer=transformer,
265
+ scheduler=scheduler,
266
+ )
267
+ self.vae_scale_factor = (
268
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
269
+ )
270
+ self.image_processor = VaeAoVProcessor(vae_scale_factor=self.vae_scale_factor)
271
+ self.tokenizer_max_length = (
272
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
273
+ )
274
+ self.default_sample_size = (
275
+ self.transformer.config.sample_size
276
+ if hasattr(self, "transformer") and self.transformer is not None
277
+ else 128
278
+ )
279
+
280
+ def _get_t5_prompt_embeds(
281
+ self,
282
+ prompt: Union[str, List[str]] = None,
283
+ num_images_per_prompt: int = 1,
284
+ device: Optional[torch.device] = None,
285
+ dtype: Optional[torch.dtype] = None,
286
+ ):
287
+ device = device or self._execution_device
288
+ dtype = dtype or self.text_encoder.dtype
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ batch_size = len(prompt)
292
+
293
+ if self.text_encoder_3 is None:
294
+ return torch.zeros(
295
+ (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
296
+ device=device,
297
+ dtype=dtype,
298
+ )
299
+
300
+ text_inputs = self.tokenizer_3(
301
+ prompt,
302
+ padding="max_length",
303
+ max_length=self.tokenizer_max_length,
304
+ truncation=True,
305
+ add_special_tokens=True,
306
+ return_tensors="pt",
307
+ )
308
+ text_input_ids = text_inputs.input_ids
309
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
310
+
311
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
312
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
313
+ logger.warning(
314
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
315
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
316
+ )
317
+
318
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
319
+
320
+ dtype = self.text_encoder_3.dtype
321
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
322
+
323
+ _, seq_len, _ = prompt_embeds.shape
324
+
325
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
326
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
328
+
329
+ return prompt_embeds
330
+
331
+ def _get_clip_prompt_embeds(
332
+ self,
333
+ prompt: Union[str, List[str]],
334
+ num_images_per_prompt: int = 1,
335
+ device: Optional[torch.device] = None,
336
+ clip_skip: Optional[int] = None,
337
+ clip_model_index: int = 0,
338
+ ):
339
+ device = device or self._execution_device
340
+
341
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
342
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
343
+
344
+ tokenizer = clip_tokenizers[clip_model_index]
345
+ text_encoder = clip_text_encoders[clip_model_index]
346
+
347
+ prompt = [prompt] if isinstance(prompt, str) else prompt
348
+ batch_size = len(prompt)
349
+
350
+ text_inputs = tokenizer(
351
+ prompt,
352
+ padding="max_length",
353
+ max_length=self.tokenizer_max_length,
354
+ truncation=True,
355
+ return_tensors="pt",
356
+ )
357
+
358
+ text_input_ids = text_inputs.input_ids
359
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
360
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
361
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
362
+ logger.warning(
363
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
364
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
365
+ )
366
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
367
+ pooled_prompt_embeds = prompt_embeds[0]
368
+
369
+ if clip_skip is None:
370
+ prompt_embeds = prompt_embeds.hidden_states[-2]
371
+ else:
372
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
373
+
374
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
375
+
376
+ _, seq_len, _ = prompt_embeds.shape
377
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
378
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
379
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
380
+
381
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
382
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
383
+
384
+ return prompt_embeds, pooled_prompt_embeds
385
+
386
+ def encode_prompt(
387
+ self,
388
+ prompt: Union[str, List[str]],
389
+ prompt_2: Union[str, List[str]] = None,
390
+ prompt_3: Union[str, List[str]] = None,
391
+ device: Optional[torch.device] = None,
392
+ num_images_per_prompt: int = 1,
393
+ do_classifier_free_guidance: bool = True,
394
+ negative_prompt: Optional[Union[str, List[str]]] = None,
395
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
396
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
397
+ prompt_embeds: Optional[torch.FloatTensor] = None,
398
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
399
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
400
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
401
+ clip_skip: Optional[int] = None,
402
+ ):
403
+ r"""
404
+
405
+ Args:
406
+ prompt (`str` or `List[str]`, *optional*):
407
+ prompt to be encoded
408
+ prompt_2 (`str` or `List[str]`, *optional*):
409
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
410
+ used in all text-encoders
411
+ prompt_3 (`str` or `List[str]`, *optional*):
412
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
413
+ used in all text-encoders
414
+ device: (`torch.device`):
415
+ torch device
416
+ num_images_per_prompt (`int`):
417
+ number of images that should be generated per prompt
418
+ do_classifier_free_guidance (`bool`):
419
+ whether to use classifier free guidance or not
420
+ negative_prompt (`str` or `List[str]`, *optional*):
421
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
422
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
423
+ less than `1`).
424
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
425
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
426
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
427
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
428
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
429
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
430
+ prompt_embeds (`torch.FloatTensor`, *optional*):
431
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
432
+ provided, text embeddings will be generated from `prompt` input argument.
433
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
434
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
435
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
436
+ argument.
437
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
438
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
439
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
440
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
441
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
442
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
443
+ input argument.
444
+ clip_skip (`int`, *optional*):
445
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
446
+ the output of the pre-final layer will be used for computing the prompt embeddings.
447
+ """
448
+ device = device or self._execution_device
449
+
450
+ prompt = [prompt] if isinstance(prompt, str) else prompt
451
+ if prompt is not None:
452
+ batch_size = len(prompt)
453
+ else:
454
+ batch_size = prompt_embeds.shape[0]
455
+
456
+ if prompt_embeds is None:
457
+ prompt_2 = prompt_2 or prompt
458
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
459
+
460
+ prompt_3 = prompt_3 or prompt
461
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
462
+
463
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
464
+ prompt=prompt,
465
+ device=device,
466
+ num_images_per_prompt=num_images_per_prompt,
467
+ clip_skip=clip_skip,
468
+ clip_model_index=0,
469
+ )
470
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
471
+ prompt=prompt_2,
472
+ device=device,
473
+ num_images_per_prompt=num_images_per_prompt,
474
+ clip_skip=clip_skip,
475
+ clip_model_index=1,
476
+ )
477
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
478
+
479
+ t5_prompt_embed = self._get_t5_prompt_embeds(
480
+ prompt=prompt_3,
481
+ num_images_per_prompt=num_images_per_prompt,
482
+ device=device,
483
+ )
484
+
485
+ clip_prompt_embeds = torch.nn.functional.pad(
486
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
487
+ )
488
+
489
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
490
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
491
+
492
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
493
+ negative_prompt = negative_prompt or ""
494
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
495
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
496
+
497
+ # normalize str to list
498
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
499
+ negative_prompt_2 = (
500
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
501
+ )
502
+ negative_prompt_3 = (
503
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
504
+ )
505
+
506
+ if prompt is not None and type(prompt) is not type(negative_prompt):
507
+ raise TypeError(
508
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
509
+ f" {type(prompt)}."
510
+ )
511
+ elif batch_size != len(negative_prompt):
512
+ raise ValueError(
513
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
514
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
515
+ " the batch size of `prompt`."
516
+ )
517
+
518
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
519
+ negative_prompt,
520
+ device=device,
521
+ num_images_per_prompt=num_images_per_prompt,
522
+ clip_skip=None,
523
+ clip_model_index=0,
524
+ )
525
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
526
+ negative_prompt_2,
527
+ device=device,
528
+ num_images_per_prompt=num_images_per_prompt,
529
+ clip_skip=None,
530
+ clip_model_index=1,
531
+ )
532
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
533
+
534
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
535
+ prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
536
+ )
537
+
538
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
539
+ negative_clip_prompt_embeds,
540
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
541
+ )
542
+
543
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
544
+ negative_pooled_prompt_embeds = torch.cat(
545
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
546
+ )
547
+
548
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
549
+
550
+ def check_inputs(
551
+ self,
552
+ prompt,
553
+ prompt_2,
554
+ prompt_3,
555
+ # height,
556
+ # width,
557
+ negative_prompt=None,
558
+ negative_prompt_2=None,
559
+ negative_prompt_3=None,
560
+ prompt_embeds=None,
561
+ negative_prompt_embeds=None,
562
+ pooled_prompt_embeds=None,
563
+ negative_pooled_prompt_embeds=None,
564
+ callback_on_step_end_tensor_inputs=None,
565
+ ):
566
+ # if height % 8 != 0 or width % 8 != 0:
567
+ # raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
568
+
569
+ if callback_on_step_end_tensor_inputs is not None and not all(
570
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
571
+ ):
572
+ raise ValueError(
573
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
574
+ )
575
+
576
+ if prompt is not None and prompt_embeds is not None:
577
+ raise ValueError(
578
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
579
+ " only forward one of the two."
580
+ )
581
+ elif prompt_2 is not None and prompt_embeds is not None:
582
+ raise ValueError(
583
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
584
+ " only forward one of the two."
585
+ )
586
+ elif prompt_3 is not None and prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
589
+ " only forward one of the two."
590
+ )
591
+ elif prompt is None and prompt_embeds is None:
592
+ raise ValueError(
593
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
594
+ )
595
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
596
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
597
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
598
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
599
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
600
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
601
+
602
+ if negative_prompt is not None and negative_prompt_embeds is not None:
603
+ raise ValueError(
604
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
605
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
606
+ )
607
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
608
+ raise ValueError(
609
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
610
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
611
+ )
612
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
613
+ raise ValueError(
614
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
615
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
616
+ )
617
+
618
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
619
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
620
+ raise ValueError(
621
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
622
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
623
+ f" {negative_prompt_embeds.shape}."
624
+ )
625
+
626
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
627
+ raise ValueError(
628
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
629
+ )
630
+
631
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
632
+ raise ValueError(
633
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
634
+ )
635
+
636
+ def prepare_latents(
637
+ self,
638
+ batch_size,
639
+ num_channels_latents,
640
+ height,
641
+ width,
642
+ dtype,
643
+ device,
644
+ generator,
645
+ latents=None,
646
+ ):
647
+ if latents is not None:
648
+ return latents.to(device=device, dtype=dtype)
649
+
650
+ shape = (
651
+ batch_size,
652
+ num_channels_latents,
653
+ int(height) // self.vae_scale_factor,
654
+ int(width) // self.vae_scale_factor,
655
+ )
656
+
657
+ if isinstance(generator, list) and len(generator) != batch_size:
658
+ raise ValueError(
659
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
660
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
661
+ )
662
+
663
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
664
+
665
+ return latents
666
+
667
+ def prepare_image_latents(
668
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
669
+ ):
670
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
671
+ raise ValueError(
672
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
673
+ )
674
+
675
+ image = image.to(device=device, dtype=dtype)
676
+
677
+ batch_size = batch_size * num_images_per_prompt
678
+
679
+ if image.shape[1] == self.vae.config.latent_channels:
680
+ image_latents = image
681
+ else:
682
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
683
+ # ? normalize image latents
684
+ # image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
685
+
686
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
687
+ # expand image_latents for batch_size
688
+ deprecation_message = (
689
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
690
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
691
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
692
+ " your script to pass as many initial images as text prompts to suppress this warning."
693
+ )
694
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
695
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
696
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
697
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
698
+ raise ValueError(
699
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
700
+ )
701
+ else:
702
+ image_latents = torch.cat([image_latents], dim=0)
703
+
704
+ if do_classifier_free_guidance:
705
+ uncond_image_latents = torch.zeros_like(image_latents)
706
+ image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
707
+
708
+ return image_latents
709
+
710
+
711
+
712
+
713
+
714
+
715
+
716
+
717
+ @property
718
+ def guidance_scale(self):
719
+ return self._guidance_scale
720
+ @property
721
+ def image_guidance_scale(self):
722
+ return self._image_guidance_scale
723
+
724
+ @property
725
+ def clip_skip(self):
726
+ return self._clip_skip
727
+
728
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
729
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
730
+ # corresponds to doing no classifier free guidance.
731
+ @property
732
+ def do_classifier_free_guidance(self):
733
+ return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
734
+
735
+ @property
736
+ def joint_attention_kwargs(self):
737
+ return self._joint_attention_kwargs
738
+
739
+ @property
740
+ def num_timesteps(self):
741
+ return self._num_timesteps
742
+
743
+ @property
744
+ def interrupt(self):
745
+ return self._interrupt
746
+
747
+ @torch.no_grad()
748
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
749
+ def __call__(
750
+ self,
751
+ prompt: Union[str, List[str]] = None,
752
+ prompt_2: Optional[Union[str, List[str]]] = None,
753
+ prompt_3: Optional[Union[str, List[str]]] = None,
754
+ image: PipelineImageInput = None,
755
+ height: Optional[int] = None,
756
+ width: Optional[int] = None,
757
+ num_inference_steps: int = 28,
758
+ timesteps: List[int] = None,
759
+ guidance_scale: float = 7.0,
760
+ image_guidance_scale: float = 1.5,
761
+ negative_prompt: Optional[Union[str, List[str]]] = None,
762
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
763
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
764
+ num_images_per_prompt: Optional[int] = 1,
765
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
766
+ latents: Optional[torch.FloatTensor] = None,
767
+ prompt_embeds: Optional[torch.FloatTensor] = None,
768
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
769
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
770
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
771
+ output_type: Optional[str] = "pil",
772
+ return_dict: bool = True,
773
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
774
+ clip_skip: Optional[int] = None,
775
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
776
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
777
+ mask_img: Optional[PipelineImageInput] = None,
778
+ aov: List[str] = ["albedo"],
779
+ map_aware_mask: Optional[torch.FloatTensor] = None,
780
+ **kwargs
781
+ ):
782
+ r"""
783
+ Function invoked when calling the pipeline for generation.
784
+
785
+ Args:
786
+ prompt (`str` or `List[str]`, *optional*):
787
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
788
+ instead.
789
+ prompt_2 (`str` or `List[str]`, *optional*):
790
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
791
+ will be used instead
792
+ prompt_3 (`str` or `List[str]`, *optional*):
793
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
794
+ will be used instead
795
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
796
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
797
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
798
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
799
+ num_inference_steps (`int`, *optional*, defaults to 50):
800
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
801
+ expense of slower inference.
802
+ timesteps (`List[int]`, *optional*):
803
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
804
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
805
+ passed will be used. Must be in descending order.
806
+ guidance_scale (`float`, *optional*, defaults to 5.0):
807
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
808
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
809
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
810
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
811
+ usually at the expense of lower image quality.
812
+ negative_prompt (`str` or `List[str]`, *optional*):
813
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
814
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
815
+ less than `1`).
816
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
817
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
818
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
819
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
820
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
821
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
822
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
823
+ The number of images to generate per prompt.
824
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
825
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
826
+ to make generation deterministic.
827
+ latents (`torch.FloatTensor`, *optional*):
828
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
829
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
830
+ tensor will ge generated by sampling using the supplied random `generator`.
831
+ prompt_embeds (`torch.FloatTensor`, *optional*):
832
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
833
+ provided, text embeddings will be generated from `prompt` input argument.
834
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
835
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
836
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
837
+ argument.
838
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
839
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
840
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
841
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
842
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
843
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
844
+ input argument.
845
+ output_type (`str`, *optional*, defaults to `"pil"`):
846
+ The output format of the generate image. Choose between
847
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
848
+ return_dict (`bool`, *optional*, defaults to `True`):
849
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
850
+ of a plain tuple.
851
+ joint_attention_kwargs (`dict`, *optional*):
852
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
853
+ `self.processor` in
854
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
855
+ callback_on_step_end (`Callable`, *optional*):
856
+ A function that calls at the end of each denoising steps during the inference. The function is called
857
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
858
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
859
+ `callback_on_step_end_tensor_inputs`.
860
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
861
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
862
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
863
+ `._callback_tensor_inputs` attribute of your pipeline class.
864
+
865
+ Examples:
866
+
867
+ Returns:
868
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
869
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
870
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
871
+ """
872
+
873
+ # height = height or self.default_sample_size * self.vae_scale_factor
874
+ # width = width or self.default_sample_size * self.vae_scale_factor
875
+
876
+ # 1. Check inputs. Raise error if not correct
877
+ self.check_inputs(
878
+ prompt,
879
+ prompt_2,
880
+ prompt_3,
881
+ negative_prompt=negative_prompt,
882
+ negative_prompt_2=negative_prompt_2,
883
+ negative_prompt_3=negative_prompt_3,
884
+ prompt_embeds=prompt_embeds,
885
+ negative_prompt_embeds=negative_prompt_embeds,
886
+ pooled_prompt_embeds=pooled_prompt_embeds,
887
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
888
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
889
+ )
890
+
891
+ self._guidance_scale = guidance_scale
892
+ self._image_guidance_scale = image_guidance_scale
893
+ self._clip_skip = clip_skip
894
+ self._joint_attention_kwargs = joint_attention_kwargs
895
+ self._interrupt = False
896
+
897
+ # 2. Define call parameters
898
+ if prompt is not None and isinstance(prompt, str):
899
+ batch_size = 1
900
+ elif prompt is not None and isinstance(prompt, list):
901
+ batch_size = len(prompt)
902
+ else:
903
+ batch_size = prompt_embeds.shape[0]
904
+
905
+ device = self._execution_device
906
+
907
+ (
908
+ prompt_embeds,
909
+ negative_prompt_embeds,
910
+ pooled_prompt_embeds,
911
+ negative_pooled_prompt_embeds,
912
+ ) = self.encode_prompt(
913
+ prompt=prompt,
914
+ prompt_2=prompt_2,
915
+ prompt_3=prompt_3,
916
+ negative_prompt=negative_prompt,
917
+ negative_prompt_2=negative_prompt_2,
918
+ negative_prompt_3=negative_prompt_3,
919
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
920
+ prompt_embeds=prompt_embeds,
921
+ negative_prompt_embeds=negative_prompt_embeds,
922
+ pooled_prompt_embeds=pooled_prompt_embeds,
923
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
924
+ device=device,
925
+ clip_skip=self.clip_skip,
926
+ num_images_per_prompt=num_images_per_prompt,
927
+ )
928
+ # print("prompt:", prompt_embeds.shape)
929
+ if self.do_classifier_free_guidance:
930
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
931
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
932
+
933
+ # Similiarly
934
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0)
935
+
936
+ # if self.do_classifier_free_guidance:
937
+ # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
938
+ # pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
939
+
940
+ # 3. Preprocess image
941
+ if (isinstance(image, PIL.Image.Image) and len(image.size)==4) or (isinstance(image, torch.Tensor) and len(image.shape) == 4):
942
+ old_height = image.shape[2]
943
+ old_width = image.shape[3]
944
+ else:
945
+ old_height = image.shape[1]
946
+ old_width = image.shape[2]
947
+
948
+ image = self.image_processor.preprocess(image)
949
+ # image = torchvision.transforms.Resize((512,512),interpolation=PIL.Image.BICUBIC)(image)
950
+
951
+ # 4. Prepare Image latent
952
+ image_latents = self.prepare_image_latents(
953
+ image,
954
+ batch_size,
955
+ num_images_per_prompt,
956
+ prompt_embeds.dtype,
957
+ device,
958
+ self.do_classifier_free_guidance,
959
+ )
960
+
961
+ # 5. Prepare timesteps
962
+ set_flow_timesteps(
963
+ self.scheduler,
964
+ self.transformer,
965
+ num_inference_steps,
966
+ image_latents.shape[-2],
967
+ image_latents.shape[-1],
968
+ device,
969
+ )
970
+ timesteps = self.scheduler.timesteps
971
+
972
+ # timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
973
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
974
+ self._num_timesteps = len(timesteps)
975
+
976
+ height, width = image_latents.shape[-2:]
977
+ height = height * self.vae_scale_factor
978
+ width = width * self.vae_scale_factor
979
+ # 6. Prepare latent variables
980
+ num_channels_latents = self.vae.config.latent_channels
981
+ latents = self.prepare_latents(
982
+ batch_size * num_images_per_prompt,
983
+ num_channels_latents,
984
+ height,
985
+ width,
986
+ prompt_embeds.dtype,
987
+ device,
988
+ generator,
989
+ latents,
990
+ )
991
+
992
+ # 7. Check that shapes of latents and image match the DIT in_channels
993
+ num_channels_image = image_latents.shape[1]
994
+ if mask_img is not None:
995
+ mask_img = self.image_processor.preprocess(mask_img)
996
+ mask_image_latents = self.prepare_image_latents(
997
+ mask_img,
998
+ batch_size,
999
+ num_images_per_prompt,
1000
+ prompt_embeds.dtype,
1001
+ device,
1002
+ self.do_classifier_free_guidance,
1003
+ )
1004
+ num_channels_image += mask_image_latents.shape[1]
1005
+
1006
+ if num_channels_latents + num_channels_image != self.transformer.config.in_channels:
1007
+ raise ValueError(
1008
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
1009
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1010
+ f" `num_channels_image`: {num_channels_image} "
1011
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
1012
+ " `pipeline.transformer` or your `image` input."
1013
+ )
1014
+
1015
+ # 8. Denoising loop
1016
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1017
+ for i, t in enumerate(timesteps):
1018
+ if self.interrupt:
1019
+ continue
1020
+
1021
+ # expand the latents if we are doing classifier free guidance
1022
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
1023
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1024
+ timestep = t.expand(latent_model_input.shape[0])
1025
+
1026
+ scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
1027
+ if mask_img is not None:
1028
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, mask_image_latents], dim=1)
1029
+ # if "mask_index" in kwargs and kwargs['mask_index'] is not None:
1030
+ # mask_index = kwargs['mask_index']
1031
+ # else:
1032
+ # mask_index = None
1033
+ noise_pred = self.transformer(
1034
+ hidden_states=scaled_latent_model_input,
1035
+ timestep=timestep,
1036
+ encoder_hidden_states=prompt_embeds,
1037
+ pooled_projections=pooled_prompt_embeds,
1038
+ map_aware_mask=map_aware_mask,
1039
+ joint_attention_kwargs=self.joint_attention_kwargs,
1040
+ return_dict=False,
1041
+ # mask_index= mask_index,
1042
+ )[0]
1043
+
1044
+ # perform guidance
1045
+ if self.do_classifier_free_guidance:
1046
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
1047
+ noise_pred = (
1048
+ noise_pred_uncond
1049
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
1050
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
1051
+ )
1052
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # neg, prompt
1053
+ # noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1054
+
1055
+ # compute the previous noisy sample x_t -> x_t-1
1056
+ latents_dtype = latents.dtype
1057
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1058
+
1059
+ if latents.dtype != latents_dtype:
1060
+ if torch.backends.mps.is_available():
1061
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1062
+ latents = latents.to(latents_dtype)
1063
+
1064
+ if callback_on_step_end is not None:
1065
+ callback_kwargs = {}
1066
+ for k in callback_on_step_end_tensor_inputs:
1067
+ callback_kwargs[k] = locals()[k]
1068
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1069
+
1070
+ latents = callback_outputs.pop("latents", latents)
1071
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1072
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1073
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1074
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1075
+ )
1076
+ image_latents = callback_outputs.pop("image_latents", image_latents)
1077
+ if mask_img is not None:
1078
+ mask_image_latents = callback_outputs.pop("mask_image_latents", mask_image_latents)
1079
+ # call the callback, if provided
1080
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1081
+ progress_bar.update()
1082
+
1083
+ if XLA_AVAILABLE:
1084
+ xm.mark_step()
1085
+
1086
+ if output_type == "latent":
1087
+ image = latents
1088
+
1089
+ else:
1090
+ # latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1091
+ latents = latents / self.vae.config.scaling_factor
1092
+ image = self.vae.decode(latents, return_dict=False)[0]
1093
+
1094
+ do_denormalize = [True] * image.shape[0]
1095
+ aov_name = aov[0]
1096
+ if aov_name == "albedo" or aov_name == "irradiance":
1097
+ do_gamma_correction = True
1098
+ else:
1099
+ do_gamma_correction = False
1100
+
1101
+ if aov_name == "roughness" or aov_name == "metallic":
1102
+ image = image[:, 0:1].repeat(1, 3, 1, 1)
1103
+ # print(image.shape)
1104
+ # print(old_height, old_width)
1105
+ image = torchvision.transforms.Resize((old_height, old_width),interpolation=PIL.Image.BICUBIC)(image)
1106
+ image = self.image_processor.postprocess(
1107
+ image,
1108
+ output_type=output_type,
1109
+ do_denormalize=do_denormalize,
1110
+ do_gamma_correction=do_gamma_correction,
1111
+ )
1112
+
1113
+ # Offload all models
1114
+ self.maybe_free_model_hooks()
1115
+
1116
+ if not return_dict:
1117
+ return (image,)
1118
+
1119
+ return StableDiffusion3PipelineOutput(images=image)
pipeline_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for loading transformer variants from ``transformer/<subfolder>/``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ from diffusers.models.transformers import SD3Transformer2DModel
10
+
11
+
12
+ def calculate_shift(
13
+ image_seq_len: int,
14
+ base_seq_len: int = 256,
15
+ max_seq_len: int = 4096,
16
+ base_shift: float = 0.5,
17
+ max_shift: float = 1.15,
18
+ ) -> float:
19
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
20
+ b = base_shift - m * base_seq_len
21
+ return image_seq_len * m + b
22
+
23
+
24
+ def set_flow_timesteps(
25
+ scheduler,
26
+ transformer,
27
+ num_inference_steps: int,
28
+ latent_height: int,
29
+ latent_width: int,
30
+ device: torch.device,
31
+ ) -> None:
32
+ if scheduler.config.get("use_dynamic_shifting", False):
33
+ image_seq_len = (latent_height // transformer.config.patch_size) * (
34
+ latent_width // transformer.config.patch_size
35
+ )
36
+ mu = calculate_shift(
37
+ image_seq_len,
38
+ scheduler.config.get("base_image_seq_len", 256),
39
+ scheduler.config.get("max_image_seq_len", 4096),
40
+ scheduler.config.get("base_shift", 0.5),
41
+ scheduler.config.get("max_shift", 1.15),
42
+ )
43
+ scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
44
+ else:
45
+ scheduler.set_timesteps(num_inference_steps, device=device)
46
+
47
+
48
+ def resolve_repo_dir(pretrained_model_name_or_path: str | Path) -> Path:
49
+ return Path(pretrained_model_name_or_path).resolve()
50
+
51
+
52
+ def load_transformer_from_subfolder(
53
+ repo_dir: str | Path,
54
+ transformer_subfolder: str,
55
+ *,
56
+ dtype: torch.dtype = torch.bfloat16,
57
+ device: str | torch.device | None = None,
58
+ ):
59
+ """Load a transformer checkpoint from ``<repo_dir>/transformer/<transformer_subfolder>/``."""
60
+ repo_dir = resolve_repo_dir(repo_dir)
61
+ transformer_path = repo_dir / "transformer" / transformer_subfolder
62
+ if not transformer_path.is_dir():
63
+ raise FileNotFoundError(f"Transformer folder not found: {transformer_path}")
64
+
65
+ custom_module = transformer_path / "transformer_intrinsic_weather.py"
66
+ if custom_module.exists():
67
+ spec = importlib.util.spec_from_file_location("transformer_intrinsic_weather", custom_module)
68
+ if spec is None or spec.loader is None:
69
+ raise ImportError(f"Cannot import custom transformer module: {custom_module}")
70
+ module = importlib.util.module_from_spec(spec)
71
+ spec.loader.exec_module(module)
72
+ cls = module.IntrinsicWeatherSD3Transformer2DModel
73
+ transformer = cls.from_pretrained(
74
+ transformer_path.as_posix(),
75
+ torch_dtype=dtype,
76
+ local_files_only=True,
77
+ )
78
+ else:
79
+ transformer = SD3Transformer2DModel.from_pretrained(
80
+ transformer_path.as_posix(),
81
+ torch_dtype=dtype,
82
+ local_files_only=True,
83
+ )
84
+
85
+ if device is not None:
86
+ transformer = transformer.to(device)
87
+ return transformer
88
+
89
+
90
+ def resolve_transformer_lora_dir(repo_dir: str | Path, transformer_subfolder: str) -> Path | None:
91
+ """Return ``transformer/<subfolder>/lora`` when present."""
92
+ lora_dir = resolve_repo_dir(repo_dir) / "transformer" / transformer_subfolder / "lora"
93
+ if lora_dir.is_dir() and any(lora_dir.glob("*.safetensors")):
94
+ return lora_dir
95
+ return None
96
+
97
+
98
+ def load_transformer_lora(pipe, repo_dir: str | Path, transformer_subfolder: str) -> bool:
99
+ """Load LoRA weights bundled with a transformer variant. Returns True if loaded."""
100
+ lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder)
101
+ if lora_dir is None:
102
+ return False
103
+ pipe.load_lora_weights(lora_dir.as_posix())
104
+ return True
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.35.1",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 3.0,
6
+ "use_dynamic_shifting": true,
7
+ "base_shift": 0.5,
8
+ "max_shift": 1.15
9
+ }
test_all_pipelines.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Smoke-test all IntrinsicWeather pipelines on CUDA with bfloat16."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import gc
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ import torch
11
+
12
+ REPO = Path(__file__).resolve().parent
13
+ sys.path.insert(0, str(REPO))
14
+
15
+ DTYPE = torch.bfloat16
16
+ DEVICE = "cuda"
17
+ IMAGE_SIZE = 512
18
+ STEPS = 2
19
+
20
+
21
+ def _clear():
22
+ gc.collect()
23
+ if torch.cuda.is_available():
24
+ torch.cuda.empty_cache()
25
+
26
+
27
+ def test_inverse() -> None:
28
+ from imaa.imaa import IMAA
29
+ from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
30
+ from safetensors.torch import load_file
31
+
32
+ print("[inverse] loading pipeline ...")
33
+ pipe = IntrinsicWeatherInversePipeline.from_pretrained(
34
+ REPO,
35
+ transformer_subfolder="inverse-512",
36
+ device=DEVICE,
37
+ local_files_only=True,
38
+ torch_dtype=DTYPE,
39
+ )
40
+ assert next(pipe.transformer.parameters()).dtype == DTYPE
41
+ assert next(pipe.transformer.parameters()).device.type == "cuda"
42
+ print(f"[inverse] transformer in_channels={pipe.transformer.config.in_channels}")
43
+
44
+ imaa = IMAA(dino_model=None, processor=None, num_maps=5, map_embedding_dim=256, common_dim=128).to(DEVICE)
45
+ imaa.load_state_dict(load_file((REPO / "imaa" / "model.safetensors").as_posix()))
46
+ imaa.eval()
47
+
48
+ image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE, dtype=DTYPE)
49
+ prompt_embeds, _, pooled_prompt_embeds, _ = pipe.encode_prompt(
50
+ prompt="Albedo (diffuse basecolor)",
51
+ prompt_2=None,
52
+ prompt_3=None,
53
+ do_classifier_free_guidance=False,
54
+ )
55
+ print("[inverse] running 2-step inference ...")
56
+ out = pipe(
57
+ image=image,
58
+ prompt_embeds=prompt_embeds,
59
+ pooled_prompt_embeds=pooled_prompt_embeds,
60
+ guidance_scale=0.0,
61
+ image_guidance_scale=0.0,
62
+ num_inference_steps=STEPS,
63
+ output_type="pt",
64
+ aov=["albedo"],
65
+ map_aware_mask=None,
66
+ )
67
+ print(f"[inverse] output shape={tuple(out.images[0].shape)} dtype={out.images[0].dtype}")
68
+ del pipe, imaa, out
69
+ _clear()
70
+
71
+
72
+ def test_forward() -> None:
73
+ from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
74
+
75
+ print("[forward] loading pipeline ...")
76
+ pipe = IntrinsicWeatherForwardPipeline.from_pretrained(
77
+ REPO,
78
+ transformer_subfolder="forward",
79
+ device=DEVICE,
80
+ local_files_only=True,
81
+ torch_dtype=DTYPE,
82
+ load_lora=True,
83
+ )
84
+ assert next(pipe.transformer.parameters()).dtype == DTYPE
85
+ assert next(pipe.transformer.parameters()).device.type == "cuda"
86
+ print(f"[forward] transformer in_channels={pipe.transformer.config.in_channels}")
87
+
88
+ aov = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE, dtype=DTYPE)
89
+ print("[forward] running 2-step inference ...")
90
+ out = pipe(
91
+ albedo=aov,
92
+ normal=aov,
93
+ roughness=aov,
94
+ metallic=aov,
95
+ irradiance=aov,
96
+ prompt=["A rainy day."],
97
+ guidance_scale=6.0,
98
+ image_guidance_scale=1.5,
99
+ num_inference_steps=STEPS,
100
+ required_aovs=["albedo", "normal", "roughness", "metallic", "irradiance"],
101
+ generator=torch.Generator(device=DEVICE).manual_seed(0),
102
+ )
103
+ print(f"[forward] output size={out.images[0].size}")
104
+ del pipe, out
105
+ _clear()
106
+
107
+
108
+ def test_unified() -> None:
109
+ from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
110
+
111
+ print("[unified] loading pipeline ...")
112
+ pipe = IntrinsicWeatherPipeline.from_pretrained(
113
+ REPO,
114
+ inverse_transformer_subfolder="inverse-512",
115
+ forward_transformer_subfolder="forward",
116
+ device=DEVICE,
117
+ local_files_only=True,
118
+ torch_dtype=DTYPE,
119
+ load_lora=True,
120
+ load_imaa=True,
121
+ )
122
+ assert next(pipe.inverse_transformer.parameters()).dtype == DTYPE
123
+ assert next(pipe.forward_transformer.parameters()).dtype == DTYPE
124
+ print("[unified] loaded inverse + forward transformers and IMAA")
125
+ del pipe
126
+ _clear()
127
+
128
+
129
+ def main() -> None:
130
+ if not torch.cuda.is_available():
131
+ raise SystemExit("CUDA is required for this test.")
132
+
133
+ print(f"device={torch.cuda.get_device_name(0)} dtype={DTYPE}")
134
+ test_inverse()
135
+ test_forward()
136
+ test_unified()
137
+ print("All pipeline tests passed.")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
text_encoder/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModelWithProjection"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "quick_gelu",
10
+ "hidden_size": 768,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.41.2",
23
+ "vocab_size": 49408
24
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71e183d11db0c6b6282a4d9e0abb74125edc8692393e89ed8ee5571005f35cb1
3
+ size 247323896
text_encoder_2/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModelWithProjection"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.41.2",
23
+ "vocab_size": 49408
24
+ }
text_encoder_2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec310df2af79c318e24d20511b601a591ca8cd4f1fce1d8dff822a356bcdb1f4
3
+ size 1389382176
text_encoder_3/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5EncoderModel"
4
+ ],
5
+ "classifier_dropout": 0.0,
6
+ "d_ff": 10240,
7
+ "d_kv": 64,
8
+ "d_model": 4096,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "num_decoder_layers": 24,
20
+ "num_heads": 64,
21
+ "num_layers": 24,
22
+ "output_past": true,
23
+ "pad_token_id": 0,
24
+ "relative_attention_max_distance": 128,
25
+ "relative_attention_num_buckets": 32,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float16",
28
+ "transformers_version": "4.41.2",
29
+ "use_cache": true,
30
+ "vocab_size": 32128
31
+ }
text_encoder_3/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f2751ceeb2a96edd693e539dc5d6bba0b8d3814f49a9b3798403a0cec4b2e3d
3
+ size 4994582104
text_encoder_3/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f63154532130422309532ff56f11945fbea8266c958e3133e8e5aef85c6293c7
3
+ size 4530066248
text_encoder_3/model.safetensors.index.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 9524621312
4
+ },
5
+ "weight_map": {
6
+ "encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
7
+ "encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
8
+ "encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
9
+ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
10
+ "encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
11
+ "encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
12
+ "encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
13
+ "encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
14
+ "encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
15
+ "encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
16
+ "encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
17
+ "encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
18
+ "encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
19
+ "encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
20
+ "encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
21
+ "encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
22
+ "encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
23
+ "encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
24
+ "encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
25
+ "encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
26
+ "encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
27
+ "encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
28
+ "encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
29
+ "encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
30
+ "encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
31
+ "encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
32
+ "encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
33
+ "encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
34
+ "encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
35
+ "encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
36
+ "encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
37
+ "encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
38
+ "encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
39
+ "encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
40
+ "encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
41
+ "encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
42
+ "encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
43
+ "encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
44
+ "encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
45
+ "encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
46
+ "encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
47
+ "encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
48
+ "encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
49
+ "encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
50
+ "encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
51
+ "encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
52
+ "encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
53
+ "encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
54
+ "encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
55
+ "encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
56
+ "encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
57
+ "encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
58
+ "encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
59
+ "encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
60
+ "encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
61
+ "encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
62
+ "encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
63
+ "encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
64
+ "encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
65
+ "encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
66
+ "encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
67
+ "encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
68
+ "encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
69
+ "encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
70
+ "encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
71
+ "encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
72
+ "encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
73
+ "encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
74
+ "encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
75
+ "encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
76
+ "encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
77
+ "encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
78
+ "encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
79
+ "encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
80
+ "encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
81
+ "encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
82
+ "encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
83
+ "encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
84
+ "encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
85
+ "encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
86
+ "encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
87
+ "encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
88
+ "encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
89
+ "encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
90
+ "encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
91
+ "encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
92
+ "encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
93
+ "encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
94
+ "encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
95
+ "encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
96
+ "encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
97
+ "encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
98
+ "encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
99
+ "encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
100
+ "encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
101
+ "encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
102
+ "encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
103
+ "encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
104
+ "encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
105
+ "encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
106
+ "encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
107
+ "encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
108
+ "encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
109
+ "encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
110
+ "encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
111
+ "encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
112
+ "encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
113
+ "encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
114
+ "encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
115
+ "encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
116
+ "encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
117
+ "encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
118
+ "encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
119
+ "encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
120
+ "encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
121
+ "encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
122
+ "encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
123
+ "encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
124
+ "encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
125
+ "encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
126
+ "encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
127
+ "encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
128
+ "encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
129
+ "encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
130
+ "encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
131
+ "encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
132
+ "encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
133
+ "encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
134
+ "encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
135
+ "encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
136
+ "encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
137
+ "encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
138
+ "encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
139
+ "encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
140
+ "encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
141
+ "encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
142
+ "encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
143
+ "encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
144
+ "encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
145
+ "encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
146
+ "encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
147
+ "encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
148
+ "encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
149
+ "encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
150
+ "encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
151
+ "encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
152
+ "encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
153
+ "encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
154
+ "encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
155
+ "encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
156
+ "encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
157
+ "encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
158
+ "encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
159
+ "encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
160
+ "encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
161
+ "encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
162
+ "encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
163
+ "encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
164
+ "encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
165
+ "encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
166
+ "encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
167
+ "encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
168
+ "encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
169
+ "encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
170
+ "encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
171
+ "encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
172
+ "encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
173
+ "encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
174
+ "encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
175
+ "encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
176
+ "encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
177
+ "encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
178
+ "encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
179
+ "encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
180
+ "encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
181
+ "encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
182
+ "encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
183
+ "encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
184
+ "encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
185
+ "encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
186
+ "encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
187
+ "encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
188
+ "encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
189
+ "encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
190
+ "encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
191
+ "encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
192
+ "encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
193
+ "encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
194
+ "encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
195
+ "encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
196
+ "encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
197
+ "encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
198
+ "encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
199
+ "encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
200
+ "encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
201
+ "encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
202
+ "encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
203
+ "encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
204
+ "encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
205
+ "encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
206
+ "encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
207
+ "encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
208
+ "encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
209
+ "encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
210
+ "encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
211
+ "encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
212
+ "encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
213
+ "encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
214
+ "encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
215
+ "encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
216
+ "encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
217
+ "encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
218
+ "encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
219
+ "encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
220
+ "encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
221
+ "encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
222
+ "encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
223
+ "encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
224
+ "shared.weight": "model-00001-of-00002.safetensors"
225
+ }
226
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "!",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer_2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_3/special_tokens_map.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": {
105
+ "content": "</s>",
106
+ "lstrip": false,
107
+ "normalized": false,
108
+ "rstrip": false,
109
+ "single_word": false
110
+ },
111
+ "pad_token": {
112
+ "content": "<pad>",
113
+ "lstrip": false,
114
+ "normalized": false,
115
+ "rstrip": false,
116
+ "single_word": false
117
+ },
118
+ "unk_token": {
119
+ "content": "<unk>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false
124
+ }
125
+ }
tokenizer_3/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer_3/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_3/tokenizer_config.json ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": true,
31
+ "normalized": false,
32
+ "rstrip": true,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": true,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": true,
47
+ "normalized": false,
48
+ "rstrip": true,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": true,
55
+ "normalized": false,
56
+ "rstrip": true,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": true,
63
+ "normalized": false,
64
+ "rstrip": true,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": true,
71
+ "normalized": false,
72
+ "rstrip": true,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": true,
79
+ "normalized": false,
80
+ "rstrip": true,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": true,
87
+ "normalized": false,
88
+ "rstrip": true,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": true,
95
+ "normalized": false,
96
+ "rstrip": true,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": true,
103
+ "normalized": false,
104
+ "rstrip": true,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": true,
111
+ "normalized": false,
112
+ "rstrip": true,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": true,
119
+ "normalized": false,
120
+ "rstrip": true,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": true,
127
+ "normalized": false,
128
+ "rstrip": true,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": true,
135
+ "normalized": false,
136
+ "rstrip": true,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": true,
143
+ "normalized": false,
144
+ "rstrip": true,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": true,
151
+ "normalized": false,
152
+ "rstrip": true,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": true,
159
+ "normalized": false,
160
+ "rstrip": true,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": true,
167
+ "normalized": false,
168
+ "rstrip": true,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": true,
175
+ "normalized": false,
176
+ "rstrip": true,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": true,
183
+ "normalized": false,
184
+ "rstrip": true,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": true,
191
+ "normalized": false,
192
+ "rstrip": true,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": true,
199
+ "normalized": false,
200
+ "rstrip": true,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": true,
207
+ "normalized": false,
208
+ "rstrip": true,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": true,
215
+ "normalized": false,
216
+ "rstrip": true,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": true,
223
+ "normalized": false,
224
+ "rstrip": true,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": true,
231
+ "normalized": false,
232
+ "rstrip": true,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": true,
239
+ "normalized": false,
240
+ "rstrip": true,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": true,
247
+ "normalized": false,
248
+ "rstrip": true,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": true,
255
+ "normalized": false,
256
+ "rstrip": true,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": true,
263
+ "normalized": false,
264
+ "rstrip": true,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": true,
271
+ "normalized": false,
272
+ "rstrip": true,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": true,
279
+ "normalized": false,
280
+ "rstrip": true,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": true,
287
+ "normalized": false,
288
+ "rstrip": true,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": true,
295
+ "normalized": false,
296
+ "rstrip": true,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": true,
303
+ "normalized": false,
304
+ "rstrip": true,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": true,
311
+ "normalized": false,
312
+ "rstrip": true,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": true,
319
+ "normalized": false,
320
+ "rstrip": true,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": true,
327
+ "normalized": false,
328
+ "rstrip": true,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": true,
335
+ "normalized": false,
336
+ "rstrip": true,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": true,
343
+ "normalized": false,
344
+ "rstrip": true,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": true,
351
+ "normalized": false,
352
+ "rstrip": true,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": true,
359
+ "normalized": false,
360
+ "rstrip": true,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": true,
367
+ "normalized": false,
368
+ "rstrip": true,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": true,
375
+ "normalized": false,
376
+ "rstrip": true,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": true,
383
+ "normalized": false,
384
+ "rstrip": true,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": true,
391
+ "normalized": false,
392
+ "rstrip": true,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": true,
399
+ "normalized": false,
400
+ "rstrip": true,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": true,
407
+ "normalized": false,
408
+ "rstrip": true,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": true,
415
+ "normalized": false,
416
+ "rstrip": true,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": true,
423
+ "normalized": false,
424
+ "rstrip": true,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": true,
431
+ "normalized": false,
432
+ "rstrip": true,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": true,
439
+ "normalized": false,
440
+ "rstrip": true,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": true,
447
+ "normalized": false,
448
+ "rstrip": true,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": true,
455
+ "normalized": false,
456
+ "rstrip": true,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": true,
463
+ "normalized": false,
464
+ "rstrip": true,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": true,
471
+ "normalized": false,
472
+ "rstrip": true,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": true,
479
+ "normalized": false,
480
+ "rstrip": true,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": true,
487
+ "normalized": false,
488
+ "rstrip": true,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": true,
495
+ "normalized": false,
496
+ "rstrip": true,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": true,
503
+ "normalized": false,
504
+ "rstrip": true,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": true,
511
+ "normalized": false,
512
+ "rstrip": true,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": true,
519
+ "normalized": false,
520
+ "rstrip": true,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": true,
527
+ "normalized": false,
528
+ "rstrip": true,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": true,
535
+ "normalized": false,
536
+ "rstrip": true,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": true,
543
+ "normalized": false,
544
+ "rstrip": true,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": true,
551
+ "normalized": false,
552
+ "rstrip": true,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": true,
559
+ "normalized": false,
560
+ "rstrip": true,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": true,
567
+ "normalized": false,
568
+ "rstrip": true,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": true,
575
+ "normalized": false,
576
+ "rstrip": true,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": true,
583
+ "normalized": false,
584
+ "rstrip": true,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": true,
591
+ "normalized": false,
592
+ "rstrip": true,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": true,
599
+ "normalized": false,
600
+ "rstrip": true,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": true,
607
+ "normalized": false,
608
+ "rstrip": true,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": true,
615
+ "normalized": false,
616
+ "rstrip": true,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": true,
623
+ "normalized": false,
624
+ "rstrip": true,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": true,
631
+ "normalized": false,
632
+ "rstrip": true,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": true,
639
+ "normalized": false,
640
+ "rstrip": true,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": true,
647
+ "normalized": false,
648
+ "rstrip": true,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": true,
655
+ "normalized": false,
656
+ "rstrip": true,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": true,
663
+ "normalized": false,
664
+ "rstrip": true,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": true,
671
+ "normalized": false,
672
+ "rstrip": true,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": true,
679
+ "normalized": false,
680
+ "rstrip": true,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": true,
687
+ "normalized": false,
688
+ "rstrip": true,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": true,
695
+ "normalized": false,
696
+ "rstrip": true,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": true,
703
+ "normalized": false,
704
+ "rstrip": true,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": true,
711
+ "normalized": false,
712
+ "rstrip": true,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": true,
719
+ "normalized": false,
720
+ "rstrip": true,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": true,
727
+ "normalized": false,
728
+ "rstrip": true,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": true,
735
+ "normalized": false,
736
+ "rstrip": true,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": true,
743
+ "normalized": false,
744
+ "rstrip": true,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": true,
751
+ "normalized": false,
752
+ "rstrip": true,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": true,
759
+ "normalized": false,
760
+ "rstrip": true,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": true,
767
+ "normalized": false,
768
+ "rstrip": true,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": true,
775
+ "normalized": false,
776
+ "rstrip": true,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": true,
783
+ "normalized": false,
784
+ "rstrip": true,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": true,
791
+ "normalized": false,
792
+ "rstrip": true,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": true,
799
+ "normalized": false,
800
+ "rstrip": true,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": true,
807
+ "normalized": false,
808
+ "rstrip": true,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": true,
815
+ "normalized": false,
816
+ "rstrip": true,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": true,
823
+ "normalized": false,
824
+ "rstrip": true,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": true,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "legacy": true,
935
+ "model_max_length": 512,
936
+ "pad_token": "<pad>",
937
+ "sp_model_kwargs": {},
938
+ "tokenizer_class": "T5Tokenizer",
939
+ "unk_token": "<unk>"
940
+ }
transformer/forward/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SD3Transformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "attention_head_dim": 64,
5
+ "caption_projection_dim": 1536,
6
+ "dual_attention_layers": [
7
+ 0,
8
+ 1,
9
+ 2,
10
+ 3,
11
+ 4,
12
+ 5,
13
+ 6,
14
+ 7,
15
+ 8,
16
+ 9,
17
+ 10,
18
+ 11,
19
+ 12
20
+ ],
21
+ "in_channels": 96,
22
+ "joint_attention_dim": 4096,
23
+ "num_attention_heads": 24,
24
+ "num_layers": 24,
25
+ "out_channels": 16,
26
+ "patch_size": 2,
27
+ "pooled_projection_dim": 2048,
28
+ "pos_embed_max_size": 384,
29
+ "qk_norm": "rms_norm",
30
+ "sample_size": 128
31
+ }
transformer/forward/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1f917f113024a1a10ad868d578d522639296062f937e0f7f8b8b8b31ec9de38
3
+ size 9880726944
transformer/forward/lora/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7947c779fa0b3bba7124a21a5ffa29507ca09cef45ffd8e84f6a8eb738afa7fc
3
+ size 75154632
transformer/inverse-512/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "IntrinsicWeatherSD3Transformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "attention_head_dim": 64,
5
+ "caption_projection_dim": 1536,
6
+ "dual_attention_layers": [
7
+ 0,
8
+ 1,
9
+ 2,
10
+ 3,
11
+ 4,
12
+ 5,
13
+ 6,
14
+ 7,
15
+ 8,
16
+ 9,
17
+ 10,
18
+ 11,
19
+ 12
20
+ ],
21
+ "in_channels": 32,
22
+ "joint_attention_dim": 4096,
23
+ "num_attention_heads": 24,
24
+ "num_layers": 24,
25
+ "out_channels": 16,
26
+ "patch_size": 2,
27
+ "pooled_projection_dim": 2048,
28
+ "pos_embed_max_size": 384,
29
+ "qk_norm": "rms_norm",
30
+ "sample_size": 128
31
+ }
transformer/inverse-512/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cdcd40fc9770b2397b533ff4fd3132bfee99990ae59f31f22895fbdae76d797
3
+ size 9879154080
transformer/inverse-512/transformer_intrinsic_weather.py ADDED
@@ -0,0 +1,1527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch import nn as torch_nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
27
+ from diffusers.models.attention import FeedForward, JointTransformerBlock, _chunked_feed_forward
28
+ from diffusers.models.attention_processor import (
29
+ Attention,
30
+ AttentionProcessor,
31
+ AttnProcessor,
32
+ AttnProcessor2_0,
33
+ FusedJointAttnProcessor2_0,
34
+ JointAttnProcessor2_0,
35
+ SpatialNorm,
36
+ )
37
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed, SinusoidalPositionalEmbedding
38
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
39
+ from diffusers.models.modeling_utils import ModelMixin
40
+ from diffusers.models.normalization import (
41
+ AdaLayerNorm,
42
+ AdaLayerNormContinuous,
43
+ AdaLayerNormZero,
44
+ RMSNorm,
45
+ SD35AdaLayerNormZeroX,
46
+ )
47
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
48
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ class MapAwareAttention(nn.Module):
54
+ r"""
55
+ A cross attention layer.
56
+
57
+ Parameters:
58
+ query_dim (`int`):
59
+ The number of channels in the query.
60
+ cross_attention_dim (`int`, *optional*):
61
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
62
+ heads (`int`, *optional*, defaults to 8):
63
+ The number of heads to use for multi-head attention.
64
+ kv_heads (`int`, *optional*, defaults to `None`):
65
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
66
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
67
+ Query Attention (MQA) otherwise GQA is used.
68
+ dim_head (`int`, *optional*, defaults to 64):
69
+ The number of channels in each head.
70
+ dropout (`float`, *optional*, defaults to 0.0):
71
+ The dropout probability to use.
72
+ bias (`bool`, *optional*, defaults to False):
73
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
74
+ upcast_attention (`bool`, *optional*, defaults to False):
75
+ Set to `True` to upcast the attention computation to `float32`.
76
+ upcast_softmax (`bool`, *optional*, defaults to False):
77
+ Set to `True` to upcast the softmax computation to `float32`.
78
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
79
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
80
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
81
+ The number of groups to use for the group norm in the cross attention.
82
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
83
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
84
+ norm_num_groups (`int`, *optional*, defaults to `None`):
85
+ The number of groups to use for the group norm in the attention.
86
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
87
+ The number of channels to use for the spatial normalization.
88
+ out_bias (`bool`, *optional*, defaults to `True`):
89
+ Set to `True` to use a bias in the output linear layer.
90
+ scale_qk (`bool`, *optional*, defaults to `True`):
91
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
92
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
93
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
94
+ `added_kv_proj_dim` is not `None`.
95
+ eps (`float`, *optional*, defaults to 1e-5):
96
+ An additional value added to the denominator in group normalization that is used for numerical stability.
97
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
98
+ A factor to rescale the output by dividing it with this value.
99
+ residual_connection (`bool`, *optional*, defaults to `False`):
100
+ Set to `True` to add the residual connection to the output.
101
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
102
+ Set to `True` if the attention block is loaded from a deprecated state dict.
103
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
104
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
105
+ `AttnProcessor` otherwise.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ query_dim: int,
111
+ cross_attention_dim: Optional[int] = None,
112
+ heads: int = 8,
113
+ kv_heads: Optional[int] = None,
114
+ dim_head: int = 64,
115
+ dropout: float = 0.0,
116
+ bias: bool = False,
117
+ upcast_attention: bool = False,
118
+ upcast_softmax: bool = False,
119
+ cross_attention_norm: Optional[str] = None,
120
+ cross_attention_norm_num_groups: int = 32,
121
+ qk_norm: Optional[str] = None,
122
+ added_kv_proj_dim: Optional[int] = None,
123
+ added_proj_bias: Optional[bool] = True,
124
+ norm_num_groups: Optional[int] = None,
125
+ spatial_norm_dim: Optional[int] = None,
126
+ out_bias: bool = True,
127
+ scale_qk: bool = True,
128
+ only_cross_attention: bool = False,
129
+ eps: float = 1e-5,
130
+ rescale_output_factor: float = 1.0,
131
+ residual_connection: bool = False,
132
+ _from_deprecated_attn_block: bool = False,
133
+ processor: Optional["AttnProcessor"] = None,
134
+ out_dim: int = None,
135
+ out_context_dim: int = None,
136
+ context_pre_only=None,
137
+ pre_only=False,
138
+ elementwise_affine: bool = True,
139
+ is_causal: bool = False,
140
+ ):
141
+ super().__init__()
142
+
143
+ # To prevent circular import.
144
+ from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm
145
+
146
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
147
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
148
+ self.query_dim = query_dim
149
+ self.use_bias = bias
150
+ self.is_cross_attention = cross_attention_dim is not None
151
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
152
+ self.upcast_attention = upcast_attention
153
+ self.upcast_softmax = upcast_softmax
154
+ self.rescale_output_factor = rescale_output_factor
155
+ self.residual_connection = residual_connection
156
+ self.dropout = dropout
157
+ self.fused_projections = False
158
+ self.out_dim = out_dim if out_dim is not None else query_dim
159
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
160
+ self.context_pre_only = context_pre_only
161
+ self.pre_only = pre_only
162
+ self.is_causal = is_causal
163
+
164
+ # we make use of this private variable to know whether this class is loaded
165
+ # with an deprecated state dict so that we can convert it on the fly
166
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
167
+
168
+ self.scale_qk = scale_qk
169
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
170
+
171
+ self.heads = out_dim // dim_head if out_dim is not None else heads
172
+ # for slice_size > 0 the attention score computation
173
+ # is split across the batch axis to save memory
174
+ # You can set slice_size with `set_attention_slice`
175
+ self.sliceable_head_dim = heads
176
+
177
+ self.added_kv_proj_dim = added_kv_proj_dim
178
+ self.only_cross_attention = only_cross_attention
179
+
180
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
181
+ raise ValueError(
182
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
183
+ )
184
+
185
+ if norm_num_groups is not None:
186
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
187
+ else:
188
+ self.group_norm = None
189
+
190
+ if spatial_norm_dim is not None:
191
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
192
+ else:
193
+ self.spatial_norm = None
194
+
195
+ if qk_norm is None:
196
+ self.norm_q = None
197
+ self.norm_k = None
198
+ elif qk_norm == "layer_norm":
199
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
200
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
201
+ elif qk_norm == "fp32_layer_norm":
202
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
203
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
204
+ elif qk_norm == "layer_norm_across_heads":
205
+ # Lumina applies qk norm across all heads
206
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
207
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
208
+ elif qk_norm == "rms_norm":
209
+ self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
210
+ self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
211
+ elif qk_norm == "rms_norm_across_heads":
212
+ # LTX applies qk norm across all heads
213
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
214
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
215
+ elif qk_norm == "l2":
216
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
217
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
218
+ else:
219
+ raise ValueError(
220
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
221
+ )
222
+
223
+ if cross_attention_norm is None:
224
+ self.norm_cross = None
225
+ elif cross_attention_norm == "layer_norm":
226
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
227
+ elif cross_attention_norm == "group_norm":
228
+ if self.added_kv_proj_dim is not None:
229
+ # The given `encoder_hidden_states` are initially of shape
230
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
231
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
232
+ # before the projection, so we need to use `added_kv_proj_dim` as
233
+ # the number of channels for the group norm.
234
+ norm_cross_num_channels = added_kv_proj_dim
235
+ else:
236
+ norm_cross_num_channels = self.cross_attention_dim
237
+
238
+ self.norm_cross = nn.GroupNorm(
239
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
240
+ )
241
+ else:
242
+ raise ValueError(
243
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
244
+ )
245
+
246
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
247
+
248
+ if not self.only_cross_attention:
249
+ # only relevant for the `AddedKVProcessor` classes
250
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
251
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
252
+ else:
253
+ self.to_k = None
254
+ self.to_v = None
255
+
256
+ self.added_proj_bias = added_proj_bias
257
+ if self.added_kv_proj_dim is not None:
258
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
259
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
260
+ if self.context_pre_only is not None:
261
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
262
+ else:
263
+ self.add_q_proj = None
264
+ self.add_k_proj = None
265
+ self.add_v_proj = None
266
+
267
+ if not self.pre_only:
268
+ self.to_out = nn.ModuleList([])
269
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
270
+ self.to_out.append(nn.Dropout(dropout))
271
+ else:
272
+ self.to_out = None
273
+
274
+ if self.context_pre_only is not None and not self.context_pre_only:
275
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
276
+ else:
277
+ self.to_add_out = None
278
+
279
+ if qk_norm is not None and added_kv_proj_dim is not None:
280
+ if qk_norm == "layer_norm":
281
+ self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
282
+ self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
283
+ elif qk_norm == "fp32_layer_norm":
284
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
285
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
286
+ elif qk_norm == "rms_norm":
287
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
288
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
289
+ elif qk_norm == "rms_norm_across_heads":
290
+ # Wan applies qk norm across all heads
291
+ # Wan also doesn't apply a q norm
292
+ self.norm_added_q = None
293
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
294
+ else:
295
+ raise ValueError(
296
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
297
+ )
298
+ else:
299
+ self.norm_added_q = None
300
+ self.norm_added_k = None
301
+
302
+ # set attention processor
303
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
304
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
305
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
306
+ if processor is None:
307
+ processor = (
308
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
309
+ )
310
+ self.set_processor(processor)
311
+
312
+ # def set_use_xla_flash_attention(
313
+ # self,
314
+ # use_xla_flash_attention: bool,
315
+ # partition_spec: Optional[Tuple[Optional[str], ...]] = None,
316
+ # is_flux=False,
317
+ # ) -> None:
318
+ # r"""
319
+ # Set whether to use xla flash attention from `torch_xla` or not.
320
+
321
+ # Args:
322
+ # use_xla_flash_attention (`bool`):
323
+ # Whether to use pallas flash attention kernel from `torch_xla` or not.
324
+ # partition_spec (`Tuple[]`, *optional*):
325
+ # Specify the partition specification if using SPMD. Otherwise None.
326
+ # """
327
+ # if use_xla_flash_attention:
328
+ # if not is_torch_xla_available:
329
+ # raise "torch_xla is not available"
330
+ # elif is_torch_xla_version("<", "2.3"):
331
+ # raise "flash attention pallas kernel is supported from torch_xla version 2.3"
332
+ # elif is_spmd() and is_torch_xla_version("<", "2.4"):
333
+ # raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
334
+ # else:
335
+ # if is_flux:
336
+ # processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
337
+ # else:
338
+ # processor = XLAFlashAttnProcessor2_0(partition_spec)
339
+ # else:
340
+ # processor = (
341
+ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
342
+ # )
343
+ # self.set_processor(processor)
344
+
345
+ # def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
346
+ # r"""
347
+ # Set whether to use npu flash attention from `torch_npu` or not.
348
+
349
+ # """
350
+ # if use_npu_flash_attention:
351
+ # processor = AttnProcessorNPU()
352
+ # else:
353
+ # # set attention processor
354
+ # # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
355
+ # # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
356
+ # # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
357
+ # processor = (
358
+ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
359
+ # )
360
+ # self.set_processor(processor)
361
+
362
+ # def set_use_memory_efficient_attention_xformers(
363
+ # self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
364
+ # ) -> None:
365
+ # r"""
366
+ # Set whether to use memory efficient attention from `xformers` or not.
367
+
368
+ # Args:
369
+ # use_memory_efficient_attention_xformers (`bool`):
370
+ # Whether to use memory efficient attention from `xformers` or not.
371
+ # attention_op (`Callable`, *optional*):
372
+ # The attention operation to use. Defaults to `None` which uses the default attention operation from
373
+ # `xformers`.
374
+ # """
375
+ # is_custom_diffusion = hasattr(self, "processor") and isinstance(
376
+ # self.processor,
377
+ # (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
378
+ # )
379
+ # is_added_kv_processor = hasattr(self, "processor") and isinstance(
380
+ # self.processor,
381
+ # (
382
+ # AttnAddedKVProcessor,
383
+ # AttnAddedKVProcessor2_0,
384
+ # SlicedAttnAddedKVProcessor,
385
+ # XFormersAttnAddedKVProcessor,
386
+ # ),
387
+ # )
388
+ # is_ip_adapter = hasattr(self, "processor") and isinstance(
389
+ # self.processor,
390
+ # (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
391
+ # )
392
+ # is_joint_processor = hasattr(self, "processor") and isinstance(
393
+ # self.processor,
394
+ # (
395
+ # JointAttnProcessor2_0,
396
+ # XFormersJointAttnProcessor,
397
+ # ),
398
+ # )
399
+
400
+ # if use_memory_efficient_attention_xformers:
401
+ # if is_added_kv_processor and is_custom_diffusion:
402
+ # raise NotImplementedError(
403
+ # f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
404
+ # )
405
+ # if not is_xformers_available():
406
+ # raise ModuleNotFoundError(
407
+ # (
408
+ # "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
409
+ # " xformers"
410
+ # ),
411
+ # name="xformers",
412
+ # )
413
+ # elif not torch.cuda.is_available():
414
+ # raise ValueError(
415
+ # "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
416
+ # " only available for GPU "
417
+ # )
418
+ # else:
419
+ # try:
420
+ # # Make sure we can run the memory efficient attention
421
+ # dtype = None
422
+ # if attention_op is not None:
423
+ # op_fw, op_bw = attention_op
424
+ # dtype, *_ = op_fw.SUPPORTED_DTYPES
425
+ # q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
426
+ # _ = xformers.ops.memory_efficient_attention(q, q, q)
427
+ # except Exception as e:
428
+ # raise e
429
+
430
+ # if is_custom_diffusion:
431
+ # processor = CustomDiffusionXFormersAttnProcessor(
432
+ # train_kv=self.processor.train_kv,
433
+ # train_q_out=self.processor.train_q_out,
434
+ # hidden_size=self.processor.hidden_size,
435
+ # cross_attention_dim=self.processor.cross_attention_dim,
436
+ # attention_op=attention_op,
437
+ # )
438
+ # processor.load_state_dict(self.processor.state_dict())
439
+ # if hasattr(self.processor, "to_k_custom_diffusion"):
440
+ # processor.to(self.processor.to_k_custom_diffusion.weight.device)
441
+ # elif is_added_kv_processor:
442
+ # # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
443
+ # # which uses this type of cross attention ONLY because the attention mask of format
444
+ # # [0, ..., -10.000, ..., 0, ...,] is not supported
445
+ # # throw warning
446
+ # logger.info(
447
+ # "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
448
+ # )
449
+ # processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
450
+ # elif is_ip_adapter:
451
+ # processor = IPAdapterXFormersAttnProcessor(
452
+ # hidden_size=self.processor.hidden_size,
453
+ # cross_attention_dim=self.processor.cross_attention_dim,
454
+ # num_tokens=self.processor.num_tokens,
455
+ # scale=self.processor.scale,
456
+ # attention_op=attention_op,
457
+ # )
458
+ # processor.load_state_dict(self.processor.state_dict())
459
+ # if hasattr(self.processor, "to_k_ip"):
460
+ # processor.to(
461
+ # device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
462
+ # )
463
+ # elif is_joint_processor:
464
+ # processor = XFormersJointAttnProcessor(attention_op=attention_op)
465
+ # else:
466
+ # processor = XFormersAttnProcessor(attention_op=attention_op)
467
+ # else:
468
+ # if is_custom_diffusion:
469
+ # attn_processor_class = (
470
+ # CustomDiffusionAttnProcessor2_0
471
+ # if hasattr(F, "scaled_dot_product_attention")
472
+ # else CustomDiffusionAttnProcessor
473
+ # )
474
+ # processor = attn_processor_class(
475
+ # train_kv=self.processor.train_kv,
476
+ # train_q_out=self.processor.train_q_out,
477
+ # hidden_size=self.processor.hidden_size,
478
+ # cross_attention_dim=self.processor.cross_attention_dim,
479
+ # )
480
+ # processor.load_state_dict(self.processor.state_dict())
481
+ # if hasattr(self.processor, "to_k_custom_diffusion"):
482
+ # processor.to(self.processor.to_k_custom_diffusion.weight.device)
483
+ # elif is_ip_adapter:
484
+ # processor = IPAdapterAttnProcessor2_0(
485
+ # hidden_size=self.processor.hidden_size,
486
+ # cross_attention_dim=self.processor.cross_attention_dim,
487
+ # num_tokens=self.processor.num_tokens,
488
+ # scale=self.processor.scale,
489
+ # )
490
+ # processor.load_state_dict(self.processor.state_dict())
491
+ # if hasattr(self.processor, "to_k_ip"):
492
+ # processor.to(
493
+ # device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
494
+ # )
495
+ # else:
496
+ # # set attention processor
497
+ # # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
498
+ # # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
499
+ # # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
500
+ # processor = (
501
+ # AttnProcessor2_0()
502
+ # if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
503
+ # else AttnProcessor()
504
+ # )
505
+
506
+ # self.set_processor(processor)
507
+
508
+ # def set_attention_slice(self, slice_size: int) -> None:
509
+ # r"""
510
+ # Set the slice size for attention computation.
511
+
512
+ # Args:
513
+ # slice_size (`int`):
514
+ # The slice size for attention computation.
515
+ # """
516
+ # if slice_size is not None and slice_size > self.sliceable_head_dim:
517
+ # raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
518
+
519
+ # if slice_size is not None and self.added_kv_proj_dim is not None:
520
+ # processor = SlicedAttnAddedKVProcessor(slice_size)
521
+ # elif slice_size is not None:
522
+ # processor = SlicedAttnProcessor(slice_size)
523
+ # elif self.added_kv_proj_dim is not None:
524
+ # processor = AttnAddedKVProcessor()
525
+ # else:
526
+ # # set attention processor
527
+ # # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
528
+ # # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
529
+ # # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
530
+ # processor = (
531
+ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
532
+ # )
533
+
534
+ # self.set_processor(processor)
535
+
536
+ def set_processor(self, processor: "AttnProcessor") -> None:
537
+ r"""
538
+ Set the attention processor to use.
539
+
540
+ Args:
541
+ processor (`AttnProcessor`):
542
+ The attention processor to use.
543
+ """
544
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
545
+ # pop `processor` from `self._modules`
546
+ if (
547
+ hasattr(self, "processor")
548
+ and isinstance(self.processor, torch.nn.Module)
549
+ and not isinstance(processor, torch.nn.Module)
550
+ ):
551
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
552
+ self._modules.pop("processor")
553
+
554
+ self.processor = processor
555
+
556
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
557
+ r"""
558
+ Get the attention processor in use.
559
+
560
+ Args:
561
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
562
+ Set to `True` to return the deprecated LoRA attention processor.
563
+
564
+ Returns:
565
+ "AttentionProcessor": The attention processor in use.
566
+ """
567
+ if not return_deprecated_lora:
568
+ return self.processor
569
+
570
+ def forward(
571
+ self,
572
+ hidden_states: torch.Tensor,
573
+ encoder_hidden_states: Optional[torch.Tensor] = None,
574
+ map_aware_mask: Optional[torch.FloatTensor] = None,
575
+ **cross_attention_kwargs,
576
+ ) -> torch.Tensor:
577
+ r"""
578
+ The forward method of the `Attention` class.
579
+
580
+ Args:
581
+ hidden_states (`torch.Tensor`):
582
+ The hidden states of the query.
583
+ encoder_hidden_states (`torch.Tensor`, *optional*):
584
+ The hidden states of the encoder.
585
+ map_aware_mask (`torch.Tensor`, *optional*):
586
+ The attention mask to use. If `None`, no mask is applied.
587
+ **cross_attention_kwargs:
588
+ Additional keyword arguments to pass along to the cross attention.
589
+
590
+ Returns:
591
+ `torch.Tensor`: The output of the attention layer.
592
+ """
593
+ # The `Attention` class can call different attention processors / attention functions
594
+ # here we simply pass along all tensors to the selected processor class
595
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
596
+
597
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
598
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
599
+ unused_kwargs = [
600
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
601
+ ]
602
+ if len(unused_kwargs) > 0:
603
+ logger.warning(
604
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
605
+ )
606
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
607
+
608
+ return self.processor(
609
+ self,
610
+ hidden_states,
611
+ encoder_hidden_states=encoder_hidden_states,
612
+ attention_mask=map_aware_mask,
613
+ **cross_attention_kwargs,
614
+ )
615
+
616
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
617
+ r"""
618
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
619
+ is the number of heads initialized while constructing the `Attention` class.
620
+
621
+ Args:
622
+ tensor (`torch.Tensor`): The tensor to reshape.
623
+
624
+ Returns:
625
+ `torch.Tensor`: The reshaped tensor.
626
+ """
627
+ head_size = self.heads
628
+ batch_size, seq_len, dim = tensor.shape
629
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
630
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
631
+ return tensor
632
+
633
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
634
+ r"""
635
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
636
+ the number of heads initialized while constructing the `Attention` class.
637
+
638
+ Args:
639
+ tensor (`torch.Tensor`): The tensor to reshape.
640
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
641
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
642
+
643
+ Returns:
644
+ `torch.Tensor`: The reshaped tensor.
645
+ """
646
+ head_size = self.heads
647
+ if tensor.ndim == 3:
648
+ batch_size, seq_len, dim = tensor.shape
649
+ extra_dim = 1
650
+ else:
651
+ batch_size, extra_dim, seq_len, dim = tensor.shape
652
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
653
+ tensor = tensor.permute(0, 2, 1, 3)
654
+
655
+ if out_dim == 3:
656
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
657
+
658
+ return tensor
659
+
660
+ def get_attention_scores(
661
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
662
+ ) -> torch.Tensor:
663
+ r"""
664
+ Compute the attention scores.
665
+
666
+ Args:
667
+ query (`torch.Tensor`): The query tensor.
668
+ key (`torch.Tensor`): The key tensor.
669
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
670
+
671
+ Returns:
672
+ `torch.Tensor`: The attention probabilities/scores.
673
+ """
674
+ dtype = query.dtype
675
+ if self.upcast_attention:
676
+ query = query.float()
677
+ key = key.float()
678
+
679
+ if attention_mask is None:
680
+ baddbmm_input = torch.empty(
681
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
682
+ )
683
+ beta = 0
684
+ else:
685
+ baddbmm_input = attention_mask
686
+ beta = 1
687
+
688
+ attention_scores = torch.baddbmm(
689
+ baddbmm_input,
690
+ query,
691
+ key.transpose(-1, -2),
692
+ beta=beta,
693
+ alpha=self.scale,
694
+ )
695
+ del baddbmm_input
696
+
697
+ if self.upcast_softmax:
698
+ attention_scores = attention_scores.float()
699
+
700
+ attention_probs = attention_scores.softmax(dim=-1)
701
+ del attention_scores
702
+
703
+ attention_probs = attention_probs.to(dtype)
704
+
705
+ return attention_probs
706
+
707
+ def prepare_attention_mask(
708
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
709
+ ) -> torch.Tensor:
710
+ r"""
711
+ Prepare the attention mask for the attention computation.
712
+
713
+ Args:
714
+ attention_mask (`torch.Tensor`):
715
+ The attention mask to prepare.
716
+ target_length (`int`):
717
+ The target length of the attention mask. This is the length of the attention mask after padding.
718
+ batch_size (`int`):
719
+ The batch size, which is used to repeat the attention mask.
720
+ out_dim (`int`, *optional*, defaults to `3`):
721
+ The output dimension of the attention mask. Can be either `3` or `4`.
722
+
723
+ Returns:
724
+ `torch.Tensor`: The prepared attention mask.
725
+ """
726
+ head_size = self.heads
727
+ if attention_mask is None:
728
+ return attention_mask
729
+
730
+ current_length: int = attention_mask.shape[-1]
731
+ if current_length != target_length:
732
+ if attention_mask.device.type == "mps":
733
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
734
+ # Instead, we can manually construct the padding tensor.
735
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
736
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
737
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
738
+ else:
739
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
740
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
741
+ # remaining_length: int = target_length - current_length
742
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
743
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
744
+
745
+ if out_dim == 3:
746
+ if attention_mask.shape[0] < batch_size * head_size:
747
+ attention_mask = attention_mask.repeat_interleave(
748
+ head_size, dim=0, output_size=attention_mask.shape[0] * head_size
749
+ )
750
+ elif out_dim == 4:
751
+ attention_mask = attention_mask.unsqueeze(1)
752
+ attention_mask = attention_mask.repeat_interleave(
753
+ head_size, dim=1, output_size=attention_mask.shape[1] * head_size
754
+ )
755
+
756
+ return attention_mask
757
+
758
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
759
+ r"""
760
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
761
+ `Attention` class.
762
+
763
+ Args:
764
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
765
+
766
+ Returns:
767
+ `torch.Tensor`: The normalized encoder hidden states.
768
+ """
769
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
770
+
771
+ if isinstance(self.norm_cross, nn.LayerNorm):
772
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
773
+ elif isinstance(self.norm_cross, nn.GroupNorm):
774
+ # Group norm norms along the channels dimension and expects
775
+ # input to be in the shape of (N, C, *). In this case, we want
776
+ # to norm along the hidden dimension, so we need to move
777
+ # (batch_size, sequence_length, hidden_size) ->
778
+ # (batch_size, hidden_size, sequence_length)
779
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
780
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
781
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
782
+ else:
783
+ assert False
784
+
785
+ return encoder_hidden_states
786
+
787
+ @torch.no_grad()
788
+ def fuse_projections(self, fuse=True):
789
+ device = self.to_q.weight.data.device
790
+ dtype = self.to_q.weight.data.dtype
791
+
792
+ if not self.is_cross_attention:
793
+ # fetch weight matrices.
794
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
795
+ in_features = concatenated_weights.shape[1]
796
+ out_features = concatenated_weights.shape[0]
797
+
798
+ # create a new single projection layer and copy over the weights.
799
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
800
+ self.to_qkv.weight.copy_(concatenated_weights)
801
+ if self.use_bias:
802
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
803
+ self.to_qkv.bias.copy_(concatenated_bias)
804
+
805
+ else:
806
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
807
+ in_features = concatenated_weights.shape[1]
808
+ out_features = concatenated_weights.shape[0]
809
+
810
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
811
+ self.to_kv.weight.copy_(concatenated_weights)
812
+ if self.use_bias:
813
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
814
+ self.to_kv.bias.copy_(concatenated_bias)
815
+
816
+ # handle added projections for SD3 and others.
817
+ if (
818
+ getattr(self, "add_q_proj", None) is not None
819
+ and getattr(self, "add_k_proj", None) is not None
820
+ and getattr(self, "add_v_proj", None) is not None
821
+ ):
822
+ concatenated_weights = torch.cat(
823
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
824
+ )
825
+ in_features = concatenated_weights.shape[1]
826
+ out_features = concatenated_weights.shape[0]
827
+
828
+ self.to_added_qkv = nn.Linear(
829
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
830
+ )
831
+ self.to_added_qkv.weight.copy_(concatenated_weights)
832
+ if self.added_proj_bias:
833
+ concatenated_bias = torch.cat(
834
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
835
+ )
836
+ self.to_added_qkv.bias.copy_(concatenated_bias)
837
+
838
+ self.fused_projections = fuse
839
+
840
+ class MapAwareAttnProcessor2_0:
841
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
842
+
843
+ def __init__(self):
844
+ if not hasattr(F, "scaled_dot_product_attention"):
845
+ raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
846
+
847
+ def __call__(
848
+ self,
849
+ attn: Attention,
850
+ hidden_states: torch.FloatTensor,
851
+ encoder_hidden_states: torch.FloatTensor = None,
852
+ attention_mask: Optional[torch.FloatTensor] = None,
853
+ *args,
854
+ **kwargs,
855
+ ) -> torch.FloatTensor:
856
+ # print("attention_mask: ", attention_mask)
857
+ residual = hidden_states
858
+
859
+ batch_size = hidden_states.shape[0]
860
+
861
+ # `sample` projections.
862
+ query = attn.to_q(hidden_states)
863
+ key = attn.to_k(hidden_states)
864
+ value = attn.to_v(hidden_states)
865
+
866
+ inner_dim = key.shape[-1]
867
+ head_dim = inner_dim // attn.heads
868
+
869
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
870
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
871
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
872
+
873
+ if attn.norm_q is not None:
874
+ query = attn.norm_q(query)
875
+ if attn.norm_k is not None:
876
+ key = attn.norm_k(key)
877
+
878
+ # `context` projections.
879
+ if encoder_hidden_states is not None:
880
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
881
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
882
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
883
+
884
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
885
+ batch_size, -1, attn.heads, head_dim
886
+ ).transpose(1, 2)
887
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
888
+ batch_size, -1, attn.heads, head_dim
889
+ ).transpose(1, 2)
890
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
891
+ batch_size, -1, attn.heads, head_dim
892
+ ).transpose(1, 2)
893
+
894
+ if attn.norm_added_q is not None:
895
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
896
+ if attn.norm_added_k is not None:
897
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
898
+
899
+ # print(f"image q: {query.shape}, image k: {key.shape}, image v: {value.shape}") # [B, 24, 1024, 64]
900
+ # print(f"text q: {encoder_hidden_states_query_proj.shape}, text k: {encoder_hidden_states_key_proj.shape}, text v: {encoder_hidden_states_value_proj.shape}")
901
+ # [B, 24, 154, 64]
902
+
903
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
904
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
905
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
906
+
907
+ # print(f"Joint - query shape: {query.shape}, key shape: {key.shape}, value shape: {value.shape}")
908
+ # [B, 24, 1178, 64]
909
+ map_aware_mask = attention_mask
910
+ else:
911
+ map_aware_mask = None
912
+ # print(
913
+ # "map_aware_mask:",
914
+ # None if map_aware_mask is None else (map_aware_mask.shape, map_aware_mask.dtype)
915
+ # )
916
+
917
+ # print("query: ", query.shape, query.dtype)
918
+ if map_aware_mask is not None:
919
+ map_aware_mask = map_aware_mask.to(query.dtype)
920
+
921
+
922
+ hidden_states = F.scaled_dot_product_attention(query, key, value,
923
+ attn_mask=map_aware_mask,
924
+ dropout_p=0.0,
925
+ is_causal=False)
926
+
927
+
928
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
929
+ hidden_states = hidden_states.to(query.dtype)
930
+
931
+ if encoder_hidden_states is not None:
932
+ # Split the attention outputs.
933
+ hidden_states, encoder_hidden_states = (
934
+ hidden_states[:, : residual.shape[1]],
935
+ hidden_states[:, residual.shape[1] :],
936
+ )
937
+ if not attn.context_pre_only:
938
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
939
+
940
+ # linear proj
941
+ hidden_states = attn.to_out[0](hidden_states)
942
+ # dropout
943
+ hidden_states = attn.to_out[1](hidden_states)
944
+
945
+ if encoder_hidden_states is not None:
946
+ return hidden_states, encoder_hidden_states
947
+ else:
948
+ return hidden_states
949
+
950
+ class MapAwareTransformerBlock(nn.Module):
951
+ r"""
952
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
953
+
954
+ Reference: https://huggingface.co/papers/2403.03206
955
+
956
+ Parameters:
957
+ dim (`int`): The number of channels in the input and output.
958
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
959
+ attention_head_dim (`int`): The number of channels in each head.
960
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
961
+ processing of `context` conditions.
962
+ """
963
+
964
+ def __init__(
965
+ self,
966
+ dim: int,
967
+ num_attention_heads: int,
968
+ attention_head_dim: int,
969
+ context_pre_only: bool = False,
970
+ qk_norm: Optional[str] = None,
971
+ use_dual_attention: bool = False,
972
+ ):
973
+ super().__init__()
974
+
975
+ self.use_dual_attention = use_dual_attention
976
+ self.context_pre_only = context_pre_only
977
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
978
+
979
+ if use_dual_attention:
980
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
981
+ else:
982
+ self.norm1 = AdaLayerNormZero(dim)
983
+
984
+ if context_norm_type == "ada_norm_continous":
985
+ self.norm1_context = AdaLayerNormContinuous(
986
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
987
+ )
988
+ elif context_norm_type == "ada_norm_zero":
989
+ self.norm1_context = AdaLayerNormZero(dim)
990
+ else:
991
+ raise ValueError(
992
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
993
+ )
994
+
995
+ # if hasattr(F, "scaled_dot_product_attention"):
996
+ # processor = JointAttnProcessor2_0()
997
+ # else:
998
+ # raise ValueError(
999
+ # "The current PyTorch version does not support the `scaled_dot_product_attention` function."
1000
+ # )
1001
+
1002
+ self.attn = MapAwareAttention(
1003
+ query_dim=dim,
1004
+ cross_attention_dim=None,
1005
+ added_kv_proj_dim=dim,
1006
+ dim_head=attention_head_dim,
1007
+ heads=num_attention_heads,
1008
+ out_dim=dim,
1009
+ context_pre_only=context_pre_only,
1010
+ bias=True,
1011
+ processor=MapAwareAttnProcessor2_0(),
1012
+ qk_norm=qk_norm,
1013
+ eps=1e-6,
1014
+ )
1015
+
1016
+ if use_dual_attention:
1017
+ self.attn2 = Attention(
1018
+ query_dim=dim,
1019
+ cross_attention_dim=None,
1020
+ dim_head=attention_head_dim,
1021
+ heads=num_attention_heads,
1022
+ out_dim=dim,
1023
+ bias=True,
1024
+ processor=JointAttnProcessor2_0(),
1025
+ qk_norm=qk_norm,
1026
+ eps=1e-6,
1027
+ )
1028
+ else:
1029
+ self.attn2 = None
1030
+
1031
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
1032
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
1033
+
1034
+ if not context_pre_only:
1035
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
1036
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
1037
+ else:
1038
+ self.norm2_context = None
1039
+ self.ff_context = None
1040
+
1041
+ # let chunk size default to None
1042
+ self._chunk_size = None
1043
+ self._chunk_dim = 0
1044
+
1045
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
1046
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
1047
+ # Sets chunk feed-forward
1048
+ self._chunk_size = chunk_size
1049
+ self._chunk_dim = dim
1050
+
1051
+ def forward(
1052
+ self,
1053
+ hidden_states: torch.FloatTensor,
1054
+ encoder_hidden_states: torch.FloatTensor,
1055
+ temb: torch.FloatTensor,
1056
+ map_aware_mask: Optional[torch.FloatTensor] = None,
1057
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1058
+ ):
1059
+ joint_attention_kwargs = joint_attention_kwargs or {}
1060
+ if self.use_dual_attention:
1061
+ # print(f"hidden_states: {type(hidden_states)}")
1062
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
1063
+ hidden_states, emb=temb
1064
+ )
1065
+ else:
1066
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
1067
+
1068
+ if self.context_pre_only:
1069
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
1070
+ else:
1071
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
1072
+ encoder_hidden_states, emb=temb
1073
+ )
1074
+
1075
+ # Attention.
1076
+ attn_output, context_attn_output = self.attn(
1077
+ hidden_states=norm_hidden_states,
1078
+ encoder_hidden_states=norm_encoder_hidden_states,
1079
+ map_aware_mask=map_aware_mask,
1080
+ **joint_attention_kwargs,
1081
+ )
1082
+
1083
+ # Process attention outputs for the `hidden_states`.
1084
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1085
+ hidden_states = hidden_states + attn_output
1086
+
1087
+ if self.use_dual_attention:
1088
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
1089
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
1090
+ hidden_states = hidden_states + attn_output2
1091
+
1092
+ norm_hidden_states = self.norm2(hidden_states)
1093
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1094
+ if self._chunk_size is not None:
1095
+ # "feed_forward_chunk_size" can be used to save memory
1096
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1097
+ else:
1098
+ ff_output = self.ff(norm_hidden_states)
1099
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1100
+
1101
+ hidden_states = hidden_states + ff_output
1102
+
1103
+ # Process attention outputs for the `encoder_hidden_states`.
1104
+ if self.context_pre_only:
1105
+ encoder_hidden_states = None
1106
+ else:
1107
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
1108
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
1109
+
1110
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
1111
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
1112
+ if self._chunk_size is not None:
1113
+ # "feed_forward_chunk_size" can be used to save memory
1114
+ context_ff_output = _chunked_feed_forward(
1115
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
1116
+ )
1117
+ else:
1118
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
1119
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
1120
+
1121
+ return encoder_hidden_states, hidden_states
1122
+
1123
+ @maybe_allow_in_graph
1124
+ class SD3SingleTransformerBlock(nn.Module):
1125
+ def __init__(
1126
+ self,
1127
+ dim: int,
1128
+ num_attention_heads: int,
1129
+ attention_head_dim: int,
1130
+ ):
1131
+ super().__init__()
1132
+
1133
+ self.norm1 = AdaLayerNormZero(dim)
1134
+ self.attn = Attention(
1135
+ query_dim=dim,
1136
+ dim_head=attention_head_dim,
1137
+ heads=num_attention_heads,
1138
+ out_dim=dim,
1139
+ bias=True,
1140
+ processor=JointAttnProcessor2_0(),
1141
+ eps=1e-6,
1142
+ )
1143
+
1144
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
1145
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
1146
+
1147
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
1148
+ # 1. Attention
1149
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
1150
+ attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
1151
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1152
+ hidden_states = hidden_states + attn_output
1153
+
1154
+ # 2. Feed Forward
1155
+ norm_hidden_states = self.norm2(hidden_states)
1156
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
1157
+ ff_output = self.ff(norm_hidden_states)
1158
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1159
+ hidden_states = hidden_states + ff_output
1160
+
1161
+ return hidden_states
1162
+
1163
+
1164
+ class IntrinsicWeatherSD3Transformer2DModel(
1165
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
1166
+ ):
1167
+ """
1168
+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
1169
+
1170
+ Parameters:
1171
+ sample_size (`int`, defaults to `128`):
1172
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
1173
+ position embeddings.
1174
+ patch_size (`int`, defaults to `2`):
1175
+ Patch size to turn the input data into small patches.
1176
+ in_channels (`int`, defaults to `16`):
1177
+ The number of latent channels in the input.
1178
+ num_layers (`int`, defaults to `18`):
1179
+ The number of layers of transformer blocks to use.
1180
+ attention_head_dim (`int`, defaults to `64`):
1181
+ The number of channels in each head.
1182
+ num_attention_heads (`int`, defaults to `18`):
1183
+ The number of heads to use for multi-head attention.
1184
+ joint_attention_dim (`int`, defaults to `4096`):
1185
+ The embedding dimension to use for joint text-image attention.
1186
+ caption_projection_dim (`int`, defaults to `1152`):
1187
+ The embedding dimension of caption embeddings.
1188
+ pooled_projection_dim (`int`, defaults to `2048`):
1189
+ The embedding dimension of pooled text projections.
1190
+ out_channels (`int`, defaults to `16`):
1191
+ The number of latent channels in the output.
1192
+ pos_embed_max_size (`int`, defaults to `96`):
1193
+ The maximum latent height/width of positional embeddings.
1194
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
1195
+ The number of dual-stream transformer blocks to use.
1196
+ qk_norm (`str`, *optional*, defaults to `None`):
1197
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
1198
+ """
1199
+
1200
+ _supports_gradient_checkpointing = True
1201
+ _no_split_modules = ["JointTransformerBlock"]
1202
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
1203
+
1204
+ @register_to_config
1205
+ def __init__(
1206
+ self,
1207
+ sample_size: int = 128,
1208
+ patch_size: int = 2,
1209
+ in_channels: int = 16,
1210
+ num_layers: int = 18,
1211
+ attention_head_dim: int = 64,
1212
+ num_attention_heads: int = 18,
1213
+ joint_attention_dim: int = 4096,
1214
+ caption_projection_dim: int = 1152,
1215
+ pooled_projection_dim: int = 2048,
1216
+ out_channels: int = 16,
1217
+ pos_embed_max_size: int = 96,
1218
+ dual_attention_layers: Tuple[
1219
+ int, ...
1220
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
1221
+ qk_norm: Optional[str] = None,
1222
+ ):
1223
+ super().__init__()
1224
+ self.out_channels = out_channels if out_channels is not None else in_channels
1225
+ self.inner_dim = num_attention_heads * attention_head_dim
1226
+
1227
+ self.pos_embed = PatchEmbed(
1228
+ height=sample_size,
1229
+ width=sample_size,
1230
+ patch_size=patch_size,
1231
+ in_channels=in_channels,
1232
+ embed_dim=self.inner_dim,
1233
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
1234
+ )
1235
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
1236
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
1237
+ )
1238
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
1239
+
1240
+ self.transformer_blocks = nn.ModuleList(
1241
+ [
1242
+ MapAwareTransformerBlock(
1243
+ dim=self.inner_dim,
1244
+ num_attention_heads=num_attention_heads,
1245
+ attention_head_dim=attention_head_dim,
1246
+ context_pre_only=i == num_layers - 1,
1247
+ qk_norm=qk_norm,
1248
+ use_dual_attention=True if i in dual_attention_layers else False,
1249
+ )
1250
+ for i in range(num_layers)
1251
+ ]
1252
+ )
1253
+
1254
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
1255
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
1256
+
1257
+ self.gradient_checkpointing = False
1258
+
1259
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
1260
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
1261
+ """
1262
+ Sets the attention processor to use [feed forward
1263
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
1264
+
1265
+ Parameters:
1266
+ chunk_size (`int`, *optional*):
1267
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
1268
+ over each tensor of dim=`dim`.
1269
+ dim (`int`, *optional*, defaults to `0`):
1270
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
1271
+ or dim=1 (sequence length).
1272
+ """
1273
+ if dim not in [0, 1]:
1274
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
1275
+
1276
+ # By default chunk size is 1
1277
+ chunk_size = chunk_size or 1
1278
+
1279
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
1280
+ if hasattr(module, "set_chunk_feed_forward"):
1281
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
1282
+
1283
+ for child in module.children():
1284
+ fn_recursive_feed_forward(child, chunk_size, dim)
1285
+
1286
+ for module in self.children():
1287
+ fn_recursive_feed_forward(module, chunk_size, dim)
1288
+
1289
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
1290
+ def disable_forward_chunking(self):
1291
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
1292
+ if hasattr(module, "set_chunk_feed_forward"):
1293
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
1294
+
1295
+ for child in module.children():
1296
+ fn_recursive_feed_forward(child, chunk_size, dim)
1297
+
1298
+ for module in self.children():
1299
+ fn_recursive_feed_forward(module, None, 0)
1300
+
1301
+ @property
1302
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1303
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
1304
+ r"""
1305
+ Returns:
1306
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1307
+ indexed by its weight name.
1308
+ """
1309
+ # set recursively
1310
+ processors = {}
1311
+
1312
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1313
+ if hasattr(module, "get_processor"):
1314
+ processors[f"{name}.processor"] = module.get_processor()
1315
+
1316
+ for sub_name, child in module.named_children():
1317
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1318
+
1319
+ return processors
1320
+
1321
+ for name, module in self.named_children():
1322
+ fn_recursive_add_processors(name, module, processors)
1323
+
1324
+ return processors
1325
+
1326
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1327
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1328
+ r"""
1329
+ Sets the attention processor to use to compute attention.
1330
+
1331
+ Parameters:
1332
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1333
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1334
+ for **all** `Attention` layers.
1335
+
1336
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1337
+ processor. This is strongly recommended when setting trainable attention processors.
1338
+
1339
+ """
1340
+ count = len(self.attn_processors.keys())
1341
+
1342
+ if isinstance(processor, dict) and len(processor) != count:
1343
+ raise ValueError(
1344
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1345
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1346
+ )
1347
+
1348
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1349
+ if hasattr(module, "set_processor"):
1350
+ if not isinstance(processor, dict):
1351
+ module.set_processor(processor)
1352
+ else:
1353
+ module.set_processor(processor.pop(f"{name}.processor"))
1354
+
1355
+ for sub_name, child in module.named_children():
1356
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1357
+
1358
+ for name, module in self.named_children():
1359
+ fn_recursive_attn_processor(name, module, processor)
1360
+
1361
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
1362
+ def fuse_qkv_projections(self):
1363
+ """
1364
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
1365
+ are fused. For cross-attention modules, key and value projection matrices are fused.
1366
+
1367
+ <Tip warning={true}>
1368
+
1369
+ This API is 🧪 experimental.
1370
+
1371
+ </Tip>
1372
+ """
1373
+ self.original_attn_processors = None
1374
+
1375
+ for _, attn_processor in self.attn_processors.items():
1376
+ if "Added" in str(attn_processor.__class__.__name__):
1377
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
1378
+
1379
+ self.original_attn_processors = self.attn_processors
1380
+
1381
+ for module in self.modules():
1382
+ if isinstance(module, Attention):
1383
+ module.fuse_projections(fuse=True)
1384
+
1385
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
1386
+
1387
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
1388
+ def unfuse_qkv_projections(self):
1389
+ """Disables the fused QKV projection if enabled.
1390
+
1391
+ <Tip warning={true}>
1392
+
1393
+ This API is 🧪 experimental.
1394
+
1395
+ </Tip>
1396
+
1397
+ """
1398
+ if self.original_attn_processors is not None:
1399
+ self.set_attn_processor(self.original_attn_processors)
1400
+
1401
+ def forward(
1402
+ self,
1403
+ hidden_states: torch.Tensor,
1404
+ encoder_hidden_states: torch.Tensor = None,
1405
+ pooled_projections: torch.Tensor = None,
1406
+ timestep: torch.LongTensor = None,
1407
+ block_controlnet_hidden_states: List = None,
1408
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1409
+ return_dict: bool = True,
1410
+ skip_layers: Optional[List[int]] = None,
1411
+ map_aware_mask: Optional[torch.FloatTensor] = None,
1412
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
1413
+ """
1414
+ The [`SD3Transformer2DModel`] forward method.
1415
+
1416
+ Args:
1417
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
1418
+ Input `hidden_states`.
1419
+ encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
1420
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
1421
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
1422
+ Embeddings projected from the embeddings of input conditions.
1423
+ timestep (`torch.LongTensor`):
1424
+ Used to indicate denoising step.
1425
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
1426
+ A list of tensors that if specified are added to the residuals of transformer blocks.
1427
+ joint_attention_kwargs (`dict`, *optional*):
1428
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1429
+ `self.processor` in
1430
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1431
+ return_dict (`bool`, *optional*, defaults to `True`):
1432
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
1433
+ tuple.
1434
+ skip_layers (`list` of `int`, *optional*):
1435
+ A list of layer indices to skip during the forward pass.
1436
+
1437
+ Returns:
1438
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
1439
+ `tuple` where the first element is the sample tensor.
1440
+ """
1441
+ if joint_attention_kwargs is not None:
1442
+ joint_attention_kwargs = joint_attention_kwargs.copy()
1443
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
1444
+ else:
1445
+ lora_scale = 1.0
1446
+
1447
+ if USE_PEFT_BACKEND:
1448
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1449
+ scale_lora_layers(self, lora_scale)
1450
+ else:
1451
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
1452
+ logger.warning(
1453
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
1454
+ )
1455
+
1456
+ height, width = hidden_states.shape[-2:]
1457
+
1458
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
1459
+ temb = self.time_text_embed(timestep, pooled_projections)
1460
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1461
+
1462
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
1463
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
1464
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
1465
+
1466
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
1467
+
1468
+ for index_block, block in enumerate(self.transformer_blocks):
1469
+ # print("index: ", index_block)
1470
+
1471
+ # Skip specified layers
1472
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
1473
+
1474
+ if index_block >= (self.config.num_layers // 2) and map_aware_mask is not None:
1475
+ current_mask = map_aware_mask.to(hidden_states.device)
1476
+ else:
1477
+ current_mask = None
1478
+
1479
+ # print("transformer: map_aware_mask:", current_mask.shape if current_mask is not None else None)
1480
+
1481
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
1482
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
1483
+ block,
1484
+ hidden_states,
1485
+ encoder_hidden_states,
1486
+ temb,
1487
+ current_mask,
1488
+ joint_attention_kwargs,
1489
+ )
1490
+ elif not is_skip:
1491
+ encoder_hidden_states, hidden_states = block(
1492
+ hidden_states=hidden_states,
1493
+ encoder_hidden_states=encoder_hidden_states,
1494
+ temb=temb,
1495
+ map_aware_mask=current_mask,
1496
+ joint_attention_kwargs=joint_attention_kwargs,
1497
+ )
1498
+
1499
+ # controlnet residual
1500
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
1501
+ interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
1502
+ hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
1503
+
1504
+ hidden_states = self.norm_out(hidden_states, temb)
1505
+ hidden_states = self.proj_out(hidden_states)
1506
+
1507
+ # unpatchify
1508
+ patch_size = self.config.patch_size
1509
+ height = height // patch_size
1510
+ width = width // patch_size
1511
+
1512
+ hidden_states = hidden_states.reshape(
1513
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
1514
+ )
1515
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
1516
+ output = hidden_states.reshape(
1517
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
1518
+ )
1519
+
1520
+ if USE_PEFT_BACKEND:
1521
+ # remove `lora_scale` from each PEFT layer
1522
+ unscale_lora_layers(self, lora_scale)
1523
+
1524
+ if not return_dict:
1525
+ return (output,)
1526
+
1527
+ return Transformer2DModelOutput(sample=output)
vae/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.29.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 16,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "norm_num_groups": 32,
24
+ "out_channels": 3,
25
+ "sample_size": 1024,
26
+ "scaling_factor": 1.5305,
27
+ "shift_factor": 0.0609,
28
+ "up_block_types": [
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D"
33
+ ],
34
+ "use_post_quant_conv": false,
35
+ "use_quant_conv": false
36
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9b67a279283625caee39d61eacb5324243848477b4eb535355eaaa8423d4e09
3
+ size 167666654