Instructions to use BiliSakura/IntrisicWeather-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/IntrisicWeather-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/IntrisicWeather-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- README.md +267 -0
- __init__.py +9 -0
- convert_forward_renderer.py +99 -0
- convert_inverse_renderer_512.py +118 -0
- dinov2/README.md +61 -0
- dinov2/config.json +24 -0
- dinov2/model.safetensors +3 -0
- dinov2/preprocessor_config.json +27 -0
- imaa/config.json +11 -0
- imaa/imaa.py +205 -0
- imaa/model.safetensors +3 -0
- model_index.json +39 -0
- pipeline_intrinsic_weather.py +486 -0
- pipeline_intrinsic_weather_forward.py +1191 -0
- pipeline_intrinsic_weather_inverse.py +1119 -0
- pipeline_utils.py +104 -0
- scheduler/scheduler_config.json +9 -0
- test_all_pipelines.py +141 -0
- text_encoder/config.json +24 -0
- text_encoder/model.safetensors +3 -0
- text_encoder_2/config.json +24 -0
- text_encoder_2/model.safetensors +3 -0
- text_encoder_3/config.json +31 -0
- text_encoder_3/model-00001-of-00002.safetensors +3 -0
- text_encoder_3/model-00002-of-00002.safetensors +3 -0
- text_encoder_3/model.safetensors.index.json +226 -0
- tokenizer/merges.txt +0 -0
- tokenizer/special_tokens_map.json +30 -0
- tokenizer/tokenizer_config.json +30 -0
- tokenizer/vocab.json +0 -0
- tokenizer_2/merges.txt +0 -0
- tokenizer_2/special_tokens_map.json +30 -0
- tokenizer_2/tokenizer_config.json +38 -0
- tokenizer_2/vocab.json +0 -0
- tokenizer_3/special_tokens_map.json +125 -0
- tokenizer_3/spiece.model +3 -0
- tokenizer_3/tokenizer.json +0 -0
- tokenizer_3/tokenizer_config.json +940 -0
- transformer/forward/config.json +31 -0
- transformer/forward/diffusion_pytorch_model.safetensors +3 -0
- transformer/forward/lora/pytorch_lora_weights.safetensors +3 -0
- transformer/inverse-512/config.json +31 -0
- transformer/inverse-512/diffusion_pytorch_model.safetensors +3 -0
- transformer/inverse-512/transformer_intrinsic_weather.py +1527 -0
- vae/config.json +36 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
README.md
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IntrinsicWeather (Diffusers)
|
| 2 |
+
|
| 3 |
+
Diffusers-format checkpoint for **[IntrinsicWeather: Controllable Weather Editing in Intrinsic Space](https://arxiv.org/pdf/2508.06982v6)** (CVPR 2026 Highlight).
|
| 4 |
+
|
| 5 |
+
This repo bundles inverse rendering, forward weather rendering, and the IMAA gating module into a single Hugging Face–compatible layout. Shared Stable Diffusion 3 components (VAE, text encoders, tokenizers, scheduler) are stored once; task-specific transformers live under `transformer/<variant>/`.
|
| 6 |
+
|
| 7 |
+
## Model layout
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
IntrisicWeather-diffusers/
|
| 11 |
+
├── dinov2/ # bundled DINOv2 weights (for IMAA / decomposition)
|
| 12 |
+
├── imaa/ # Intrinsic Map-Aware Attention weights
|
| 13 |
+
├── text_encoder/, text_encoder_2/, text_encoder_3/
|
| 14 |
+
├── tokenizer/, tokenizer_2/, tokenizer_3/
|
| 15 |
+
├── vae/, scheduler/
|
| 16 |
+
├── transformer/
|
| 17 |
+
│ ├── inverse-512/ # IntrinsicWeatherSD3Transformer2DModel (in_channels=32)
|
| 18 |
+
│ │ └── transformer_intrinsic_weather.py
|
| 19 |
+
│ └── forward/ # SD3Transformer2DModel (in_channels=96)
|
| 20 |
+
│ └── lora/ # forward-renderer LoRA (loaded by default)
|
| 21 |
+
├── pipeline_intrinsic_weather.py # unified: RGB → maps → weather RGB
|
| 22 |
+
├── pipeline_intrinsic_weather_inverse.py # inverse only
|
| 23 |
+
├── pipeline_intrinsic_weather_forward.py # forward only
|
| 24 |
+
├── pipeline_utils.py
|
| 25 |
+
├── model_index.json
|
| 26 |
+
├── convert_inverse_renderer_512.py
|
| 27 |
+
├── convert_forward_renderer.py
|
| 28 |
+
└── test_all_pipelines.py
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
| Component | Source | Notes |
|
| 32 |
+
|-----------|--------|-------|
|
| 33 |
+
| Inverse transformer | [GilgameshYX/InverseRenderer-512](https://huggingface.co/GilgameshYX/InverseRenderer-512) | 512×512 decomposition |
|
| 34 |
+
| Forward transformer + LoRA | [GilgameshYX/ForwardRenderer](https://huggingface.co/GilgameshYX/ForwardRenderer) | LoRA in `transformer/forward/lora/` |
|
| 35 |
+
| IMAA | InverseRenderer-512 `imaa.pth` | Required for map-aware inverse attention |
|
| 36 |
+
| SD3 shared weights | `stabilityai/stable-diffusion-3-medium-diffusers` | VAE + text encoders only |
|
| 37 |
+
| Transformer config | `stabilityai/stable-diffusion-3.5-medium` | Architecture template for weight loading |
|
| 38 |
+
|
| 39 |
+
## Requirements
|
| 40 |
+
|
| 41 |
+
- Python 3.10+
|
| 42 |
+
- CUDA GPU recommended (~20 GB VRAM for full end-to-end inference at 512×512)
|
| 43 |
+
- `torch`, `diffusers>=0.38`, `transformers`, `safetensors`, `torchvision`, `Pillow`
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install torch diffusers transformers safetensors torchvision pillow accelerate
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Quick start (end-to-end weather edit)
|
| 50 |
+
|
| 51 |
+
The unified pipeline decomposes an input RGB image into intrinsic maps, then renders a weather-conditioned result. **DINOv2** is required for decomposition (bundled under `dinov2/`, or use `facebook/dinov2-base` from Hugging Face).
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
from pathlib import Path
|
| 55 |
+
|
| 56 |
+
import torch
|
| 57 |
+
from PIL import Image
|
| 58 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 59 |
+
|
| 60 |
+
from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
|
| 61 |
+
|
| 62 |
+
repo_dir = Path(".").resolve() # path to this folder
|
| 63 |
+
device = "cuda"
|
| 64 |
+
dtype = torch.bfloat16
|
| 65 |
+
|
| 66 |
+
pipe = IntrinsicWeatherPipeline.from_pretrained(
|
| 67 |
+
repo_dir,
|
| 68 |
+
inverse_transformer_subfolder="inverse-512",
|
| 69 |
+
forward_transformer_subfolder="forward",
|
| 70 |
+
device=device,
|
| 71 |
+
local_files_only=True,
|
| 72 |
+
torch_dtype=dtype,
|
| 73 |
+
load_lora=True,
|
| 74 |
+
load_imaa=True,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
dino_path = repo_dir / "dinov2"
|
| 78 |
+
dino_processor = AutoImageProcessor.from_pretrained(dino_path, local_files_only=True)
|
| 79 |
+
dino_model = AutoModel.from_pretrained(dino_path, local_files_only=True).to(device)
|
| 80 |
+
dino_model.eval()
|
| 81 |
+
|
| 82 |
+
image = Image.open("input.png").convert("RGB")
|
| 83 |
+
result = pipe(
|
| 84 |
+
image=image,
|
| 85 |
+
weather="snowy", # rainy | sunny | snowy | foggy | overcast | night
|
| 86 |
+
dino_model=dino_model,
|
| 87 |
+
dino_processor=dino_processor,
|
| 88 |
+
image_size=512,
|
| 89 |
+
render_size=512,
|
| 90 |
+
num_inverse_steps=50,
|
| 91 |
+
num_forward_steps=50,
|
| 92 |
+
guidance_scale=6.0,
|
| 93 |
+
image_guidance_scale=1.5,
|
| 94 |
+
generator=torch.Generator(device=device).manual_seed(42),
|
| 95 |
+
)
|
| 96 |
+
result.images[0].save("output_snowy.png")
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Run from inside this directory (or add it to `PYTHONPATH`) so `pipeline_intrinsic_weather.py` and `imaa/` resolve correctly.
|
| 100 |
+
|
| 101 |
+
## Pipelines
|
| 102 |
+
|
| 103 |
+
### 1. `IntrinsicWeatherPipeline` (unified)
|
| 104 |
+
|
| 105 |
+
Full pipeline: **RGB → intrinsic maps → weather RGB**.
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
pipe = IntrinsicWeatherPipeline.from_pretrained(
|
| 109 |
+
repo_dir,
|
| 110 |
+
inverse_transformer_subfolder="inverse-512",
|
| 111 |
+
forward_transformer_subfolder="forward",
|
| 112 |
+
device="cuda",
|
| 113 |
+
torch_dtype=torch.bfloat16,
|
| 114 |
+
)
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
Useful kwargs:
|
| 118 |
+
|
| 119 |
+
| Argument | Default | Description |
|
| 120 |
+
|----------|---------|-------------|
|
| 121 |
+
| `inverse_transformer_subfolder` | `"inverse-512"` | Inverse transformer under `transformer/` |
|
| 122 |
+
| `forward_transformer_subfolder` | `"forward"` | Forward transformer under `transformer/` |
|
| 123 |
+
| `load_lora` | `True` | Load LoRA from `transformer/forward/lora/` |
|
| 124 |
+
| `load_imaa` | `True` | Load IMAA weights from `imaa/` |
|
| 125 |
+
| `device` | `None` | Moves all modules to device (IMAA stays float32) |
|
| 126 |
+
|
| 127 |
+
Sub-methods:
|
| 128 |
+
|
| 129 |
+
- `pipe.decompose(image, dino_model, dino_processor, ...)` → dict of intrinsic maps
|
| 130 |
+
- `pipe.render(maps, weather="rainy", ...)` → weather-conditioned RGB
|
| 131 |
+
|
| 132 |
+
### 2. `IntrinsicWeatherInversePipeline`
|
| 133 |
+
|
| 134 |
+
Inverse rendering only (single intrinsic map per call).
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
|
| 138 |
+
|
| 139 |
+
pipe = IntrinsicWeatherInversePipeline.from_pretrained(
|
| 140 |
+
repo_dir,
|
| 141 |
+
transformer_subfolder="inverse-512",
|
| 142 |
+
device="cuda",
|
| 143 |
+
torch_dtype=torch.bfloat16,
|
| 144 |
+
)
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Load the transformer separately if needed:
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
transformer = IntrinsicWeatherInversePipeline.load_transformer(
|
| 151 |
+
"inverse-512", repo_dir, device="cuda"
|
| 152 |
+
)
|
| 153 |
+
pipe = IntrinsicWeatherInversePipeline.from_pretrained(
|
| 154 |
+
repo_dir, transformer=transformer, device="cuda"
|
| 155 |
+
)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
IMAA and DINO are used by the unified pipeline’s `decompose()` path; for standalone inverse calls, pass `map_aware_mask` from IMAA manually (see `test_all_pipelines.py`).
|
| 159 |
+
|
| 160 |
+
### 3. `IntrinsicWeatherForwardPipeline`
|
| 161 |
+
|
| 162 |
+
Forward weather rendering from intrinsic maps.
|
| 163 |
+
|
| 164 |
+
```python
|
| 165 |
+
from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
|
| 166 |
+
|
| 167 |
+
pipe = IntrinsicWeatherForwardPipeline.from_pretrained(
|
| 168 |
+
repo_dir,
|
| 169 |
+
transformer_subfolder="forward",
|
| 170 |
+
device="cuda",
|
| 171 |
+
torch_dtype=torch.bfloat16,
|
| 172 |
+
load_lora=True,
|
| 173 |
+
)
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
LoRA weights are read from `transformer/forward/lora/` when `load_lora=True`.
|
| 177 |
+
|
| 178 |
+
## Weather presets
|
| 179 |
+
|
| 180 |
+
Built-in weather keys (or pass a custom prompt string):
|
| 181 |
+
|
| 182 |
+
| Key | Prompt |
|
| 183 |
+
|-----|--------|
|
| 184 |
+
| `rainy` | A rainy day. |
|
| 185 |
+
| `sunny` | A sunny day. |
|
| 186 |
+
| `snowy` | A snowy day. |
|
| 187 |
+
| `foggy` | A foggy day. |
|
| 188 |
+
| `overcast` | An overcast day. |
|
| 189 |
+
| `night` | A night scene. |
|
| 190 |
+
|
| 191 |
+
## Intrinsic maps (AoVs)
|
| 192 |
+
|
| 193 |
+
The inverse renderer produces five appearance-of-variety maps:
|
| 194 |
+
|
| 195 |
+
`albedo`, `normal`, `roughness`, `metallic`, `irradiance`
|
| 196 |
+
|
| 197 |
+
## Loading transformers manually
|
| 198 |
+
|
| 199 |
+
Transformers are stored per variant under `transformer/<subfolder>/`. Use `pipeline_utils.load_transformer_from_subfolder`:
|
| 200 |
+
|
| 201 |
+
```python
|
| 202 |
+
from pipeline_utils import load_transformer_from_subfolder, load_transformer_lora
|
| 203 |
+
|
| 204 |
+
inverse = load_transformer_from_subfolder(repo_dir, "inverse-512", device="cuda")
|
| 205 |
+
forward = load_transformer_from_subfolder(repo_dir, "forward", device="cuda")
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
- `inverse-512` uses a custom `IntrinsicWeatherSD3Transformer2DModel` (`in_channels=32`).
|
| 209 |
+
- `forward` uses the standard `SD3Transformer2DModel` (`in_channels=96`).
|
| 210 |
+
|
| 211 |
+
## Dtype and device notes
|
| 212 |
+
|
| 213 |
+
- Default dtype is **`torch.bfloat16`** for transformers, VAE, and text encoders.
|
| 214 |
+
- **IMAA** stays in **float32** (DINO patch tokens are float32).
|
| 215 |
+
- Pass `device="cuda"` to `from_pretrained` on all three pipeline classes; the unified pipeline moves every registered module to the target device automatically.
|
| 216 |
+
|
| 217 |
+
## Testing
|
| 218 |
+
|
| 219 |
+
Smoke-test all pipelines on CUDA:
|
| 220 |
+
|
| 221 |
+
```bash
|
| 222 |
+
python test_all_pipelines.py
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
Runs 2-step inverse, forward (with LoRA), and unified load checks with `bfloat16`.
|
| 226 |
+
|
| 227 |
+
## Re-converting from original checkpoints
|
| 228 |
+
|
| 229 |
+
If you have the raw GilgameshYX checkpoints:
|
| 230 |
+
|
| 231 |
+
```bash
|
| 232 |
+
# Inverse renderer (512) + IMAA
|
| 233 |
+
python convert_inverse_renderer_512.py
|
| 234 |
+
|
| 235 |
+
# Forward renderer + LoRA
|
| 236 |
+
python convert_forward_renderer.py
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
See `conversion_metadata.json` and `conversion_metadata_forward.json` for source paths used during conversion.
|
| 240 |
+
|
| 241 |
+
## Hugging Face Hub loading
|
| 242 |
+
|
| 243 |
+
When published to the Hub, load with `trust_remote_code=True`:
|
| 244 |
+
|
| 245 |
+
```python
|
| 246 |
+
from diffusers import DiffusionPipeline
|
| 247 |
+
|
| 248 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 249 |
+
"BiliSakura/IntrisicWeather-diffusers",
|
| 250 |
+
custom_pipeline="pipeline_intrinsic_weather.py",
|
| 251 |
+
trust_remote_code=True,
|
| 252 |
+
torch_dtype=torch.bfloat16,
|
| 253 |
+
)
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
For local use, importing `IntrinsicWeatherPipeline` directly (as in Quick start) is simpler and avoids Hub cache path issues with custom modules.
|
| 257 |
+
|
| 258 |
+
## References
|
| 259 |
+
|
| 260 |
+
- **Paper:** [IntrinsicWeather (arXiv:2508.06982)](https://arxiv.org/pdf/2508.06982v6)
|
| 261 |
+
- **Project page:** https://yixinzhu042.github.io/IntrinsicWeather/
|
| 262 |
+
- **Upstream diffusers repo:** [IntrinsicWeather-diffusers](https://github.com/YixinZhu042/IntrinsicWeather)
|
| 263 |
+
- **Original weights:** [GilgameshYX/InverseRenderer-512](https://huggingface.co/GilgameshYX/InverseRenderer-512), [GilgameshYX/ForwardRenderer](https://huggingface.co/GilgameshYX/ForwardRenderer)
|
| 264 |
+
|
| 265 |
+
## License
|
| 266 |
+
|
| 267 |
+
Weights and code follow the licenses of the upstream IntrinsicWeather project and the Stable Diffusion 3 components used for shared modules.
|
__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
|
| 2 |
+
from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
|
| 3 |
+
from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"IntrinsicWeatherPipeline",
|
| 7 |
+
"IntrinsicWeatherForwardPipeline",
|
| 8 |
+
"IntrinsicWeatherInversePipeline",
|
| 9 |
+
]
|
convert_forward_renderer.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert GilgameshYX ForwardRenderer into BiliSakura IntrisicWeather-diffusers layout."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import shutil
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
| 12 |
+
|
| 13 |
+
COLLECTION_ROOT = Path(__file__).resolve().parent
|
| 14 |
+
INTRINSIC_REPO = Path("/data/projects/IntrinsicWeather-diffusers")
|
| 15 |
+
sys.path.insert(0, str(INTRINSIC_REPO / "src"))
|
| 16 |
+
sys.path.insert(0, str(INTRINSIC_REPO))
|
| 17 |
+
|
| 18 |
+
from scripts._conversion_utils import ( # noqa: E402
|
| 19 |
+
expand_sd3_input_projection,
|
| 20 |
+
load_torch,
|
| 21 |
+
write_scheduler_config,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from _collection_setup import install_hub_pipelines # noqa: E402
|
| 25 |
+
|
| 26 |
+
SD3_PATH = Path(
|
| 27 |
+
"/data/projects/Visual-Generative-Foundation-Model-Collection/models/stabilityai/stable-diffusion-3-medium-diffusers"
|
| 28 |
+
)
|
| 29 |
+
SD35_TRANSFORMER_REPO = "stabilityai/stable-diffusion-3.5-medium"
|
| 30 |
+
CKPT_PATH = Path(
|
| 31 |
+
"/data/projects/Visual-Generative-Foundation-Model-Collection/models/GilgameshYX/ForwardRenderer"
|
| 32 |
+
)
|
| 33 |
+
OUTPUT_ROOT = COLLECTION_ROOT
|
| 34 |
+
TRANSFORMER_VARIANT = "forward"
|
| 35 |
+
SHARED_COMPONENTS = (
|
| 36 |
+
"text_encoder",
|
| 37 |
+
"text_encoder_2",
|
| 38 |
+
"text_encoder_3",
|
| 39 |
+
"tokenizer",
|
| 40 |
+
"tokenizer_2",
|
| 41 |
+
"tokenizer_3",
|
| 42 |
+
"vae",
|
| 43 |
+
"scheduler",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def copy_sd3_shared_components(sd3_path: Path, output_path: Path) -> None:
|
| 48 |
+
for name in SHARED_COMPONENTS:
|
| 49 |
+
src = sd3_path / name
|
| 50 |
+
dst = output_path / name
|
| 51 |
+
if dst.exists():
|
| 52 |
+
print(f"Skipping existing shared component: {dst}")
|
| 53 |
+
continue
|
| 54 |
+
print(f"Copying {name} ...")
|
| 55 |
+
shutil.copytree(src, dst)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main() -> None:
|
| 59 |
+
transformer_dir = OUTPUT_ROOT / "transformer" / TRANSFORMER_VARIANT
|
| 60 |
+
transformer_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
print(f"Ensuring shared SD3 components from {SD3_PATH} ...")
|
| 63 |
+
copy_sd3_shared_components(SD3_PATH, OUTPUT_ROOT)
|
| 64 |
+
write_scheduler_config(OUTPUT_ROOT)
|
| 65 |
+
install_hub_pipelines(OUTPUT_ROOT)
|
| 66 |
+
|
| 67 |
+
print("Converting forward renderer transformer ...")
|
| 68 |
+
transformer = SD3Transformer2DModel.from_config(
|
| 69 |
+
SD3Transformer2DModel.load_config(SD35_TRANSFORMER_REPO, subfolder="transformer")
|
| 70 |
+
)
|
| 71 |
+
transformer = expand_sd3_input_projection(transformer, in_channels=96)
|
| 72 |
+
transformer.load_state_dict(load_torch(CKPT_PATH / "pytorch_model.bin"), strict=True)
|
| 73 |
+
transformer.save_pretrained(transformer_dir.as_posix(), safe_serialization=True)
|
| 74 |
+
|
| 75 |
+
print("Saving LoRA weights ...")
|
| 76 |
+
lora_dir = transformer_dir / "lora"
|
| 77 |
+
lora_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
shutil.copy2(CKPT_PATH / "pytorch_lora_weights.safetensors", lora_dir / "pytorch_lora_weights.safetensors")
|
| 79 |
+
|
| 80 |
+
conversion_metadata = {
|
| 81 |
+
"task": "forward_renderer",
|
| 82 |
+
"transformer_variant": TRANSFORMER_VARIANT,
|
| 83 |
+
"source_transformer_checkpoint": str((CKPT_PATH / "pytorch_model.bin").resolve()),
|
| 84 |
+
"source_lora_checkpoint": str((CKPT_PATH / "pytorch_lora_weights.safetensors").resolve()),
|
| 85 |
+
"lora_dir": str((lora_dir).resolve()),
|
| 86 |
+
"sd3_path": str(SD3_PATH.resolve()),
|
| 87 |
+
"sd35_transformer_repo": SD35_TRANSFORMER_REPO,
|
| 88 |
+
"in_channels": 96,
|
| 89 |
+
}
|
| 90 |
+
(OUTPUT_ROOT / "conversion_metadata_forward.json").write_text(
|
| 91 |
+
json.dumps(conversion_metadata, indent=2) + "\n",
|
| 92 |
+
encoding="utf-8",
|
| 93 |
+
)
|
| 94 |
+
print(f"Saved transformer to: {transformer_dir}")
|
| 95 |
+
print("Load with: load_forward_pipeline(transformer_subfolder='forward')")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
main()
|
convert_inverse_renderer_512.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert GilgameshYX InverseRenderer-512 into BiliSakura IntrisicWeather-diffusers layout."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import shutil
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
| 12 |
+
|
| 13 |
+
COLLECTION_ROOT = Path(__file__).resolve().parent
|
| 14 |
+
INTRINSIC_REPO = Path("/data/projects/IntrinsicWeather-diffusers")
|
| 15 |
+
sys.path.insert(0, str(INTRINSIC_REPO / "src"))
|
| 16 |
+
sys.path.insert(0, str(INTRINSIC_REPO))
|
| 17 |
+
|
| 18 |
+
from intrinsic_weather.models.transformers.transformer_intrinsic_weather import ( # noqa: E402
|
| 19 |
+
IntrinsicWeatherSD3Transformer2DModel,
|
| 20 |
+
)
|
| 21 |
+
from scripts._conversion_utils import ( # noqa: E402
|
| 22 |
+
ROOT as REPO_ROOT,
|
| 23 |
+
expand_sd3_input_projection,
|
| 24 |
+
merge_sharded_state_dict,
|
| 25 |
+
save_imaa_bundle,
|
| 26 |
+
write_scheduler_config,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from _collection_setup import install_hub_pipelines # noqa: E402
|
| 30 |
+
|
| 31 |
+
SD3_PATH = Path(
|
| 32 |
+
"/data/projects/Visual-Generative-Foundation-Model-Collection/models/stabilityai/stable-diffusion-3-medium-diffusers"
|
| 33 |
+
)
|
| 34 |
+
SD35_TRANSFORMER_REPO = "stabilityai/stable-diffusion-3.5-medium"
|
| 35 |
+
CKPT_PATH = Path(
|
| 36 |
+
"/data/projects/Visual-Generative-Foundation-Model-Collection/models/GilgameshYX/InverseRenderer-512"
|
| 37 |
+
)
|
| 38 |
+
OUTPUT_ROOT = COLLECTION_ROOT
|
| 39 |
+
TRANSFORMER_VARIANT = "inverse-512"
|
| 40 |
+
SHARED_COMPONENTS = (
|
| 41 |
+
"text_encoder",
|
| 42 |
+
"text_encoder_2",
|
| 43 |
+
"text_encoder_3",
|
| 44 |
+
"tokenizer",
|
| 45 |
+
"tokenizer_2",
|
| 46 |
+
"tokenizer_3",
|
| 47 |
+
"vae",
|
| 48 |
+
"scheduler",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def copy_sd3_shared_components(sd3_path: Path, output_path: Path) -> None:
|
| 53 |
+
for name in SHARED_COMPONENTS:
|
| 54 |
+
src = sd3_path / name
|
| 55 |
+
dst = output_path / name
|
| 56 |
+
if dst.exists():
|
| 57 |
+
print(f"Skipping existing shared component: {dst}")
|
| 58 |
+
continue
|
| 59 |
+
print(f"Copying {name} ...")
|
| 60 |
+
shutil.copytree(src, dst)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main() -> None:
|
| 64 |
+
transformer_dir = OUTPUT_ROOT / "transformer" / TRANSFORMER_VARIANT
|
| 65 |
+
transformer_dir.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
print(f"Copying shared SD3 components from {SD3_PATH} ...")
|
| 68 |
+
copy_sd3_shared_components(SD3_PATH, OUTPUT_ROOT)
|
| 69 |
+
write_scheduler_config(OUTPUT_ROOT)
|
| 70 |
+
install_hub_pipelines(OUTPUT_ROOT)
|
| 71 |
+
|
| 72 |
+
print("Converting inverse renderer transformer ...")
|
| 73 |
+
base_transformer = SD3Transformer2DModel.from_config(
|
| 74 |
+
SD3Transformer2DModel.load_config(SD35_TRANSFORMER_REPO, subfolder="transformer")
|
| 75 |
+
)
|
| 76 |
+
base_transformer = expand_sd3_input_projection(base_transformer, in_channels=32)
|
| 77 |
+
custom_blocks = IntrinsicWeatherSD3Transformer2DModel.from_config(base_transformer.config)
|
| 78 |
+
custom_blocks.load_state_dict(
|
| 79 |
+
merge_sharded_state_dict(
|
| 80 |
+
[
|
| 81 |
+
CKPT_PATH / "pytorch_model-00001-of-00002.bin",
|
| 82 |
+
CKPT_PATH / "pytorch_model-00002-of-00002.bin",
|
| 83 |
+
]
|
| 84 |
+
),
|
| 85 |
+
strict=True,
|
| 86 |
+
)
|
| 87 |
+
custom_blocks.save_pretrained(transformer_dir.as_posix(), safe_serialization=True)
|
| 88 |
+
shutil.copy2(
|
| 89 |
+
REPO_ROOT / "src" / "intrinsic_weather" / "models" / "transformers" / "transformer_intrinsic_weather.py",
|
| 90 |
+
transformer_dir / "transformer_intrinsic_weather.py",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
print("Saving IMAA weights ...")
|
| 94 |
+
save_imaa_bundle(CKPT_PATH / "imaa.pth", OUTPUT_ROOT, safe_serialization=True)
|
| 95 |
+
|
| 96 |
+
conversion_metadata = {
|
| 97 |
+
"task": "inverse_renderer",
|
| 98 |
+
"resolution": 512,
|
| 99 |
+
"transformer_variant": TRANSFORMER_VARIANT,
|
| 100 |
+
"source_transformer_checkpoints": [
|
| 101 |
+
str((CKPT_PATH / "pytorch_model-00001-of-00002.bin").resolve()),
|
| 102 |
+
str((CKPT_PATH / "pytorch_model-00002-of-00002.bin").resolve()),
|
| 103 |
+
],
|
| 104 |
+
"source_imaa_checkpoint": str((CKPT_PATH / "imaa.pth").resolve()),
|
| 105 |
+
"sd3_path": str(SD3_PATH.resolve()),
|
| 106 |
+
"sd35_transformer_repo": SD35_TRANSFORMER_REPO,
|
| 107 |
+
"in_channels": 32,
|
| 108 |
+
}
|
| 109 |
+
(OUTPUT_ROOT / "conversion_metadata.json").write_text(
|
| 110 |
+
json.dumps(conversion_metadata, indent=2) + "\n",
|
| 111 |
+
encoding="utf-8",
|
| 112 |
+
)
|
| 113 |
+
print(f"Saved transformer to: {transformer_dir}")
|
| 114 |
+
print("Load with: load_inverse_pipeline(transformer_subfolder='inverse-512')")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|
dinov2/README.md
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- dino
|
| 5 |
+
- vision
|
| 6 |
+
inference: false
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Vision Transformer (base-sized model) trained using DINOv2
|
| 10 |
+
|
| 11 |
+
Vision Transformer (ViT) model trained using the DINOv2 method. It was introduced in the paper [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) by Oquab et al. and first released in [this repository](https://github.com/facebookresearch/dinov2).
|
| 12 |
+
|
| 13 |
+
Disclaimer: The team releasing DINOv2 did not write a model card for this model so this model card has been written by the Hugging Face team.
|
| 14 |
+
|
| 15 |
+
## Model description
|
| 16 |
+
|
| 17 |
+
The Vision Transformer (ViT) is a transformer encoder model (BERT-like) pretrained on a large collection of images in a self-supervised fashion.
|
| 18 |
+
|
| 19 |
+
Images are presented to the model as a sequence of fixed-size patches, which are linearly embedded. One also adds a [CLS] token to the beginning of a sequence to use it for classification tasks. One also adds absolute position embeddings before feeding the sequence to the layers of the Transformer encoder.
|
| 20 |
+
|
| 21 |
+
Note that this model does not include any fine-tuned heads.
|
| 22 |
+
|
| 23 |
+
By pre-training the model, it learns an inner representation of images that can then be used to extract features useful for downstream tasks: if you have a dataset of labeled images for instance, you can train a standard classifier by placing a linear layer on top of the pre-trained encoder. One typically places a linear layer on top of the [CLS] token, as the last hidden state of this token can be seen as a representation of an entire image.
|
| 24 |
+
|
| 25 |
+
## Intended uses & limitations
|
| 26 |
+
|
| 27 |
+
You can use the raw model for feature extraction. See the [model hub](https://huggingface.co/models?search=facebook/dinov2) to look for
|
| 28 |
+
fine-tuned versions on a task that interests you.
|
| 29 |
+
|
| 30 |
+
### How to use
|
| 31 |
+
|
| 32 |
+
Here is how to use this model:
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 36 |
+
from PIL import Image
|
| 37 |
+
import requests
|
| 38 |
+
|
| 39 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
| 40 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 41 |
+
|
| 42 |
+
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
| 43 |
+
model = AutoModel.from_pretrained('facebook/dinov2-base')
|
| 44 |
+
|
| 45 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 46 |
+
outputs = model(**inputs)
|
| 47 |
+
last_hidden_states = outputs.last_hidden_state
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### BibTeX entry and citation info
|
| 51 |
+
|
| 52 |
+
```bibtex
|
| 53 |
+
misc{oquab2023dinov2,
|
| 54 |
+
title={DINOv2: Learning Robust Visual Features without Supervision},
|
| 55 |
+
author={Maxime Oquab and Timothée Darcet and Théo Moutakanni and Huy Vo and Marc Szafraniec and Vasil Khalidov and Pierre Fernandez and Daniel Haziza and Francisco Massa and Alaaeldin El-Nouby and Mahmoud Assran and Nicolas Ballas and Wojciech Galuba and Russell Howes and Po-Yao Huang and Shang-Wen Li and Ishan Misra and Michael Rabbat and Vasu Sharma and Gabriel Synnaeve and Hu Xu and Hervé Jegou and Julien Mairal and Patrick Labatut and Armand Joulin and Piotr Bojanowski},
|
| 56 |
+
year={2023},
|
| 57 |
+
eprint={2304.07193},
|
| 58 |
+
archivePrefix={arXiv},
|
| 59 |
+
primaryClass={cs.CV}
|
| 60 |
+
}
|
| 61 |
+
```
|
dinov2/config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Dinov2Model"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.0,
|
| 6 |
+
"drop_path_rate": 0.0,
|
| 7 |
+
"hidden_act": "gelu",
|
| 8 |
+
"hidden_dropout_prob": 0.0,
|
| 9 |
+
"hidden_size": 768,
|
| 10 |
+
"image_size": 518,
|
| 11 |
+
"initializer_range": 0.02,
|
| 12 |
+
"layer_norm_eps": 1e-06,
|
| 13 |
+
"layerscale_value": 1.0,
|
| 14 |
+
"mlp_ratio": 4,
|
| 15 |
+
"model_type": "dinov2",
|
| 16 |
+
"num_attention_heads": 12,
|
| 17 |
+
"num_channels": 3,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"patch_size": 14,
|
| 20 |
+
"qkv_bias": true,
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.31.0.dev0",
|
| 23 |
+
"use_swiglu_ffn": false
|
| 24 |
+
}
|
dinov2/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d73036b56966966d07975d696bde331762f37297e2f095de8cea0040c3aa0841
|
| 3 |
+
size 346345912
|
dinov2/preprocessor_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": {
|
| 3 |
+
"height": 224,
|
| 4 |
+
"width": 224
|
| 5 |
+
},
|
| 6 |
+
"do_center_crop": true,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"image_mean": [
|
| 12 |
+
0.485,
|
| 13 |
+
0.456,
|
| 14 |
+
0.406
|
| 15 |
+
],
|
| 16 |
+
"image_processor_type": "BitImageProcessor",
|
| 17 |
+
"image_std": [
|
| 18 |
+
0.229,
|
| 19 |
+
0.224,
|
| 20 |
+
0.225
|
| 21 |
+
],
|
| 22 |
+
"resample": 3,
|
| 23 |
+
"rescale_factor": 0.00392156862745098,
|
| 24 |
+
"size": {
|
| 25 |
+
"shortest_edge": 256
|
| 26 |
+
}
|
| 27 |
+
}
|
imaa/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "IMAA",
|
| 3 |
+
"num_maps": 5,
|
| 4 |
+
"map_embedding_dim": 256,
|
| 5 |
+
"common_dim": 128,
|
| 6 |
+
"conv_channels": [
|
| 7 |
+
128,
|
| 8 |
+
64
|
| 9 |
+
],
|
| 10 |
+
"dino_patch_dim": 768
|
| 11 |
+
}
|
imaa/imaa.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def extract_patch_tokens_min_windows(
|
| 24 |
+
images: torch.Tensor,
|
| 25 |
+
model: nn.Module,
|
| 26 |
+
processor,
|
| 27 |
+
window_size: int = 224,
|
| 28 |
+
device: str | torch.device = "cuda",
|
| 29 |
+
) -> torch.Tensor:
|
| 30 |
+
r"""
|
| 31 |
+
Tile each image with a minimal window set and return averaged DINO patch tokens.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
images (`torch.Tensor`): Batch of RGB images `(B, C, H, W)`.
|
| 35 |
+
model: DINO vision transformer.
|
| 36 |
+
processor: Hugging Face image processor for DINO.
|
| 37 |
+
window_size (`int`): Sliding-window size in pixels.
|
| 38 |
+
device: Device for intermediate tensors.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
`torch.Tensor` of shape `(B, H//patch, W//patch, hidden_size)`.
|
| 42 |
+
"""
|
| 43 |
+
batch_size, _, height, width = images.shape
|
| 44 |
+
hidden_size = model.config.hidden_size
|
| 45 |
+
patch_size = model.config.patch_size
|
| 46 |
+
token_avgs = []
|
| 47 |
+
|
| 48 |
+
for batch_idx in range(batch_size):
|
| 49 |
+
image = images[batch_idx].float()
|
| 50 |
+
if image.max() <= 1.0:
|
| 51 |
+
image_np = (image.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype("uint8")
|
| 52 |
+
else:
|
| 53 |
+
image_np = image.permute(1, 2, 0).cpu().numpy().clip(0, 255).astype("uint8")
|
| 54 |
+
|
| 55 |
+
token_sum = torch.zeros((height // patch_size, width // patch_size, hidden_size), device=device)
|
| 56 |
+
token_count = torch.zeros((height // patch_size, width // patch_size, 1), device=device)
|
| 57 |
+
|
| 58 |
+
num_y = (height + window_size - 1) // window_size
|
| 59 |
+
num_x = (width + window_size - 1) // window_size
|
| 60 |
+
y_positions = [index * window_size for index in range(num_y - 1)] + [height - window_size]
|
| 61 |
+
x_positions = [index * window_size for index in range(num_x - 1)] + [width - window_size]
|
| 62 |
+
|
| 63 |
+
for y in y_positions:
|
| 64 |
+
for x in x_positions:
|
| 65 |
+
patch = image_np[y : y + window_size, x : x + window_size, :]
|
| 66 |
+
inputs = processor(images=patch, return_tensors="pt").to(device)
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
outputs = model(**inputs)
|
| 69 |
+
patch_tokens = outputs.last_hidden_state[:, 1:, :]
|
| 70 |
+
patch_tokens = patch_tokens.reshape(
|
| 71 |
+
1, window_size // patch_size, window_size // patch_size, hidden_size
|
| 72 |
+
).squeeze(0)
|
| 73 |
+
|
| 74 |
+
y0, x0 = y // patch_size, x // patch_size
|
| 75 |
+
y1, x1 = y0 + window_size // patch_size, x0 + window_size // patch_size
|
| 76 |
+
token_sum[y0:y1, x0:x1, :] += patch_tokens
|
| 77 |
+
token_count[y0:y1, x0:x1, 0] += 1
|
| 78 |
+
|
| 79 |
+
token_avgs.append(token_sum / token_count)
|
| 80 |
+
|
| 81 |
+
return torch.stack(token_avgs, dim=0)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class LayerNorm2d(nn.Module):
|
| 85 |
+
def __init__(self, channels: int) -> None:
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.norm = nn.LayerNorm([channels])
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
x = x.permute(0, 2, 3, 1)
|
| 91 |
+
x = self.norm(x)
|
| 92 |
+
return x.permute(0, 3, 1, 2)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class IMAA(nn.Module):
|
| 96 |
+
r"""
|
| 97 |
+
Intrinsic Map-Aware Attention (IMAA) gating module.
|
| 98 |
+
|
| 99 |
+
Produces per-map attention biases from DINO patch tokens and learnable map embeddings.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
dino_model: Optional[nn.Module] = None,
|
| 105 |
+
processor=None,
|
| 106 |
+
num_maps: int = 5,
|
| 107 |
+
map_embedding_dim: int = 256,
|
| 108 |
+
common_dim: int = 128,
|
| 109 |
+
conv_channels: Optional[list[int]] = None,
|
| 110 |
+
dino_patch_dim: int = 768,
|
| 111 |
+
) -> None:
|
| 112 |
+
super().__init__()
|
| 113 |
+
conv_channels = conv_channels or [128, 64]
|
| 114 |
+
self.dino = dino_model
|
| 115 |
+
self.processor = processor
|
| 116 |
+
if self.dino is not None:
|
| 117 |
+
self.dino.eval()
|
| 118 |
+
for param in self.dino.parameters():
|
| 119 |
+
param.requires_grad = False
|
| 120 |
+
|
| 121 |
+
self.num_maps = num_maps
|
| 122 |
+
self.map_embedding_dim = map_embedding_dim
|
| 123 |
+
self.common_dim = common_dim
|
| 124 |
+
self.dino_patch_dim = dino_patch_dim
|
| 125 |
+
self.map_embedding = nn.Parameter(torch.randn(num_maps, map_embedding_dim))
|
| 126 |
+
self.dino_proj = nn.Conv2d(dino_patch_dim, common_dim, kernel_size=1)
|
| 127 |
+
self.map_proj = nn.Linear(map_embedding_dim, common_dim)
|
| 128 |
+
self.fusion_layer = nn.Sequential(
|
| 129 |
+
nn.Conv2d(common_dim * 2, common_dim, 1),
|
| 130 |
+
LayerNorm2d(common_dim),
|
| 131 |
+
nn.ReLU(),
|
| 132 |
+
nn.Conv2d(common_dim, common_dim, 3, padding=1),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
conv_layers: list[nn.Module] = []
|
| 136 |
+
in_channels = common_dim
|
| 137 |
+
for out_channels in conv_channels:
|
| 138 |
+
conv_layers.extend([nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()])
|
| 139 |
+
in_channels = out_channels
|
| 140 |
+
conv_layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))
|
| 141 |
+
self.conv_head = nn.Sequential(*conv_layers)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
image: Optional[torch.Tensor] = None,
|
| 146 |
+
patch_tokens: Optional[torch.Tensor] = None,
|
| 147 |
+
output_size: Optional[Tuple[int, int]] = None,
|
| 148 |
+
map_ids: Optional[torch.Tensor] = None,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
if patch_tokens is None:
|
| 151 |
+
if self.dino is None or image is None:
|
| 152 |
+
raise ValueError("Either `patch_tokens` or (`image` and a frozen DINO model) must be provided.")
|
| 153 |
+
patch_tokens = extract_patch_tokens_min_windows(
|
| 154 |
+
image, self.dino, self.processor, window_size=224, device=image.device
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
dino_feat_map = patch_tokens.permute(0, 3, 1, 2)
|
| 158 |
+
dino_proj = self.dino_proj(dino_feat_map)
|
| 159 |
+
map_emb = self.map_embedding[map_ids]
|
| 160 |
+
map_proj = self.map_proj(map_emb).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, dino_proj.size(2), dino_proj.size(3))
|
| 161 |
+
fused_map = self.fusion_layer(torch.cat([dino_proj, map_proj], dim=1))
|
| 162 |
+
raw_gating_map = self.conv_head(fused_map)
|
| 163 |
+
aligned_map = (
|
| 164 |
+
F.interpolate(raw_gating_map, size=output_size, mode="bilinear", align_corners=False)
|
| 165 |
+
if output_size is not None
|
| 166 |
+
else raw_gating_map
|
| 167 |
+
)
|
| 168 |
+
return torch.sigmoid(aligned_map)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def build_attn_mask(
|
| 172 |
+
w_gating: torch.Tensor,
|
| 173 |
+
text_len: int,
|
| 174 |
+
img_len: int,
|
| 175 |
+
lam: float,
|
| 176 |
+
) -> torch.Tensor:
|
| 177 |
+
r"""
|
| 178 |
+
Build an additive attention mask from IMAA gating weights.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
w_gating (`torch.Tensor`): Gating map `[B, 1, H, W]` or flattened `[B, img_len]`.
|
| 182 |
+
text_len (`int`): Number of text tokens prepended to image tokens.
|
| 183 |
+
img_len (`int`): Expected number of image tokens.
|
| 184 |
+
lam (`float`): Mask scaling factor.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Attention bias tensor shaped for SD3 joint attention.
|
| 188 |
+
"""
|
| 189 |
+
batch_size = w_gating.shape[0]
|
| 190 |
+
total_len = text_len + img_len
|
| 191 |
+
if w_gating.dim() == 4:
|
| 192 |
+
w_gating = w_gating.view(batch_size, -1)
|
| 193 |
+
|
| 194 |
+
gating = lam * w_gating
|
| 195 |
+
actual_img_len = gating.shape[1]
|
| 196 |
+
if actual_img_len != img_len:
|
| 197 |
+
if actual_img_len > img_len:
|
| 198 |
+
gating = gating[:, :img_len]
|
| 199 |
+
else:
|
| 200 |
+
padding = torch.zeros(batch_size, img_len - actual_img_len, device=gating.device, dtype=gating.dtype)
|
| 201 |
+
gating = torch.cat([gating, padding], dim=1)
|
| 202 |
+
|
| 203 |
+
col_bias = torch.zeros(batch_size, total_len, device=w_gating.device, dtype=w_gating.dtype)
|
| 204 |
+
col_bias[:, text_len:] = gating
|
| 205 |
+
return col_bias.view(batch_size, 1, 1, total_len)
|
imaa/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5b9f0ff31b173c101a24bcf78cd951639acc72bebdd5231f7a8d500d41b0457
|
| 3 |
+
size 2140596
|
model_index.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": [
|
| 3 |
+
"pipeline",
|
| 4 |
+
"IntrinsicWeatherInversePipeline"
|
| 5 |
+
],
|
| 6 |
+
"_diffusers_version": "0.38.0",
|
| 7 |
+
"scheduler": [
|
| 8 |
+
"diffusers",
|
| 9 |
+
"FlowMatchEulerDiscreteScheduler"
|
| 10 |
+
],
|
| 11 |
+
"vae": [
|
| 12 |
+
"diffusers",
|
| 13 |
+
"AutoencoderKL"
|
| 14 |
+
],
|
| 15 |
+
"text_encoder": [
|
| 16 |
+
"transformers",
|
| 17 |
+
"CLIPTextModelWithProjection"
|
| 18 |
+
],
|
| 19 |
+
"text_encoder_2": [
|
| 20 |
+
"transformers",
|
| 21 |
+
"CLIPTextModelWithProjection"
|
| 22 |
+
],
|
| 23 |
+
"text_encoder_3": [
|
| 24 |
+
"transformers",
|
| 25 |
+
"T5EncoderModel"
|
| 26 |
+
],
|
| 27 |
+
"tokenizer": [
|
| 28 |
+
"transformers",
|
| 29 |
+
"CLIPTokenizer"
|
| 30 |
+
],
|
| 31 |
+
"tokenizer_2": [
|
| 32 |
+
"transformers",
|
| 33 |
+
"CLIPTokenizer"
|
| 34 |
+
],
|
| 35 |
+
"tokenizer_3": [
|
| 36 |
+
"transformers",
|
| 37 |
+
"T5TokenizerFast"
|
| 38 |
+
]
|
| 39 |
+
}
|
pipeline_intrinsic_weather.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: IntrinsicWeatherPipeline.
|
| 2 |
+
Inverse decomposition + forward weather rendering in one call.
|
| 3 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import PIL.Image
|
| 27 |
+
import torch
|
| 28 |
+
import torchvision.transforms as T
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from diffusers.image_processor import PipelineImageInput
|
| 31 |
+
from diffusers.loaders import SD3LoraLoaderMixin
|
| 32 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 33 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
| 34 |
+
from diffusers.utils import logging, replace_example_docstring
|
| 35 |
+
from transformers import (
|
| 36 |
+
CLIPTextModelWithProjection,
|
| 37 |
+
CLIPTokenizer,
|
| 38 |
+
PreTrainedModel,
|
| 39 |
+
PreTrainedTokenizer,
|
| 40 |
+
T5EncoderModel,
|
| 41 |
+
T5TokenizerFast,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
from imaa.imaa import IMAA, build_attn_mask, extract_patch_tokens_min_windows
|
| 45 |
+
from pipeline_utils import load_transformer_from_subfolder, load_transformer_lora, resolve_repo_dir
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
AOVS = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
| 50 |
+
INVERSE_PROMPTS = {
|
| 51 |
+
"albedo": "Albedo (diffuse basecolor)",
|
| 52 |
+
"normal": "Camera-space Normal",
|
| 53 |
+
"roughness": "Roughness",
|
| 54 |
+
"metallic": "Metallicness",
|
| 55 |
+
"irradiance": "Irradiance (lighting)",
|
| 56 |
+
}
|
| 57 |
+
WEATHER_PROMPTS = {
|
| 58 |
+
"rainy": "A rainy day.",
|
| 59 |
+
"sunny": "A sunny day.",
|
| 60 |
+
"snowy": "A snowy day.",
|
| 61 |
+
"foggy": "A foggy day.",
|
| 62 |
+
"overcast": "An overcast day.",
|
| 63 |
+
"night": "A night scene.",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
EXAMPLE_DOC_STRING = """
|
| 68 |
+
Examples:
|
| 69 |
+
```py
|
| 70 |
+
>>> from pathlib import Path
|
| 71 |
+
>>> import torch
|
| 72 |
+
>>> from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
|
| 73 |
+
>>> from transformers import AutoImageProcessor, AutoModel
|
| 74 |
+
|
| 75 |
+
>>> repo_dir = Path("./IntrisicWeather-diffusers").resolve()
|
| 76 |
+
>>> pipe = IntrinsicWeatherPipeline.from_pretrained(
|
| 77 |
+
... repo_dir,
|
| 78 |
+
... inverse_transformer_subfolder="inverse-512",
|
| 79 |
+
... forward_transformer_subfolder="forward",
|
| 80 |
+
... local_files_only=True,
|
| 81 |
+
... torch_dtype=torch.float16,
|
| 82 |
+
... )
|
| 83 |
+
>>> pipe.to("cuda")
|
| 84 |
+
>>> dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
| 85 |
+
>>> dino_model = AutoModel.from_pretrained("facebook/dinov2-base").to("cuda")
|
| 86 |
+
|
| 87 |
+
>>> from PIL import Image
|
| 88 |
+
>>> image = Image.open("input.png")
|
| 89 |
+
>>> result = pipe(
|
| 90 |
+
... image=image,
|
| 91 |
+
... weather="rainy",
|
| 92 |
+
... dino_model=dino_model,
|
| 93 |
+
... dino_processor=dino_processor,
|
| 94 |
+
... )
|
| 95 |
+
>>> result.images[0].save("rainy.png")
|
| 96 |
+
```
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class IntrinsicWeatherPipelineOutput(StableDiffusion3PipelineOutput):
|
| 102 |
+
maps: Optional[Dict[str, Union[PIL.Image.Image, np.ndarray]]] = None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class IntrinsicWeatherPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
| 106 |
+
r"""
|
| 107 |
+
End-to-end IntrinsicWeather pipeline: RGB → intrinsic maps → weather-conditioned RGB.
|
| 108 |
+
|
| 109 |
+
Parameters:
|
| 110 |
+
inverse_transformer ([`IntrinsicWeatherSD3Transformer2DModel`]):
|
| 111 |
+
Map-aware SD3 transformer for intrinsic decomposition.
|
| 112 |
+
forward_transformer ([`SD3Transformer2DModel`]):
|
| 113 |
+
SD3 transformer for forward weather rendering.
|
| 114 |
+
imaa ([`IMAA`]):
|
| 115 |
+
Map-aware attention gating module used during inverse rendering.
|
| 116 |
+
vae, text_encoder(s), tokenizer(s), scheduler:
|
| 117 |
+
Shared Stable Diffusion 3 components.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->inverse_transformer->forward_transformer->vae"
|
| 121 |
+
_optional_components: List[str] = []
|
| 122 |
+
|
| 123 |
+
@classmethod
|
| 124 |
+
def from_pretrained(
|
| 125 |
+
cls,
|
| 126 |
+
pretrained_model_name_or_path,
|
| 127 |
+
*args,
|
| 128 |
+
inverse_transformer_subfolder: str = "inverse-512",
|
| 129 |
+
forward_transformer_subfolder: str = "forward",
|
| 130 |
+
inverse_transformer=None,
|
| 131 |
+
forward_transformer=None,
|
| 132 |
+
load_lora: bool = True,
|
| 133 |
+
load_imaa: bool = True,
|
| 134 |
+
**kwargs,
|
| 135 |
+
):
|
| 136 |
+
repo_dir = resolve_repo_dir(pretrained_model_name_or_path)
|
| 137 |
+
dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 138 |
+
device = kwargs.get("device")
|
| 139 |
+
if inverse_transformer is None:
|
| 140 |
+
inverse_transformer = load_transformer_from_subfolder(
|
| 141 |
+
repo_dir,
|
| 142 |
+
inverse_transformer_subfolder,
|
| 143 |
+
dtype=dtype,
|
| 144 |
+
device=device,
|
| 145 |
+
)
|
| 146 |
+
if forward_transformer is None:
|
| 147 |
+
forward_transformer = load_transformer_from_subfolder(
|
| 148 |
+
repo_dir,
|
| 149 |
+
forward_transformer_subfolder,
|
| 150 |
+
dtype=dtype,
|
| 151 |
+
device=device,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
imaa = IMAA(dino_model=None, processor=None, num_maps=5, map_embedding_dim=256, common_dim=128)
|
| 155 |
+
if device is not None:
|
| 156 |
+
imaa = imaa.to(device)
|
| 157 |
+
|
| 158 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 159 |
+
repo_dir.as_posix(),
|
| 160 |
+
*args,
|
| 161 |
+
inverse_transformer=inverse_transformer,
|
| 162 |
+
forward_transformer=forward_transformer,
|
| 163 |
+
imaa=imaa,
|
| 164 |
+
custom_pipeline=str(repo_dir / "pipeline_intrinsic_weather.py"),
|
| 165 |
+
trust_remote_code=True,
|
| 166 |
+
**kwargs,
|
| 167 |
+
)
|
| 168 |
+
if load_imaa:
|
| 169 |
+
pipe.load_imaa_weights(repo_dir / "imaa")
|
| 170 |
+
if load_lora:
|
| 171 |
+
load_transformer_lora(pipe._forward, repo_dir, forward_transformer_subfolder)
|
| 172 |
+
if device is not None:
|
| 173 |
+
for name in pipe.components.keys():
|
| 174 |
+
module = getattr(pipe, name, None)
|
| 175 |
+
if module is not None and hasattr(module, "to"):
|
| 176 |
+
if name == "imaa":
|
| 177 |
+
module.to(device=device)
|
| 178 |
+
else:
|
| 179 |
+
module.to(device=device, dtype=dtype)
|
| 180 |
+
return pipe
|
| 181 |
+
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
inverse_transformer,
|
| 185 |
+
forward_transformer,
|
| 186 |
+
imaa: IMAA,
|
| 187 |
+
scheduler,
|
| 188 |
+
vae,
|
| 189 |
+
text_encoder: CLIPTextModelWithProjection,
|
| 190 |
+
tokenizer: CLIPTokenizer,
|
| 191 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 192 |
+
tokenizer_2: CLIPTokenizer,
|
| 193 |
+
text_encoder_3: T5EncoderModel,
|
| 194 |
+
tokenizer_3: T5TokenizerFast,
|
| 195 |
+
) -> None:
|
| 196 |
+
super().__init__()
|
| 197 |
+
from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
|
| 198 |
+
from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
|
| 199 |
+
|
| 200 |
+
self.register_modules(
|
| 201 |
+
inverse_transformer=inverse_transformer,
|
| 202 |
+
forward_transformer=forward_transformer,
|
| 203 |
+
imaa=imaa,
|
| 204 |
+
scheduler=scheduler,
|
| 205 |
+
vae=vae,
|
| 206 |
+
text_encoder=text_encoder,
|
| 207 |
+
tokenizer=tokenizer,
|
| 208 |
+
text_encoder_2=text_encoder_2,
|
| 209 |
+
tokenizer_2=tokenizer_2,
|
| 210 |
+
text_encoder_3=text_encoder_3,
|
| 211 |
+
tokenizer_3=tokenizer_3,
|
| 212 |
+
)
|
| 213 |
+
shared = dict(
|
| 214 |
+
scheduler=scheduler,
|
| 215 |
+
vae=vae,
|
| 216 |
+
text_encoder=text_encoder,
|
| 217 |
+
tokenizer=tokenizer,
|
| 218 |
+
text_encoder_2=text_encoder_2,
|
| 219 |
+
tokenizer_2=tokenizer_2,
|
| 220 |
+
text_encoder_3=text_encoder_3,
|
| 221 |
+
tokenizer_3=tokenizer_3,
|
| 222 |
+
)
|
| 223 |
+
self._inverse = IntrinsicWeatherInversePipeline(transformer=inverse_transformer, **shared)
|
| 224 |
+
self._forward = IntrinsicWeatherForwardPipeline(transformer=forward_transformer, **shared)
|
| 225 |
+
|
| 226 |
+
def load_imaa_weights(self, imaa_dir: Union[str, Path]) -> None:
|
| 227 |
+
r"""Load IMAA weights from an `imaa/` subfolder produced by the conversion script."""
|
| 228 |
+
imaa_dir = Path(imaa_dir)
|
| 229 |
+
weights_path = imaa_dir / "model.safetensors"
|
| 230 |
+
if weights_path.exists():
|
| 231 |
+
from safetensors.torch import load_file
|
| 232 |
+
|
| 233 |
+
state_dict = load_file(weights_path.as_posix())
|
| 234 |
+
else:
|
| 235 |
+
payload = torch.load(imaa_dir / "pytorch_model.bin", map_location="cpu", weights_only=False)
|
| 236 |
+
state_dict = payload["model_state_dict"] if isinstance(payload, dict) and "model_state_dict" in payload else payload
|
| 237 |
+
self.imaa.load_state_dict(state_dict)
|
| 238 |
+
self.imaa.eval()
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def _resolve_weather_prompt(weather: str) -> str:
|
| 242 |
+
key = weather.lower().strip()
|
| 243 |
+
if key in WEATHER_PROMPTS:
|
| 244 |
+
return WEATHER_PROMPTS[key]
|
| 245 |
+
return weather
|
| 246 |
+
|
| 247 |
+
def _preprocess_rgb(self, image: PipelineImageInput, size: int) -> torch.Tensor:
|
| 248 |
+
if isinstance(image, PIL.Image.Image):
|
| 249 |
+
pil = image.convert("RGB")
|
| 250 |
+
elif isinstance(image, torch.Tensor):
|
| 251 |
+
if image.ndim == 3:
|
| 252 |
+
image = image.unsqueeze(0)
|
| 253 |
+
tensor = image
|
| 254 |
+
if tensor.shape[1] != 3:
|
| 255 |
+
raise ValueError(f"Expected 3 RGB channels, got shape {tuple(tensor.shape)}")
|
| 256 |
+
if tensor.min() < 0:
|
| 257 |
+
tensor = (tensor + 1.0) / 2.0
|
| 258 |
+
return tensor.to(device=self._execution_device, dtype=self.dtype)
|
| 259 |
+
else:
|
| 260 |
+
raise TypeError(f"`image` must be PIL.Image or torch.Tensor, got {type(image)}")
|
| 261 |
+
|
| 262 |
+
transform = T.Compose([T.Resize((size, size), interpolation=T.InterpolationMode.BILINEAR), T.ToTensor()])
|
| 263 |
+
tensor = transform(pil)
|
| 264 |
+
return tensor.unsqueeze(0).to(device=self._execution_device, dtype=self.dtype)
|
| 265 |
+
|
| 266 |
+
@staticmethod
|
| 267 |
+
def _map_array_to_tensor(
|
| 268 |
+
array: Union[np.ndarray, PIL.Image.Image, torch.Tensor],
|
| 269 |
+
size: int,
|
| 270 |
+
device: torch.device,
|
| 271 |
+
dtype: torch.dtype,
|
| 272 |
+
) -> torch.Tensor:
|
| 273 |
+
if isinstance(array, torch.Tensor):
|
| 274 |
+
tensor = array
|
| 275 |
+
if tensor.ndim == 3:
|
| 276 |
+
tensor = tensor.unsqueeze(0)
|
| 277 |
+
if tensor.min() < 0:
|
| 278 |
+
tensor = (tensor + 1.0) / 2.0
|
| 279 |
+
return tensor.to(device=device, dtype=dtype)
|
| 280 |
+
|
| 281 |
+
if isinstance(array, np.ndarray):
|
| 282 |
+
if array.dtype != np.uint8:
|
| 283 |
+
array = np.clip(array * 255.0, 0, 255).astype(np.uint8)
|
| 284 |
+
pil = PIL.Image.fromarray(array)
|
| 285 |
+
else:
|
| 286 |
+
pil = array.convert("RGB")
|
| 287 |
+
|
| 288 |
+
transform = T.Compose(
|
| 289 |
+
[
|
| 290 |
+
T.Resize((size, size), interpolation=T.InterpolationMode.BILINEAR),
|
| 291 |
+
T.ToTensor(),
|
| 292 |
+
]
|
| 293 |
+
)
|
| 294 |
+
return transform(pil).unsqueeze(0).to(device=device, dtype=dtype)
|
| 295 |
+
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def decompose(
|
| 298 |
+
self,
|
| 299 |
+
image: PipelineImageInput,
|
| 300 |
+
dino_model: PreTrainedModel,
|
| 301 |
+
dino_processor: PreTrainedTokenizer,
|
| 302 |
+
num_inference_steps: int = 50,
|
| 303 |
+
image_size: Optional[int] = None,
|
| 304 |
+
output_type: str = "pil",
|
| 305 |
+
) -> Dict[str, Union[PIL.Image.Image, np.ndarray]]:
|
| 306 |
+
r"""
|
| 307 |
+
Run inverse rendering and return intrinsic maps.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
image: Input RGB image.
|
| 311 |
+
dino_model: Frozen DINOv2 vision model.
|
| 312 |
+
dino_processor: Matching image processor.
|
| 313 |
+
num_inference_steps: Denoising steps per intrinsic map.
|
| 314 |
+
image_size: Square resize before decomposition. Defaults to input size.
|
| 315 |
+
output_type: `"pil"` or `"np"` per map.
|
| 316 |
+
"""
|
| 317 |
+
if image_size is None:
|
| 318 |
+
if isinstance(image, PIL.Image.Image):
|
| 319 |
+
image_size = max(image.size)
|
| 320 |
+
elif isinstance(image, torch.Tensor):
|
| 321 |
+
image_size = image.shape[-1]
|
| 322 |
+
else:
|
| 323 |
+
image_size = 1024
|
| 324 |
+
|
| 325 |
+
image_tensor = self._preprocess_rgb(image, image_size)
|
| 326 |
+
patch_tokens = extract_patch_tokens_min_windows(
|
| 327 |
+
image_tensor,
|
| 328 |
+
dino_model,
|
| 329 |
+
dino_processor,
|
| 330 |
+
window_size=224,
|
| 331 |
+
device=image_tensor.device,
|
| 332 |
+
)
|
| 333 |
+
output_size = (image_tensor.shape[2] // 16, image_tensor.shape[3] // 16)
|
| 334 |
+
img_len = output_size[0] * output_size[1]
|
| 335 |
+
|
| 336 |
+
maps: Dict[str, Union[PIL.Image.Image, np.ndarray]] = {}
|
| 337 |
+
for map_index, aov_name in enumerate(AOVS):
|
| 338 |
+
prompt_embeds, _, pooled_prompt_embeds, _ = self._inverse.encode_prompt(
|
| 339 |
+
prompt=INVERSE_PROMPTS[aov_name],
|
| 340 |
+
prompt_2=None,
|
| 341 |
+
prompt_3=None,
|
| 342 |
+
do_classifier_free_guidance=False,
|
| 343 |
+
)
|
| 344 |
+
map_aware_mask = self.imaa(
|
| 345 |
+
patch_tokens=patch_tokens,
|
| 346 |
+
output_size=output_size,
|
| 347 |
+
map_ids=torch.tensor([map_index], device=image_tensor.device),
|
| 348 |
+
)
|
| 349 |
+
attn_mask = build_attn_mask(map_aware_mask, 154, img_len, 0.7)
|
| 350 |
+
result = self._inverse(
|
| 351 |
+
image=image_tensor,
|
| 352 |
+
prompt_embeds=prompt_embeds,
|
| 353 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 354 |
+
guidance_scale=0.0,
|
| 355 |
+
image_guidance_scale=0.0,
|
| 356 |
+
num_inference_steps=num_inference_steps,
|
| 357 |
+
output_type=output_type,
|
| 358 |
+
aov=[aov_name],
|
| 359 |
+
map_aware_mask=attn_mask.to(image_tensor.device),
|
| 360 |
+
)
|
| 361 |
+
maps[aov_name] = result.images[0] if output_type == "pil" else result.images[0]
|
| 362 |
+
return maps
|
| 363 |
+
|
| 364 |
+
@torch.no_grad()
|
| 365 |
+
def render(
|
| 366 |
+
self,
|
| 367 |
+
maps: Dict[str, PipelineImageInput],
|
| 368 |
+
weather: str = "rainy",
|
| 369 |
+
num_inference_steps: int = 50,
|
| 370 |
+
guidance_scale: float = 6.0,
|
| 371 |
+
image_guidance_scale: float = 1.5,
|
| 372 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 373 |
+
output_type: str = "pil",
|
| 374 |
+
render_size: Optional[int] = None,
|
| 375 |
+
) -> Union[PIL.Image.Image, np.ndarray]:
|
| 376 |
+
r"""
|
| 377 |
+
Run forward weather rendering from intrinsic maps.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
maps: Dict with keys `albedo`, `normal`, `roughness`, `metallic`, `irradiance`.
|
| 381 |
+
weather: Preset name (`rainy`, `sunny`, ...) or a custom prompt string.
|
| 382 |
+
"""
|
| 383 |
+
if render_size is None:
|
| 384 |
+
sample = next(iter(maps.values()))
|
| 385 |
+
if isinstance(sample, PIL.Image.Image):
|
| 386 |
+
render_size = max(sample.size)
|
| 387 |
+
elif isinstance(sample, np.ndarray):
|
| 388 |
+
render_size = sample.shape[0]
|
| 389 |
+
else:
|
| 390 |
+
render_size = sample.shape[-1]
|
| 391 |
+
|
| 392 |
+
device = self._execution_device
|
| 393 |
+
dtype = self.dtype
|
| 394 |
+
aov_tensors = {
|
| 395 |
+
name: self._map_array_to_tensor(maps[name], render_size, device, dtype) for name in AOVS if name in maps
|
| 396 |
+
}
|
| 397 |
+
for name in AOVS:
|
| 398 |
+
if name not in aov_tensors:
|
| 399 |
+
aov_tensors[name] = torch.zeros((1, 3, render_size, render_size), device=device, dtype=dtype)
|
| 400 |
+
|
| 401 |
+
result = self._forward(
|
| 402 |
+
albedo=aov_tensors["albedo"],
|
| 403 |
+
normal=aov_tensors["normal"],
|
| 404 |
+
roughness=aov_tensors["roughness"],
|
| 405 |
+
metallic=aov_tensors["metallic"],
|
| 406 |
+
irradiance=aov_tensors.get("irradiance"),
|
| 407 |
+
prompt=[self._resolve_weather_prompt(weather)],
|
| 408 |
+
guidance_scale=guidance_scale,
|
| 409 |
+
image_guidance_scale=image_guidance_scale,
|
| 410 |
+
num_inference_steps=num_inference_steps,
|
| 411 |
+
required_aovs=AOVS,
|
| 412 |
+
generator=generator,
|
| 413 |
+
output_type=output_type,
|
| 414 |
+
)
|
| 415 |
+
return result.images[0]
|
| 416 |
+
|
| 417 |
+
@torch.no_grad()
|
| 418 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 419 |
+
def __call__(
|
| 420 |
+
self,
|
| 421 |
+
image: PipelineImageInput,
|
| 422 |
+
weather: str = "rainy",
|
| 423 |
+
dino_model: Optional[PreTrainedModel] = None,
|
| 424 |
+
dino_processor: Optional[PreTrainedTokenizer] = None,
|
| 425 |
+
num_inverse_steps: int = 50,
|
| 426 |
+
num_forward_steps: int = 50,
|
| 427 |
+
guidance_scale: float = 6.0,
|
| 428 |
+
image_guidance_scale: float = 1.5,
|
| 429 |
+
image_size: Optional[int] = None,
|
| 430 |
+
render_size: Optional[int] = None,
|
| 431 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 432 |
+
output_type: str = "pil",
|
| 433 |
+
return_maps: bool = False,
|
| 434 |
+
return_dict: bool = True,
|
| 435 |
+
):
|
| 436 |
+
r"""
|
| 437 |
+
Decompose an RGB image into intrinsic maps, then render a weather-conditioned RGB image.
|
| 438 |
+
|
| 439 |
+
Examples:
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
image: Source RGB photograph or render.
|
| 443 |
+
weather: Weather preset (`rainy`, `sunny`, `snowy`, `foggy`, ...) or custom prompt.
|
| 444 |
+
dino_model (`transformers.PreTrainedModel`, *optional*):
|
| 445 |
+
DINOv2 model for IMAA gating. Required unless maps are passed via `maps=`.
|
| 446 |
+
dino_processor: DINO image processor paired with `dino_model`.
|
| 447 |
+
num_inverse_steps: Denoising steps for each intrinsic map.
|
| 448 |
+
num_forward_steps: Denoising steps for forward rendering.
|
| 449 |
+
guidance_scale: Text CFG for forward rendering.
|
| 450 |
+
image_guidance_scale: Image CFG for forward rendering.
|
| 451 |
+
image_size: Square size for inverse stage.
|
| 452 |
+
render_size: Square size for forward stage. Defaults to `image_size`.
|
| 453 |
+
return_maps: If `True`, also return decomposed intrinsic maps.
|
| 454 |
+
return_dict: Return [`IntrinsicWeatherPipelineOutput`] when `True`.
|
| 455 |
+
"""
|
| 456 |
+
if dino_model is None or dino_processor is None:
|
| 457 |
+
raise ValueError("`dino_model` and `dino_processor` are required for end-to-end weather editing.")
|
| 458 |
+
|
| 459 |
+
maps = self.decompose(
|
| 460 |
+
image=image,
|
| 461 |
+
dino_model=dino_model,
|
| 462 |
+
dino_processor=dino_processor,
|
| 463 |
+
num_inference_steps=num_inverse_steps,
|
| 464 |
+
image_size=image_size,
|
| 465 |
+
output_type="np",
|
| 466 |
+
)
|
| 467 |
+
rendered = self.render(
|
| 468 |
+
maps=maps,
|
| 469 |
+
weather=weather,
|
| 470 |
+
num_inference_steps=num_forward_steps,
|
| 471 |
+
guidance_scale=guidance_scale,
|
| 472 |
+
image_guidance_scale=image_guidance_scale,
|
| 473 |
+
generator=generator,
|
| 474 |
+
output_type=output_type,
|
| 475 |
+
render_size=render_size or image_size,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
self.maybe_free_model_hooks()
|
| 479 |
+
|
| 480 |
+
if not return_dict:
|
| 481 |
+
return (rendered, maps) if return_maps else (rendered,)
|
| 482 |
+
|
| 483 |
+
return IntrinsicWeatherPipelineOutput(
|
| 484 |
+
images=[rendered] if output_type == "pil" else rendered,
|
| 485 |
+
maps=maps if return_maps else None,
|
| 486 |
+
)
|
pipeline_intrinsic_weather_forward.py
ADDED
|
@@ -0,0 +1,1191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: IntrinsicWeatherForwardPipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import inspect
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
import PIL.Image
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torchvision
|
| 29 |
+
from transformers import (
|
| 30 |
+
CLIPTextModelWithProjection,
|
| 31 |
+
CLIPTokenizer,
|
| 32 |
+
T5EncoderModel,
|
| 33 |
+
T5TokenizerFast,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 37 |
+
from diffusers.loaders import SD3LoraLoaderMixin
|
| 38 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 39 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
| 40 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
| 41 |
+
from diffusers.utils import (
|
| 42 |
+
deprecate,
|
| 43 |
+
is_torch_xla_available,
|
| 44 |
+
logging,
|
| 45 |
+
replace_example_docstring,
|
| 46 |
+
)
|
| 47 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 48 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 49 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
| 50 |
+
|
| 51 |
+
from pipeline_utils import (
|
| 52 |
+
load_transformer_from_subfolder,
|
| 53 |
+
load_transformer_lora,
|
| 54 |
+
resolve_repo_dir,
|
| 55 |
+
resolve_transformer_lora_dir,
|
| 56 |
+
set_flow_timesteps,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if is_torch_xla_available():
|
| 61 |
+
import torch_xla.core.xla_model as xm
|
| 62 |
+
|
| 63 |
+
XLA_AVAILABLE = True
|
| 64 |
+
else:
|
| 65 |
+
XLA_AVAILABLE = False
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 69 |
+
|
| 70 |
+
EXAMPLE_DOC_STRING = """
|
| 71 |
+
Examples:
|
| 72 |
+
```py
|
| 73 |
+
>>> from pathlib import Path
|
| 74 |
+
>>> import torch
|
| 75 |
+
>>> from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
|
| 76 |
+
|
| 77 |
+
>>> repo_dir = Path("./IntrisicWeather-diffusers").resolve()
|
| 78 |
+
>>> pipe = IntrinsicWeatherForwardPipeline.from_pretrained(
|
| 79 |
+
... repo_dir,
|
| 80 |
+
... transformer_subfolder="forward",
|
| 81 |
+
... local_files_only=True,
|
| 82 |
+
... torch_dtype=torch.float16,
|
| 83 |
+
... load_lora=True,
|
| 84 |
+
... )
|
| 85 |
+
>>> pipe.to("cuda")
|
| 86 |
+
>>> image = pipe(
|
| 87 |
+
... prompt="A rainy day.",
|
| 88 |
+
... albedo=albedo_tensor,
|
| 89 |
+
... normal=normal_tensor,
|
| 90 |
+
... roughness=roughness_tensor,
|
| 91 |
+
... metallic=metallic_tensor,
|
| 92 |
+
... irradiance=irradiance_tensor,
|
| 93 |
+
... num_inference_steps=50,
|
| 94 |
+
... ).images[0]
|
| 95 |
+
```
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 102 |
+
def retrieve_latents(
|
| 103 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 104 |
+
):
|
| 105 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 106 |
+
return encoder_output.latent_dist.sample(generator)
|
| 107 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 108 |
+
return encoder_output.latent_dist.mode()
|
| 109 |
+
elif hasattr(encoder_output, "latents"):
|
| 110 |
+
return encoder_output.latents
|
| 111 |
+
else:
|
| 112 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 113 |
+
|
| 114 |
+
class VaeAoVProcessor(VaeImageProcessor):
|
| 115 |
+
def postprocess(
|
| 116 |
+
self,
|
| 117 |
+
image: torch.FloatTensor,
|
| 118 |
+
output_type: str = "pil",
|
| 119 |
+
do_denormalize: Optional[List[bool]] = None,
|
| 120 |
+
do_gamma_correction: bool = True,
|
| 121 |
+
):
|
| 122 |
+
if not isinstance(image, torch.Tensor):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
| 125 |
+
)
|
| 126 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
| 127 |
+
deprecation_message = (
|
| 128 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
| 129 |
+
"`pil`, `np`, `pt`, `latent`"
|
| 130 |
+
)
|
| 131 |
+
deprecate(
|
| 132 |
+
"Unsupported output_type",
|
| 133 |
+
"1.0.0",
|
| 134 |
+
deprecation_message,
|
| 135 |
+
standard_warn=False,
|
| 136 |
+
)
|
| 137 |
+
output_type = "np"
|
| 138 |
+
|
| 139 |
+
if output_type == "latent":
|
| 140 |
+
return image
|
| 141 |
+
|
| 142 |
+
if do_denormalize is None:
|
| 143 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
| 144 |
+
|
| 145 |
+
image = torch.stack(
|
| 146 |
+
[
|
| 147 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
| 148 |
+
for i in range(image.shape[0])
|
| 149 |
+
]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Gamma correction
|
| 153 |
+
if do_gamma_correction:
|
| 154 |
+
# image = torch.pow(image, 1.0 / 2.2)
|
| 155 |
+
image = image ** (1 / 2.2)
|
| 156 |
+
|
| 157 |
+
if output_type == "pt":
|
| 158 |
+
return image
|
| 159 |
+
|
| 160 |
+
image = self.pt_to_numpy(image)
|
| 161 |
+
|
| 162 |
+
if output_type == "np":
|
| 163 |
+
return image
|
| 164 |
+
|
| 165 |
+
if output_type == "pil":
|
| 166 |
+
return self.numpy_to_pil(image)
|
| 167 |
+
|
| 168 |
+
class IntrinsicWeatherForwardPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
| 169 |
+
r"""
|
| 170 |
+
Args:
|
| 171 |
+
transformer ([`SD3Transformer2DModel`]):
|
| 172 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 173 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 174 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 175 |
+
vae ([`AutoencoderKL`]):
|
| 176 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 177 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
| 178 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 179 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
| 180 |
+
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
| 181 |
+
as its dimension.
|
| 182 |
+
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
| 183 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 184 |
+
specifically the
|
| 185 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
| 186 |
+
variant.
|
| 187 |
+
text_encoder_3 ([`T5EncoderModel`]):
|
| 188 |
+
Frozen text-encoder. Stable Diffusion 3 uses
|
| 189 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 190 |
+
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 191 |
+
tokenizer (`CLIPTokenizer`):
|
| 192 |
+
Tokenizer of class
|
| 193 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 194 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 195 |
+
Second Tokenizer of class
|
| 196 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 197 |
+
tokenizer_3 (`T5TokenizerFast`):
|
| 198 |
+
Tokenizer of class
|
| 199 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
| 203 |
+
_optional_components = []
|
| 204 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def load_transformer(
|
| 208 |
+
cls,
|
| 209 |
+
transformer_subfolder: str,
|
| 210 |
+
pretrained_model_name_or_path: str | Path,
|
| 211 |
+
*,
|
| 212 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 213 |
+
device: str | torch.device | None = None,
|
| 214 |
+
):
|
| 215 |
+
return load_transformer_from_subfolder(
|
| 216 |
+
pretrained_model_name_or_path,
|
| 217 |
+
transformer_subfolder,
|
| 218 |
+
dtype=dtype,
|
| 219 |
+
device=device,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def from_pretrained(
|
| 224 |
+
cls,
|
| 225 |
+
pretrained_model_name_or_path,
|
| 226 |
+
*args,
|
| 227 |
+
transformer_subfolder: str = "forward",
|
| 228 |
+
transformer=None,
|
| 229 |
+
load_lora: bool = True,
|
| 230 |
+
**kwargs,
|
| 231 |
+
):
|
| 232 |
+
repo_dir = resolve_repo_dir(pretrained_model_name_or_path)
|
| 233 |
+
dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 234 |
+
device = kwargs.get("device")
|
| 235 |
+
if transformer is None:
|
| 236 |
+
transformer = cls.load_transformer(
|
| 237 |
+
transformer_subfolder,
|
| 238 |
+
repo_dir,
|
| 239 |
+
dtype=dtype,
|
| 240 |
+
device=device,
|
| 241 |
+
)
|
| 242 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 243 |
+
repo_dir.as_posix(),
|
| 244 |
+
*args,
|
| 245 |
+
transformer=transformer,
|
| 246 |
+
custom_pipeline=str(repo_dir / "pipeline_intrinsic_weather_forward.py"),
|
| 247 |
+
trust_remote_code=True,
|
| 248 |
+
**kwargs,
|
| 249 |
+
)
|
| 250 |
+
lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder)
|
| 251 |
+
if load_lora and lora_dir is not None:
|
| 252 |
+
pipe.load_lora_weights(lora_dir.as_posix())
|
| 253 |
+
if device is not None:
|
| 254 |
+
pipe = pipe.to(device)
|
| 255 |
+
return pipe
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
transformer: SD3Transformer2DModel,
|
| 260 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 261 |
+
vae: AutoencoderKL,
|
| 262 |
+
text_encoder: CLIPTextModelWithProjection,
|
| 263 |
+
tokenizer: CLIPTokenizer,
|
| 264 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 265 |
+
tokenizer_2: CLIPTokenizer,
|
| 266 |
+
text_encoder_3: T5EncoderModel,
|
| 267 |
+
tokenizer_3: T5TokenizerFast,
|
| 268 |
+
):
|
| 269 |
+
super().__init__()
|
| 270 |
+
|
| 271 |
+
self.register_modules(
|
| 272 |
+
vae=vae,
|
| 273 |
+
text_encoder=text_encoder,
|
| 274 |
+
text_encoder_2=text_encoder_2,
|
| 275 |
+
text_encoder_3=text_encoder_3,
|
| 276 |
+
tokenizer=tokenizer,
|
| 277 |
+
tokenizer_2=tokenizer_2,
|
| 278 |
+
tokenizer_3=tokenizer_3,
|
| 279 |
+
transformer=transformer,
|
| 280 |
+
scheduler=scheduler,
|
| 281 |
+
)
|
| 282 |
+
self.vae_scale_factor = (
|
| 283 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 284 |
+
)
|
| 285 |
+
self.image_processor = VaeAoVProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 286 |
+
self.tokenizer_max_length = (
|
| 287 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 288 |
+
)
|
| 289 |
+
self.default_sample_size = (
|
| 290 |
+
self.transformer.config.sample_size
|
| 291 |
+
if hasattr(self, "transformer") and self.transformer is not None
|
| 292 |
+
else 128
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def _get_t5_prompt_embeds(
|
| 296 |
+
self,
|
| 297 |
+
prompt: Union[str, List[str]] = None,
|
| 298 |
+
num_images_per_prompt: int = 1,
|
| 299 |
+
device: Optional[torch.device] = None,
|
| 300 |
+
dtype: Optional[torch.dtype] = None,
|
| 301 |
+
):
|
| 302 |
+
device = device or self._execution_device
|
| 303 |
+
dtype = dtype or self.text_encoder.dtype
|
| 304 |
+
|
| 305 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 306 |
+
batch_size = len(prompt)
|
| 307 |
+
|
| 308 |
+
if self.text_encoder_3 is None:
|
| 309 |
+
return torch.zeros(
|
| 310 |
+
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
|
| 311 |
+
device=device,
|
| 312 |
+
dtype=dtype,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
text_inputs = self.tokenizer_3(
|
| 316 |
+
prompt,
|
| 317 |
+
padding="max_length",
|
| 318 |
+
max_length=self.tokenizer_max_length,
|
| 319 |
+
truncation=True,
|
| 320 |
+
add_special_tokens=True,
|
| 321 |
+
return_tensors="pt",
|
| 322 |
+
)
|
| 323 |
+
text_input_ids = text_inputs.input_ids
|
| 324 |
+
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
| 325 |
+
|
| 326 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 327 |
+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 328 |
+
logger.warning(
|
| 329 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 330 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
| 334 |
+
|
| 335 |
+
dtype = self.text_encoder_3.dtype
|
| 336 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 337 |
+
|
| 338 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 339 |
+
|
| 340 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 341 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 342 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 343 |
+
|
| 344 |
+
return prompt_embeds
|
| 345 |
+
|
| 346 |
+
def _get_clip_prompt_embeds(
|
| 347 |
+
self,
|
| 348 |
+
prompt: Union[str, List[str]],
|
| 349 |
+
num_images_per_prompt: int = 1,
|
| 350 |
+
device: Optional[torch.device] = None,
|
| 351 |
+
clip_skip: Optional[int] = None,
|
| 352 |
+
clip_model_index: int = 0,
|
| 353 |
+
):
|
| 354 |
+
device = device or self._execution_device
|
| 355 |
+
|
| 356 |
+
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
| 357 |
+
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
| 358 |
+
|
| 359 |
+
tokenizer = clip_tokenizers[clip_model_index]
|
| 360 |
+
text_encoder = clip_text_encoders[clip_model_index]
|
| 361 |
+
|
| 362 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 363 |
+
batch_size = len(prompt)
|
| 364 |
+
|
| 365 |
+
text_inputs = tokenizer(
|
| 366 |
+
prompt,
|
| 367 |
+
padding="max_length",
|
| 368 |
+
max_length=self.tokenizer_max_length,
|
| 369 |
+
truncation=True,
|
| 370 |
+
return_tensors="pt",
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
text_input_ids = text_inputs.input_ids
|
| 374 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 375 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 376 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 377 |
+
logger.warning(
|
| 378 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 379 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 380 |
+
)
|
| 381 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 382 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 383 |
+
|
| 384 |
+
if clip_skip is None:
|
| 385 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 386 |
+
else:
|
| 387 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
| 388 |
+
|
| 389 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 390 |
+
|
| 391 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 392 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 393 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 394 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 395 |
+
|
| 396 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 397 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 398 |
+
|
| 399 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 400 |
+
|
| 401 |
+
def encode_prompt(
|
| 402 |
+
self,
|
| 403 |
+
prompt: Union[str, List[str]],
|
| 404 |
+
prompt_2: Union[str, List[str]],
|
| 405 |
+
prompt_3: Union[str, List[str]],
|
| 406 |
+
device: Optional[torch.device] = None,
|
| 407 |
+
num_images_per_prompt: int = 1,
|
| 408 |
+
do_classifier_free_guidance: bool = True,
|
| 409 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 410 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 411 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 412 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 413 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 414 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 415 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 416 |
+
clip_skip: Optional[int] = None,
|
| 417 |
+
):
|
| 418 |
+
r"""
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 422 |
+
prompt to be encoded
|
| 423 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 424 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 425 |
+
used in all text-encoders
|
| 426 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 427 |
+
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 428 |
+
used in all text-encoders
|
| 429 |
+
device: (`torch.device`):
|
| 430 |
+
torch device
|
| 431 |
+
num_images_per_prompt (`int`):
|
| 432 |
+
number of images that should be generated per prompt
|
| 433 |
+
do_classifier_free_guidance (`bool`):
|
| 434 |
+
whether to use classifier free guidance or not
|
| 435 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 436 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 437 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 438 |
+
less than `1`).
|
| 439 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 440 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 441 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 442 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 443 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 444 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
| 445 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 446 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 447 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 448 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 449 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 450 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 451 |
+
argument.
|
| 452 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 453 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 454 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 455 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 456 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 457 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 458 |
+
input argument.
|
| 459 |
+
clip_skip (`int`, *optional*):
|
| 460 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 461 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 462 |
+
"""
|
| 463 |
+
device = device or self._execution_device
|
| 464 |
+
|
| 465 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 466 |
+
if prompt is not None:
|
| 467 |
+
batch_size = len(prompt)
|
| 468 |
+
else:
|
| 469 |
+
batch_size = prompt_embeds.shape[0]
|
| 470 |
+
|
| 471 |
+
if prompt_embeds is None:
|
| 472 |
+
prompt_2 = prompt_2 or prompt
|
| 473 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 474 |
+
|
| 475 |
+
prompt_3 = prompt_3 or prompt
|
| 476 |
+
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
| 477 |
+
|
| 478 |
+
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 479 |
+
prompt=prompt,
|
| 480 |
+
device=device,
|
| 481 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 482 |
+
clip_skip=clip_skip,
|
| 483 |
+
clip_model_index=0,
|
| 484 |
+
)
|
| 485 |
+
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 486 |
+
prompt=prompt_2,
|
| 487 |
+
device=device,
|
| 488 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 489 |
+
clip_skip=clip_skip,
|
| 490 |
+
clip_model_index=1,
|
| 491 |
+
)
|
| 492 |
+
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
| 493 |
+
|
| 494 |
+
t5_prompt_embed = self._get_t5_prompt_embeds(
|
| 495 |
+
prompt=prompt_3,
|
| 496 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 497 |
+
device=device,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 501 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 505 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
| 506 |
+
|
| 507 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 508 |
+
negative_prompt = negative_prompt or ""
|
| 509 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
| 510 |
+
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
| 511 |
+
|
| 512 |
+
# normalize str to list
|
| 513 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 514 |
+
negative_prompt_2 = (
|
| 515 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
| 516 |
+
)
|
| 517 |
+
negative_prompt_3 = (
|
| 518 |
+
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 522 |
+
raise TypeError(
|
| 523 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 524 |
+
f" {type(prompt)}."
|
| 525 |
+
)
|
| 526 |
+
elif batch_size != len(negative_prompt):
|
| 527 |
+
raise ValueError(
|
| 528 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 529 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 530 |
+
" the batch size of `prompt`."
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 534 |
+
negative_prompt,
|
| 535 |
+
device=device,
|
| 536 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 537 |
+
clip_skip=None,
|
| 538 |
+
clip_model_index=0,
|
| 539 |
+
)
|
| 540 |
+
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 541 |
+
negative_prompt_2,
|
| 542 |
+
device=device,
|
| 543 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 544 |
+
clip_skip=None,
|
| 545 |
+
clip_model_index=1,
|
| 546 |
+
)
|
| 547 |
+
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
| 548 |
+
|
| 549 |
+
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
| 550 |
+
prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
| 554 |
+
negative_clip_prompt_embeds,
|
| 555 |
+
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
| 559 |
+
negative_pooled_prompt_embeds = torch.cat(
|
| 560 |
+
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 564 |
+
|
| 565 |
+
def check_inputs(
|
| 566 |
+
self,
|
| 567 |
+
prompt,
|
| 568 |
+
prompt_2,
|
| 569 |
+
prompt_3,
|
| 570 |
+
# height,
|
| 571 |
+
# width,
|
| 572 |
+
negative_prompt=None,
|
| 573 |
+
negative_prompt_2=None,
|
| 574 |
+
negative_prompt_3=None,
|
| 575 |
+
prompt_embeds=None,
|
| 576 |
+
negative_prompt_embeds=None,
|
| 577 |
+
pooled_prompt_embeds=None,
|
| 578 |
+
negative_pooled_prompt_embeds=None,
|
| 579 |
+
callback_on_step_end_tensor_inputs=None,
|
| 580 |
+
):
|
| 581 |
+
# if height % 8 != 0 or width % 8 != 0:
|
| 582 |
+
# raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 583 |
+
|
| 584 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 585 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 586 |
+
):
|
| 587 |
+
raise ValueError(
|
| 588 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if prompt is not None and prompt_embeds is not None:
|
| 592 |
+
raise ValueError(
|
| 593 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 594 |
+
" only forward one of the two."
|
| 595 |
+
)
|
| 596 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 597 |
+
raise ValueError(
|
| 598 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 599 |
+
" only forward one of the two."
|
| 600 |
+
)
|
| 601 |
+
elif prompt_3 is not None and prompt_embeds is not None:
|
| 602 |
+
raise ValueError(
|
| 603 |
+
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 604 |
+
" only forward one of the two."
|
| 605 |
+
)
|
| 606 |
+
elif prompt is None and prompt_embeds is None:
|
| 607 |
+
raise ValueError(
|
| 608 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 609 |
+
)
|
| 610 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 611 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 612 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 613 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 614 |
+
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
| 615 |
+
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
| 616 |
+
|
| 617 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 618 |
+
raise ValueError(
|
| 619 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 620 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 621 |
+
)
|
| 622 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 623 |
+
raise ValueError(
|
| 624 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 625 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 626 |
+
)
|
| 627 |
+
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
| 630 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 634 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 635 |
+
raise ValueError(
|
| 636 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 637 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 638 |
+
f" {negative_prompt_embeds.shape}."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 647 |
+
raise ValueError(
|
| 648 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
def prepare_latents(
|
| 652 |
+
self,
|
| 653 |
+
batch_size,
|
| 654 |
+
num_channels_latents,
|
| 655 |
+
height,
|
| 656 |
+
width,
|
| 657 |
+
dtype,
|
| 658 |
+
device,
|
| 659 |
+
generator,
|
| 660 |
+
latents=None,
|
| 661 |
+
):
|
| 662 |
+
if latents is not None:
|
| 663 |
+
return latents.to(device=device, dtype=dtype)
|
| 664 |
+
|
| 665 |
+
shape = (
|
| 666 |
+
batch_size,
|
| 667 |
+
num_channels_latents,
|
| 668 |
+
int(height) // self.vae_scale_factor,
|
| 669 |
+
int(width) // self.vae_scale_factor,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 675 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 679 |
+
|
| 680 |
+
return latents
|
| 681 |
+
|
| 682 |
+
def prepare_image_latents(
|
| 683 |
+
self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
|
| 684 |
+
):
|
| 685 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 686 |
+
raise ValueError(
|
| 687 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
image = image.to(device=device, dtype=dtype)
|
| 691 |
+
|
| 692 |
+
batch_size = batch_size * num_images_per_prompt
|
| 693 |
+
|
| 694 |
+
if image.shape[1] == self.vae.config.latent_channels:
|
| 695 |
+
image_latents = image
|
| 696 |
+
else:
|
| 697 |
+
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
| 698 |
+
# ? normalize image latents
|
| 699 |
+
# image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 700 |
+
|
| 701 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 702 |
+
# expand image_latents for batch_size
|
| 703 |
+
deprecation_message = (
|
| 704 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
| 705 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
| 706 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
| 707 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
| 708 |
+
)
|
| 709 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
| 710 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 711 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 712 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 713 |
+
raise ValueError(
|
| 714 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 715 |
+
)
|
| 716 |
+
else:
|
| 717 |
+
image_latents = torch.cat([image_latents], dim=0)
|
| 718 |
+
|
| 719 |
+
if do_classifier_free_guidance:
|
| 720 |
+
uncond_image_latents = torch.zeros_like(image_latents)
|
| 721 |
+
image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
|
| 722 |
+
|
| 723 |
+
return image_latents
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
@property
|
| 727 |
+
def guidance_scale(self):
|
| 728 |
+
return self._guidance_scale
|
| 729 |
+
@property
|
| 730 |
+
def image_guidance_scale(self):
|
| 731 |
+
return self._image_guidance_scale
|
| 732 |
+
|
| 733 |
+
@property
|
| 734 |
+
def clip_skip(self):
|
| 735 |
+
return self._clip_skip
|
| 736 |
+
|
| 737 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 738 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 739 |
+
# corresponds to doing no classifier free guidance.
|
| 740 |
+
@property
|
| 741 |
+
def do_classifier_free_guidance(self):
|
| 742 |
+
return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
|
| 743 |
+
|
| 744 |
+
@property
|
| 745 |
+
def joint_attention_kwargs(self):
|
| 746 |
+
return self._joint_attention_kwargs
|
| 747 |
+
|
| 748 |
+
@property
|
| 749 |
+
def num_timesteps(self):
|
| 750 |
+
return self._num_timesteps
|
| 751 |
+
|
| 752 |
+
@property
|
| 753 |
+
def interrupt(self):
|
| 754 |
+
return self._interrupt
|
| 755 |
+
|
| 756 |
+
@torch.no_grad()
|
| 757 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 758 |
+
def __call__(
|
| 759 |
+
self,
|
| 760 |
+
prompt: Union[str, List[str]] = None,
|
| 761 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 762 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 763 |
+
albedo: PipelineImageInput = None,
|
| 764 |
+
normal : PipelineImageInput = None,
|
| 765 |
+
metallic : PipelineImageInput = None,
|
| 766 |
+
roughness : PipelineImageInput = None,
|
| 767 |
+
irradiance : PipelineImageInput = None,
|
| 768 |
+
height: Optional[int] = None,
|
| 769 |
+
width: Optional[int] = None,
|
| 770 |
+
num_inference_steps: int = 28,
|
| 771 |
+
timesteps: List[int] = None,
|
| 772 |
+
guidance_scale: float = 7.0,
|
| 773 |
+
image_guidance_scale: float = 1.5,
|
| 774 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 775 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 776 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 777 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 778 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 779 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 780 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 781 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 782 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 783 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 784 |
+
output_type: Optional[str] = "pil",
|
| 785 |
+
return_dict: bool = True,
|
| 786 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 787 |
+
clip_skip: Optional[int] = None,
|
| 788 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 789 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 790 |
+
mask_img: Optional[PipelineImageInput] = None,
|
| 791 |
+
required_aovs: List[str] = ["albedo"],
|
| 792 |
+
**kwargs
|
| 793 |
+
):
|
| 794 |
+
r"""
|
| 795 |
+
Function invoked when calling the pipeline for generation.
|
| 796 |
+
|
| 797 |
+
Args:
|
| 798 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 799 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 800 |
+
instead.
|
| 801 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 802 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 803 |
+
will be used instead
|
| 804 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 805 |
+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 806 |
+
will be used instead
|
| 807 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 808 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 809 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 810 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 811 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 812 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 813 |
+
expense of slower inference.
|
| 814 |
+
timesteps (`List[int]`, *optional*):
|
| 815 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 816 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 817 |
+
passed will be used. Must be in descending order.
|
| 818 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 819 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 820 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 821 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 822 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 823 |
+
usually at the expense of lower image quality.
|
| 824 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 825 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 826 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 827 |
+
less than `1`).
|
| 828 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 829 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 830 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
| 831 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 832 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 833 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
| 834 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 835 |
+
The number of images to generate per prompt.
|
| 836 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 837 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 838 |
+
to make generation deterministic.
|
| 839 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 840 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 841 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 842 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 843 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 844 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 845 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 846 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 847 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 848 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 849 |
+
argument.
|
| 850 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 851 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 852 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 853 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 854 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 855 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 856 |
+
input argument.
|
| 857 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 858 |
+
The output format of the generate image. Choose between
|
| 859 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 860 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 861 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 862 |
+
of a plain tuple.
|
| 863 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 864 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 865 |
+
`self.processor` in
|
| 866 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 867 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 868 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 869 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 870 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 871 |
+
`callback_on_step_end_tensor_inputs`.
|
| 872 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 873 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 874 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 875 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 876 |
+
|
| 877 |
+
Examples:
|
| 878 |
+
|
| 879 |
+
Returns:
|
| 880 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
| 881 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
| 882 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 883 |
+
"""
|
| 884 |
+
|
| 885 |
+
# height = height or self.default_sample_size * self.vae_scale_factor
|
| 886 |
+
# width = width or self.default_sample_size * self.vae_scale_factor
|
| 887 |
+
|
| 888 |
+
# 1. Check inputs. Raise error if not correct
|
| 889 |
+
self.check_inputs(
|
| 890 |
+
prompt,
|
| 891 |
+
prompt_2,
|
| 892 |
+
prompt_3,
|
| 893 |
+
negative_prompt=negative_prompt,
|
| 894 |
+
negative_prompt_2=negative_prompt_2,
|
| 895 |
+
negative_prompt_3=negative_prompt_3,
|
| 896 |
+
prompt_embeds=prompt_embeds,
|
| 897 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 898 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 899 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 900 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
self._guidance_scale = guidance_scale
|
| 904 |
+
self._image_guidance_scale = image_guidance_scale
|
| 905 |
+
self._clip_skip = clip_skip
|
| 906 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 907 |
+
self._interrupt = False
|
| 908 |
+
|
| 909 |
+
# 2. Define call parameters
|
| 910 |
+
if prompt is not None and isinstance(prompt, str):
|
| 911 |
+
batch_size = 1
|
| 912 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 913 |
+
batch_size = len(prompt)
|
| 914 |
+
else:
|
| 915 |
+
batch_size = prompt_embeds.shape[0]
|
| 916 |
+
|
| 917 |
+
device = self._execution_device
|
| 918 |
+
|
| 919 |
+
(
|
| 920 |
+
prompt_embeds,
|
| 921 |
+
negative_prompt_embeds,
|
| 922 |
+
pooled_prompt_embeds,
|
| 923 |
+
negative_pooled_prompt_embeds,
|
| 924 |
+
) = self.encode_prompt(
|
| 925 |
+
prompt=prompt,
|
| 926 |
+
prompt_2=prompt_2,
|
| 927 |
+
prompt_3=prompt_3,
|
| 928 |
+
negative_prompt=negative_prompt,
|
| 929 |
+
negative_prompt_2=negative_prompt_2,
|
| 930 |
+
negative_prompt_3=negative_prompt_3,
|
| 931 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 932 |
+
prompt_embeds=prompt_embeds,
|
| 933 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 934 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 935 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 936 |
+
device=device,
|
| 937 |
+
clip_skip=self.clip_skip,
|
| 938 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 939 |
+
)
|
| 940 |
+
# print("prompt:", prompt_embeds.shape)
|
| 941 |
+
if self.do_classifier_free_guidance:
|
| 942 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 943 |
+
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
|
| 944 |
+
|
| 945 |
+
# Similiarly
|
| 946 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0)
|
| 947 |
+
|
| 948 |
+
# if self.do_classifier_free_guidance:
|
| 949 |
+
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 950 |
+
# pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 951 |
+
|
| 952 |
+
# 3. Preprocess image
|
| 953 |
+
# Get image dimensions from the first available AOV (albedo, normal, roughness, metallic, irradiance)
|
| 954 |
+
old_height, old_width = None, None
|
| 955 |
+
aov_dict = {"albedo": albedo, "normal": normal, "roughness": roughness, "metallic": metallic, "irradiance": irradiance}
|
| 956 |
+
for aov_name in required_aovs:
|
| 957 |
+
aov_value = aov_dict.get(aov_name)
|
| 958 |
+
if aov_value is not None:
|
| 959 |
+
if isinstance(aov_value, PIL.Image.Image):
|
| 960 |
+
old_width, old_height = aov_value.size
|
| 961 |
+
break
|
| 962 |
+
elif isinstance(aov_value, torch.Tensor):
|
| 963 |
+
if len(aov_value.shape) == 4:
|
| 964 |
+
old_height = aov_value.shape[2]
|
| 965 |
+
old_width = aov_value.shape[3]
|
| 966 |
+
else:
|
| 967 |
+
old_height = aov_value.shape[1]
|
| 968 |
+
old_width = aov_value.shape[2]
|
| 969 |
+
break
|
| 970 |
+
|
| 971 |
+
# Fallback: if no AOV available, use default dimensions
|
| 972 |
+
if old_height is None or old_width is None:
|
| 973 |
+
old_height, old_width = 1024, 1024 # default fallback
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
# For all aovs, the preprocessing remap the values to [-1, 1]
|
| 977 |
+
preprocessed_aovs = {}
|
| 978 |
+
for aov_name in required_aovs:
|
| 979 |
+
if aov_name == "albedo":
|
| 980 |
+
if albedo is not None:
|
| 981 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
| 982 |
+
albedo
|
| 983 |
+
)
|
| 984 |
+
else:
|
| 985 |
+
preprocessed_aovs[aov_name] = None
|
| 986 |
+
|
| 987 |
+
if aov_name == "normal":
|
| 988 |
+
if normal is not None:
|
| 989 |
+
preprocessed_aovs[aov_name] = (
|
| 990 |
+
self.image_processor.preprocess(normal)
|
| 991 |
+
)
|
| 992 |
+
else:
|
| 993 |
+
preprocessed_aovs[aov_name] = None
|
| 994 |
+
|
| 995 |
+
if aov_name == "roughness":
|
| 996 |
+
if roughness is not None:
|
| 997 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
| 998 |
+
roughness
|
| 999 |
+
)
|
| 1000 |
+
else:
|
| 1001 |
+
preprocessed_aovs[aov_name] = None
|
| 1002 |
+
if aov_name == "metallic":
|
| 1003 |
+
if metallic is not None:
|
| 1004 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
| 1005 |
+
metallic
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
preprocessed_aovs[aov_name] = None
|
| 1009 |
+
if aov_name == "irradiance":
|
| 1010 |
+
if irradiance is not None:
|
| 1011 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
| 1012 |
+
irradiance
|
| 1013 |
+
)
|
| 1014 |
+
else:
|
| 1015 |
+
preprocessed_aovs[aov_name] = None
|
| 1016 |
+
# print("aovs:", preprocessed_aovs.keys())
|
| 1017 |
+
# 4. Prepare latent variables
|
| 1018 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 1019 |
+
# height, width = image_latents.shape[-2:]
|
| 1020 |
+
height = old_height / 8
|
| 1021 |
+
width = old_width / 8
|
| 1022 |
+
height = height * self.vae_scale_factor
|
| 1023 |
+
width = width * self.vae_scale_factor
|
| 1024 |
+
latents = self.prepare_latents(
|
| 1025 |
+
batch_size * num_images_per_prompt,
|
| 1026 |
+
num_channels_latents,
|
| 1027 |
+
height,
|
| 1028 |
+
width,
|
| 1029 |
+
prompt_embeds.dtype,
|
| 1030 |
+
device,
|
| 1031 |
+
generator,
|
| 1032 |
+
latents,
|
| 1033 |
+
)
|
| 1034 |
+
height_latent, width_latent = latents.shape[-2:]
|
| 1035 |
+
|
| 1036 |
+
# 5. Prepare timesteps
|
| 1037 |
+
set_flow_timesteps(
|
| 1038 |
+
self.scheduler,
|
| 1039 |
+
self.transformer,
|
| 1040 |
+
num_inference_steps,
|
| 1041 |
+
height_latent,
|
| 1042 |
+
width_latent,
|
| 1043 |
+
device,
|
| 1044 |
+
)
|
| 1045 |
+
timesteps = self.scheduler.timesteps
|
| 1046 |
+
|
| 1047 |
+
# timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 1048 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 1049 |
+
self._num_timesteps = len(timesteps)
|
| 1050 |
+
|
| 1051 |
+
# 6. Prepare Image latent
|
| 1052 |
+
image_latents = []
|
| 1053 |
+
for aov_name, aov in preprocessed_aovs.items():
|
| 1054 |
+
# print(aov_name)
|
| 1055 |
+
if aov is None:
|
| 1056 |
+
image_latent = torch.zeros(
|
| 1057 |
+
batch_size,
|
| 1058 |
+
num_channels_latents,
|
| 1059 |
+
height_latent,
|
| 1060 |
+
width_latent,
|
| 1061 |
+
dtype=prompt_embeds.dtype,
|
| 1062 |
+
device=device,
|
| 1063 |
+
)
|
| 1064 |
+
# print(image_latent.shape)
|
| 1065 |
+
if self.do_classifier_free_guidance:
|
| 1066 |
+
image_latents.append(
|
| 1067 |
+
torch.cat([image_latent, image_latent, image_latent], dim=0)
|
| 1068 |
+
)
|
| 1069 |
+
else:
|
| 1070 |
+
image_latents.append(image_latent)
|
| 1071 |
+
else:
|
| 1072 |
+
image_latent = (
|
| 1073 |
+
self.prepare_image_latents(
|
| 1074 |
+
aov,
|
| 1075 |
+
batch_size,
|
| 1076 |
+
num_images_per_prompt,
|
| 1077 |
+
prompt_embeds.dtype,
|
| 1078 |
+
device,
|
| 1079 |
+
self.do_classifier_free_guidance,
|
| 1080 |
+
generator,
|
| 1081 |
+
)
|
| 1082 |
+
)
|
| 1083 |
+
image_latents.append(image_latent)
|
| 1084 |
+
image_latents = torch.cat(image_latents, dim=1)
|
| 1085 |
+
|
| 1086 |
+
# 7. Check that shapes of latents and image match the DIT in_channels
|
| 1087 |
+
num_channels_image = image_latents.shape[1]
|
| 1088 |
+
|
| 1089 |
+
if num_channels_latents + num_channels_image != self.transformer.config.in_channels:
|
| 1090 |
+
raise ValueError(
|
| 1091 |
+
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
| 1092 |
+
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
| 1093 |
+
f" `num_channels_image`: {num_channels_image} "
|
| 1094 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
| 1095 |
+
" `pipeline.transformer` or your `image` input."
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
# 8. Denoising loop
|
| 1099 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1100 |
+
for i, t in enumerate(timesteps):
|
| 1101 |
+
if self.interrupt:
|
| 1102 |
+
continue
|
| 1103 |
+
|
| 1104 |
+
# expand the latents if we are doing classifier free guidance
|
| 1105 |
+
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
|
| 1106 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1107 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 1108 |
+
|
| 1109 |
+
scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 1110 |
+
|
| 1111 |
+
# if "mask_index" in kwargs and kwargs['mask_index'] is not None:
|
| 1112 |
+
# mask_index = kwargs['mask_index']
|
| 1113 |
+
# else:
|
| 1114 |
+
# mask_index = None
|
| 1115 |
+
noise_pred = self.transformer(
|
| 1116 |
+
hidden_states=scaled_latent_model_input,
|
| 1117 |
+
timestep=timestep,
|
| 1118 |
+
encoder_hidden_states=prompt_embeds,
|
| 1119 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1120 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1121 |
+
return_dict=False,
|
| 1122 |
+
# mask_index= mask_index,
|
| 1123 |
+
)[0]
|
| 1124 |
+
|
| 1125 |
+
# perform guidance
|
| 1126 |
+
if self.do_classifier_free_guidance:
|
| 1127 |
+
noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
|
| 1128 |
+
noise_pred = (
|
| 1129 |
+
noise_pred_uncond
|
| 1130 |
+
+ self.guidance_scale * (noise_pred_text - noise_pred_image)
|
| 1131 |
+
+ self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
| 1132 |
+
)
|
| 1133 |
+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # neg, prompt
|
| 1134 |
+
# noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1135 |
+
|
| 1136 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1137 |
+
latents_dtype = latents.dtype
|
| 1138 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1139 |
+
|
| 1140 |
+
if latents.dtype != latents_dtype:
|
| 1141 |
+
if torch.backends.mps.is_available():
|
| 1142 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1143 |
+
latents = latents.to(latents_dtype)
|
| 1144 |
+
|
| 1145 |
+
if callback_on_step_end is not None:
|
| 1146 |
+
callback_kwargs = {}
|
| 1147 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1148 |
+
callback_kwargs[k] = locals()[k]
|
| 1149 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1150 |
+
|
| 1151 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1152 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1153 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1154 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 1155 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 1156 |
+
)
|
| 1157 |
+
image_latents = callback_outputs.pop("image_latents", image_latents)
|
| 1158 |
+
if mask_img is not None:
|
| 1159 |
+
mask_image_latents = callback_outputs.pop("mask_image_latents", mask_image_latents)
|
| 1160 |
+
# call the callback, if provided
|
| 1161 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1162 |
+
progress_bar.update()
|
| 1163 |
+
|
| 1164 |
+
if XLA_AVAILABLE:
|
| 1165 |
+
xm.mark_step()
|
| 1166 |
+
|
| 1167 |
+
if output_type == "latent":
|
| 1168 |
+
image = latents
|
| 1169 |
+
|
| 1170 |
+
else:
|
| 1171 |
+
# latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1172 |
+
latents = latents / self.vae.config.scaling_factor
|
| 1173 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1174 |
+
|
| 1175 |
+
do_denormalize = [True] * image.shape[0]
|
| 1176 |
+
|
| 1177 |
+
image = torchvision.transforms.Resize((old_height, old_width),interpolation=PIL.Image.BICUBIC)(image)
|
| 1178 |
+
image = self.image_processor.postprocess(
|
| 1179 |
+
image,
|
| 1180 |
+
output_type=output_type,
|
| 1181 |
+
do_denormalize=do_denormalize,
|
| 1182 |
+
do_gamma_correction=False
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
# Offload all models
|
| 1186 |
+
self.maybe_free_model_hooks()
|
| 1187 |
+
|
| 1188 |
+
if not return_dict:
|
| 1189 |
+
return (image,)
|
| 1190 |
+
|
| 1191 |
+
return StableDiffusion3PipelineOutput(images=image)
|
pipeline_intrinsic_weather_inverse.py
ADDED
|
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub custom pipeline: IntrinsicWeatherInversePipeline.
|
| 2 |
+
Load with native Hugging Face diffusers and trust_remote_code=True.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import inspect
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 24 |
+
|
| 25 |
+
import PIL.Image
|
| 26 |
+
import torch
|
| 27 |
+
import torchvision
|
| 28 |
+
from transformers import (
|
| 29 |
+
CLIPTextModelWithProjection,
|
| 30 |
+
CLIPTokenizer,
|
| 31 |
+
T5EncoderModel,
|
| 32 |
+
T5TokenizerFast,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 36 |
+
from diffusers.loaders import SD3LoraLoaderMixin
|
| 37 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 38 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
|
| 39 |
+
from diffusers.utils import (
|
| 40 |
+
deprecate,
|
| 41 |
+
is_torch_xla_available,
|
| 42 |
+
logging,
|
| 43 |
+
replace_example_docstring,
|
| 44 |
+
)
|
| 45 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 46 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 47 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
| 48 |
+
|
| 49 |
+
from pipeline_utils import load_transformer_from_subfolder, resolve_repo_dir, set_flow_timesteps
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if is_torch_xla_available():
|
| 53 |
+
import torch_xla.core.xla_model as xm
|
| 54 |
+
|
| 55 |
+
XLA_AVAILABLE = True
|
| 56 |
+
else:
|
| 57 |
+
XLA_AVAILABLE = False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 61 |
+
|
| 62 |
+
EXAMPLE_DOC_STRING = """
|
| 63 |
+
Examples:
|
| 64 |
+
```py
|
| 65 |
+
>>> from pathlib import Path
|
| 66 |
+
>>> import torch
|
| 67 |
+
>>> from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
|
| 68 |
+
|
| 69 |
+
>>> repo_dir = Path("./IntrisicWeather-diffusers").resolve()
|
| 70 |
+
>>> transformer = IntrinsicWeatherInversePipeline.load_transformer("inverse-512", repo_dir)
|
| 71 |
+
>>> pipe = IntrinsicWeatherInversePipeline.from_pretrained(
|
| 72 |
+
... repo_dir,
|
| 73 |
+
... transformer_subfolder="inverse-512",
|
| 74 |
+
... local_files_only=True,
|
| 75 |
+
... torch_dtype=torch.float16,
|
| 76 |
+
... )
|
| 77 |
+
>>> pipe.to("cuda")
|
| 78 |
+
>>> output = pipe(
|
| 79 |
+
... image=input_rgb_tensor,
|
| 80 |
+
... prompt="Albedo (diffuse basecolor)",
|
| 81 |
+
... aov=["albedo"],
|
| 82 |
+
... num_inference_steps=50,
|
| 83 |
+
... )
|
| 84 |
+
```
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 91 |
+
def retrieve_latents(
|
| 92 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 93 |
+
):
|
| 94 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 95 |
+
return encoder_output.latent_dist.sample(generator)
|
| 96 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 97 |
+
return encoder_output.latent_dist.mode()
|
| 98 |
+
elif hasattr(encoder_output, "latents"):
|
| 99 |
+
return encoder_output.latents
|
| 100 |
+
else:
|
| 101 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 102 |
+
|
| 103 |
+
class VaeAoVProcessor(VaeImageProcessor):
|
| 104 |
+
def postprocess(
|
| 105 |
+
self,
|
| 106 |
+
image: torch.FloatTensor,
|
| 107 |
+
output_type: str = "pil",
|
| 108 |
+
do_denormalize: Optional[List[bool]] = None,
|
| 109 |
+
do_gamma_correction: bool = True,
|
| 110 |
+
):
|
| 111 |
+
if not isinstance(image, torch.Tensor):
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
| 114 |
+
)
|
| 115 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
| 116 |
+
deprecation_message = (
|
| 117 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
| 118 |
+
"`pil`, `np`, `pt`, `latent`"
|
| 119 |
+
)
|
| 120 |
+
deprecate(
|
| 121 |
+
"Unsupported output_type",
|
| 122 |
+
"1.0.0",
|
| 123 |
+
deprecation_message,
|
| 124 |
+
standard_warn=False,
|
| 125 |
+
)
|
| 126 |
+
output_type = "np"
|
| 127 |
+
|
| 128 |
+
if output_type == "latent":
|
| 129 |
+
return image
|
| 130 |
+
|
| 131 |
+
if do_denormalize is None:
|
| 132 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
| 133 |
+
|
| 134 |
+
image = torch.stack(
|
| 135 |
+
[
|
| 136 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
| 137 |
+
for i in range(image.shape[0])
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Gamma correction
|
| 142 |
+
if do_gamma_correction:
|
| 143 |
+
# image = torch.pow(image, 1.0 / 2.2)
|
| 144 |
+
image = image ** (1 / 2.2)
|
| 145 |
+
|
| 146 |
+
if output_type == "pt":
|
| 147 |
+
return image
|
| 148 |
+
|
| 149 |
+
image = self.pt_to_numpy(image)
|
| 150 |
+
|
| 151 |
+
if output_type == "np":
|
| 152 |
+
return image
|
| 153 |
+
|
| 154 |
+
if output_type == "pil":
|
| 155 |
+
return self.numpy_to_pil(image)
|
| 156 |
+
|
| 157 |
+
class IntrinsicWeatherInversePipeline(DiffusionPipeline, SD3LoraLoaderMixin):
|
| 158 |
+
r"""
|
| 159 |
+
Args:
|
| 160 |
+
transformer ([`SD3Transformer2DModel`]):
|
| 161 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 162 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 163 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 164 |
+
vae ([`AutoencoderKL`]):
|
| 165 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 166 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
| 167 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 168 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
| 169 |
+
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
| 170 |
+
as its dimension.
|
| 171 |
+
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
| 172 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
| 173 |
+
specifically the
|
| 174 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
| 175 |
+
variant.
|
| 176 |
+
text_encoder_3 ([`T5EncoderModel`]):
|
| 177 |
+
Frozen text-encoder. Stable Diffusion 3 uses
|
| 178 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 179 |
+
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 180 |
+
tokenizer (`CLIPTokenizer`):
|
| 181 |
+
Tokenizer of class
|
| 182 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 183 |
+
tokenizer_2 (`CLIPTokenizer`):
|
| 184 |
+
Second Tokenizer of class
|
| 185 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 186 |
+
tokenizer_3 (`T5TokenizerFast`):
|
| 187 |
+
Tokenizer of class
|
| 188 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
| 192 |
+
_optional_components = []
|
| 193 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
| 194 |
+
|
| 195 |
+
@classmethod
|
| 196 |
+
def load_transformer(
|
| 197 |
+
cls,
|
| 198 |
+
transformer_subfolder: str,
|
| 199 |
+
pretrained_model_name_or_path: str | Path,
|
| 200 |
+
*,
|
| 201 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 202 |
+
device: str | torch.device | None = None,
|
| 203 |
+
):
|
| 204 |
+
return load_transformer_from_subfolder(
|
| 205 |
+
pretrained_model_name_or_path,
|
| 206 |
+
transformer_subfolder,
|
| 207 |
+
dtype=dtype,
|
| 208 |
+
device=device,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def from_pretrained(
|
| 213 |
+
cls,
|
| 214 |
+
pretrained_model_name_or_path,
|
| 215 |
+
*args,
|
| 216 |
+
transformer_subfolder: str = "inverse-512",
|
| 217 |
+
transformer=None,
|
| 218 |
+
**kwargs,
|
| 219 |
+
):
|
| 220 |
+
repo_dir = resolve_repo_dir(pretrained_model_name_or_path)
|
| 221 |
+
dtype = kwargs.get("torch_dtype", torch.bfloat16)
|
| 222 |
+
device = kwargs.get("device")
|
| 223 |
+
if transformer is None:
|
| 224 |
+
transformer = cls.load_transformer(
|
| 225 |
+
transformer_subfolder,
|
| 226 |
+
repo_dir,
|
| 227 |
+
dtype=dtype,
|
| 228 |
+
device=device,
|
| 229 |
+
)
|
| 230 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 231 |
+
repo_dir.as_posix(),
|
| 232 |
+
*args,
|
| 233 |
+
transformer=transformer,
|
| 234 |
+
custom_pipeline=str(repo_dir / "pipeline_intrinsic_weather_inverse.py"),
|
| 235 |
+
trust_remote_code=True,
|
| 236 |
+
**kwargs,
|
| 237 |
+
)
|
| 238 |
+
if device is not None:
|
| 239 |
+
pipe = pipe.to(device)
|
| 240 |
+
return pipe
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
transformer,
|
| 245 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 246 |
+
vae: AutoencoderKL,
|
| 247 |
+
text_encoder: CLIPTextModelWithProjection,
|
| 248 |
+
tokenizer: CLIPTokenizer,
|
| 249 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 250 |
+
tokenizer_2: CLIPTokenizer,
|
| 251 |
+
text_encoder_3: T5EncoderModel,
|
| 252 |
+
tokenizer_3: T5TokenizerFast,
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
self.register_modules(
|
| 257 |
+
vae=vae,
|
| 258 |
+
text_encoder=text_encoder,
|
| 259 |
+
text_encoder_2=text_encoder_2,
|
| 260 |
+
text_encoder_3=text_encoder_3,
|
| 261 |
+
tokenizer=tokenizer,
|
| 262 |
+
tokenizer_2=tokenizer_2,
|
| 263 |
+
tokenizer_3=tokenizer_3,
|
| 264 |
+
transformer=transformer,
|
| 265 |
+
scheduler=scheduler,
|
| 266 |
+
)
|
| 267 |
+
self.vae_scale_factor = (
|
| 268 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 269 |
+
)
|
| 270 |
+
self.image_processor = VaeAoVProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 271 |
+
self.tokenizer_max_length = (
|
| 272 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 273 |
+
)
|
| 274 |
+
self.default_sample_size = (
|
| 275 |
+
self.transformer.config.sample_size
|
| 276 |
+
if hasattr(self, "transformer") and self.transformer is not None
|
| 277 |
+
else 128
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def _get_t5_prompt_embeds(
|
| 281 |
+
self,
|
| 282 |
+
prompt: Union[str, List[str]] = None,
|
| 283 |
+
num_images_per_prompt: int = 1,
|
| 284 |
+
device: Optional[torch.device] = None,
|
| 285 |
+
dtype: Optional[torch.dtype] = None,
|
| 286 |
+
):
|
| 287 |
+
device = device or self._execution_device
|
| 288 |
+
dtype = dtype or self.text_encoder.dtype
|
| 289 |
+
|
| 290 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 291 |
+
batch_size = len(prompt)
|
| 292 |
+
|
| 293 |
+
if self.text_encoder_3 is None:
|
| 294 |
+
return torch.zeros(
|
| 295 |
+
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
|
| 296 |
+
device=device,
|
| 297 |
+
dtype=dtype,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
text_inputs = self.tokenizer_3(
|
| 301 |
+
prompt,
|
| 302 |
+
padding="max_length",
|
| 303 |
+
max_length=self.tokenizer_max_length,
|
| 304 |
+
truncation=True,
|
| 305 |
+
add_special_tokens=True,
|
| 306 |
+
return_tensors="pt",
|
| 307 |
+
)
|
| 308 |
+
text_input_ids = text_inputs.input_ids
|
| 309 |
+
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
| 310 |
+
|
| 311 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 312 |
+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 313 |
+
logger.warning(
|
| 314 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 315 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
| 319 |
+
|
| 320 |
+
dtype = self.text_encoder_3.dtype
|
| 321 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 322 |
+
|
| 323 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 324 |
+
|
| 325 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 326 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 327 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 328 |
+
|
| 329 |
+
return prompt_embeds
|
| 330 |
+
|
| 331 |
+
def _get_clip_prompt_embeds(
|
| 332 |
+
self,
|
| 333 |
+
prompt: Union[str, List[str]],
|
| 334 |
+
num_images_per_prompt: int = 1,
|
| 335 |
+
device: Optional[torch.device] = None,
|
| 336 |
+
clip_skip: Optional[int] = None,
|
| 337 |
+
clip_model_index: int = 0,
|
| 338 |
+
):
|
| 339 |
+
device = device or self._execution_device
|
| 340 |
+
|
| 341 |
+
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
| 342 |
+
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
| 343 |
+
|
| 344 |
+
tokenizer = clip_tokenizers[clip_model_index]
|
| 345 |
+
text_encoder = clip_text_encoders[clip_model_index]
|
| 346 |
+
|
| 347 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 348 |
+
batch_size = len(prompt)
|
| 349 |
+
|
| 350 |
+
text_inputs = tokenizer(
|
| 351 |
+
prompt,
|
| 352 |
+
padding="max_length",
|
| 353 |
+
max_length=self.tokenizer_max_length,
|
| 354 |
+
truncation=True,
|
| 355 |
+
return_tensors="pt",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
text_input_ids = text_inputs.input_ids
|
| 359 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 360 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 361 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 362 |
+
logger.warning(
|
| 363 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 364 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 365 |
+
)
|
| 366 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
| 367 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
| 368 |
+
|
| 369 |
+
if clip_skip is None:
|
| 370 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 371 |
+
else:
|
| 372 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
| 373 |
+
|
| 374 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 375 |
+
|
| 376 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 377 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 378 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 379 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 380 |
+
|
| 381 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 382 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 383 |
+
|
| 384 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 385 |
+
|
| 386 |
+
def encode_prompt(
|
| 387 |
+
self,
|
| 388 |
+
prompt: Union[str, List[str]],
|
| 389 |
+
prompt_2: Union[str, List[str]] = None,
|
| 390 |
+
prompt_3: Union[str, List[str]] = None,
|
| 391 |
+
device: Optional[torch.device] = None,
|
| 392 |
+
num_images_per_prompt: int = 1,
|
| 393 |
+
do_classifier_free_guidance: bool = True,
|
| 394 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 395 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 396 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 397 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 398 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 399 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 400 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 401 |
+
clip_skip: Optional[int] = None,
|
| 402 |
+
):
|
| 403 |
+
r"""
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 407 |
+
prompt to be encoded
|
| 408 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 409 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 410 |
+
used in all text-encoders
|
| 411 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 412 |
+
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 413 |
+
used in all text-encoders
|
| 414 |
+
device: (`torch.device`):
|
| 415 |
+
torch device
|
| 416 |
+
num_images_per_prompt (`int`):
|
| 417 |
+
number of images that should be generated per prompt
|
| 418 |
+
do_classifier_free_guidance (`bool`):
|
| 419 |
+
whether to use classifier free guidance or not
|
| 420 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 421 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 422 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 423 |
+
less than `1`).
|
| 424 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 425 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 426 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 427 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 428 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 429 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
| 430 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 431 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 432 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 433 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 434 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 435 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 436 |
+
argument.
|
| 437 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 438 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 439 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 440 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 441 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 442 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 443 |
+
input argument.
|
| 444 |
+
clip_skip (`int`, *optional*):
|
| 445 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 446 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 447 |
+
"""
|
| 448 |
+
device = device or self._execution_device
|
| 449 |
+
|
| 450 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 451 |
+
if prompt is not None:
|
| 452 |
+
batch_size = len(prompt)
|
| 453 |
+
else:
|
| 454 |
+
batch_size = prompt_embeds.shape[0]
|
| 455 |
+
|
| 456 |
+
if prompt_embeds is None:
|
| 457 |
+
prompt_2 = prompt_2 or prompt
|
| 458 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 459 |
+
|
| 460 |
+
prompt_3 = prompt_3 or prompt
|
| 461 |
+
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
| 462 |
+
|
| 463 |
+
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 464 |
+
prompt=prompt,
|
| 465 |
+
device=device,
|
| 466 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 467 |
+
clip_skip=clip_skip,
|
| 468 |
+
clip_model_index=0,
|
| 469 |
+
)
|
| 470 |
+
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 471 |
+
prompt=prompt_2,
|
| 472 |
+
device=device,
|
| 473 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 474 |
+
clip_skip=clip_skip,
|
| 475 |
+
clip_model_index=1,
|
| 476 |
+
)
|
| 477 |
+
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
| 478 |
+
|
| 479 |
+
t5_prompt_embed = self._get_t5_prompt_embeds(
|
| 480 |
+
prompt=prompt_3,
|
| 481 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 482 |
+
device=device,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
| 486 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
| 490 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
| 491 |
+
|
| 492 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 493 |
+
negative_prompt = negative_prompt or ""
|
| 494 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
| 495 |
+
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
| 496 |
+
|
| 497 |
+
# normalize str to list
|
| 498 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 499 |
+
negative_prompt_2 = (
|
| 500 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
| 501 |
+
)
|
| 502 |
+
negative_prompt_3 = (
|
| 503 |
+
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 507 |
+
raise TypeError(
|
| 508 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 509 |
+
f" {type(prompt)}."
|
| 510 |
+
)
|
| 511 |
+
elif batch_size != len(negative_prompt):
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 514 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 515 |
+
" the batch size of `prompt`."
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
| 519 |
+
negative_prompt,
|
| 520 |
+
device=device,
|
| 521 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 522 |
+
clip_skip=None,
|
| 523 |
+
clip_model_index=0,
|
| 524 |
+
)
|
| 525 |
+
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
| 526 |
+
negative_prompt_2,
|
| 527 |
+
device=device,
|
| 528 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 529 |
+
clip_skip=None,
|
| 530 |
+
clip_model_index=1,
|
| 531 |
+
)
|
| 532 |
+
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
| 533 |
+
|
| 534 |
+
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
| 535 |
+
prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
| 539 |
+
negative_clip_prompt_embeds,
|
| 540 |
+
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
| 544 |
+
negative_pooled_prompt_embeds = torch.cat(
|
| 545 |
+
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 549 |
+
|
| 550 |
+
def check_inputs(
|
| 551 |
+
self,
|
| 552 |
+
prompt,
|
| 553 |
+
prompt_2,
|
| 554 |
+
prompt_3,
|
| 555 |
+
# height,
|
| 556 |
+
# width,
|
| 557 |
+
negative_prompt=None,
|
| 558 |
+
negative_prompt_2=None,
|
| 559 |
+
negative_prompt_3=None,
|
| 560 |
+
prompt_embeds=None,
|
| 561 |
+
negative_prompt_embeds=None,
|
| 562 |
+
pooled_prompt_embeds=None,
|
| 563 |
+
negative_pooled_prompt_embeds=None,
|
| 564 |
+
callback_on_step_end_tensor_inputs=None,
|
| 565 |
+
):
|
| 566 |
+
# if height % 8 != 0 or width % 8 != 0:
|
| 567 |
+
# raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 568 |
+
|
| 569 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 570 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 571 |
+
):
|
| 572 |
+
raise ValueError(
|
| 573 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
if prompt is not None and prompt_embeds is not None:
|
| 577 |
+
raise ValueError(
|
| 578 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 579 |
+
" only forward one of the two."
|
| 580 |
+
)
|
| 581 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 582 |
+
raise ValueError(
|
| 583 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 584 |
+
" only forward one of the two."
|
| 585 |
+
)
|
| 586 |
+
elif prompt_3 is not None and prompt_embeds is not None:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 589 |
+
" only forward one of the two."
|
| 590 |
+
)
|
| 591 |
+
elif prompt is None and prompt_embeds is None:
|
| 592 |
+
raise ValueError(
|
| 593 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 594 |
+
)
|
| 595 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 596 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 597 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 598 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 599 |
+
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
| 600 |
+
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
| 601 |
+
|
| 602 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 603 |
+
raise ValueError(
|
| 604 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 605 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 606 |
+
)
|
| 607 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 608 |
+
raise ValueError(
|
| 609 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 610 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 611 |
+
)
|
| 612 |
+
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
| 613 |
+
raise ValueError(
|
| 614 |
+
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
| 615 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 619 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 622 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 623 |
+
f" {negative_prompt_embeds.shape}."
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 627 |
+
raise ValueError(
|
| 628 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
def prepare_latents(
|
| 637 |
+
self,
|
| 638 |
+
batch_size,
|
| 639 |
+
num_channels_latents,
|
| 640 |
+
height,
|
| 641 |
+
width,
|
| 642 |
+
dtype,
|
| 643 |
+
device,
|
| 644 |
+
generator,
|
| 645 |
+
latents=None,
|
| 646 |
+
):
|
| 647 |
+
if latents is not None:
|
| 648 |
+
return latents.to(device=device, dtype=dtype)
|
| 649 |
+
|
| 650 |
+
shape = (
|
| 651 |
+
batch_size,
|
| 652 |
+
num_channels_latents,
|
| 653 |
+
int(height) // self.vae_scale_factor,
|
| 654 |
+
int(width) // self.vae_scale_factor,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 658 |
+
raise ValueError(
|
| 659 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 660 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 664 |
+
|
| 665 |
+
return latents
|
| 666 |
+
|
| 667 |
+
def prepare_image_latents(
|
| 668 |
+
self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
|
| 669 |
+
):
|
| 670 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 671 |
+
raise ValueError(
|
| 672 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
image = image.to(device=device, dtype=dtype)
|
| 676 |
+
|
| 677 |
+
batch_size = batch_size * num_images_per_prompt
|
| 678 |
+
|
| 679 |
+
if image.shape[1] == self.vae.config.latent_channels:
|
| 680 |
+
image_latents = image
|
| 681 |
+
else:
|
| 682 |
+
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
| 683 |
+
# ? normalize image latents
|
| 684 |
+
# image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 685 |
+
|
| 686 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 687 |
+
# expand image_latents for batch_size
|
| 688 |
+
deprecation_message = (
|
| 689 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
| 690 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
| 691 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
| 692 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
| 693 |
+
)
|
| 694 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
| 695 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 696 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 697 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 698 |
+
raise ValueError(
|
| 699 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
| 700 |
+
)
|
| 701 |
+
else:
|
| 702 |
+
image_latents = torch.cat([image_latents], dim=0)
|
| 703 |
+
|
| 704 |
+
if do_classifier_free_guidance:
|
| 705 |
+
uncond_image_latents = torch.zeros_like(image_latents)
|
| 706 |
+
image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
|
| 707 |
+
|
| 708 |
+
return image_latents
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
@property
|
| 718 |
+
def guidance_scale(self):
|
| 719 |
+
return self._guidance_scale
|
| 720 |
+
@property
|
| 721 |
+
def image_guidance_scale(self):
|
| 722 |
+
return self._image_guidance_scale
|
| 723 |
+
|
| 724 |
+
@property
|
| 725 |
+
def clip_skip(self):
|
| 726 |
+
return self._clip_skip
|
| 727 |
+
|
| 728 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 729 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 730 |
+
# corresponds to doing no classifier free guidance.
|
| 731 |
+
@property
|
| 732 |
+
def do_classifier_free_guidance(self):
|
| 733 |
+
return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
|
| 734 |
+
|
| 735 |
+
@property
|
| 736 |
+
def joint_attention_kwargs(self):
|
| 737 |
+
return self._joint_attention_kwargs
|
| 738 |
+
|
| 739 |
+
@property
|
| 740 |
+
def num_timesteps(self):
|
| 741 |
+
return self._num_timesteps
|
| 742 |
+
|
| 743 |
+
@property
|
| 744 |
+
def interrupt(self):
|
| 745 |
+
return self._interrupt
|
| 746 |
+
|
| 747 |
+
@torch.no_grad()
|
| 748 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 749 |
+
def __call__(
|
| 750 |
+
self,
|
| 751 |
+
prompt: Union[str, List[str]] = None,
|
| 752 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 753 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 754 |
+
image: PipelineImageInput = None,
|
| 755 |
+
height: Optional[int] = None,
|
| 756 |
+
width: Optional[int] = None,
|
| 757 |
+
num_inference_steps: int = 28,
|
| 758 |
+
timesteps: List[int] = None,
|
| 759 |
+
guidance_scale: float = 7.0,
|
| 760 |
+
image_guidance_scale: float = 1.5,
|
| 761 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 762 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 763 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 764 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 765 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 766 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 767 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 768 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 769 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 770 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 771 |
+
output_type: Optional[str] = "pil",
|
| 772 |
+
return_dict: bool = True,
|
| 773 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 774 |
+
clip_skip: Optional[int] = None,
|
| 775 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 776 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 777 |
+
mask_img: Optional[PipelineImageInput] = None,
|
| 778 |
+
aov: List[str] = ["albedo"],
|
| 779 |
+
map_aware_mask: Optional[torch.FloatTensor] = None,
|
| 780 |
+
**kwargs
|
| 781 |
+
):
|
| 782 |
+
r"""
|
| 783 |
+
Function invoked when calling the pipeline for generation.
|
| 784 |
+
|
| 785 |
+
Args:
|
| 786 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 787 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 788 |
+
instead.
|
| 789 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 790 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 791 |
+
will be used instead
|
| 792 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 793 |
+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 794 |
+
will be used instead
|
| 795 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 796 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 797 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 798 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 799 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 800 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 801 |
+
expense of slower inference.
|
| 802 |
+
timesteps (`List[int]`, *optional*):
|
| 803 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 804 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 805 |
+
passed will be used. Must be in descending order.
|
| 806 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 807 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 808 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 809 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 810 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 811 |
+
usually at the expense of lower image quality.
|
| 812 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 813 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 814 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 815 |
+
less than `1`).
|
| 816 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 817 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 818 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
| 819 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 820 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 821 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
| 822 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 823 |
+
The number of images to generate per prompt.
|
| 824 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 825 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 826 |
+
to make generation deterministic.
|
| 827 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 828 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 829 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 830 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 831 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 832 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 833 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 834 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 835 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 836 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 837 |
+
argument.
|
| 838 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 839 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 840 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 841 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 842 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 843 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 844 |
+
input argument.
|
| 845 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 846 |
+
The output format of the generate image. Choose between
|
| 847 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 848 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 849 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 850 |
+
of a plain tuple.
|
| 851 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 852 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 853 |
+
`self.processor` in
|
| 854 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 855 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 856 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 857 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 858 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 859 |
+
`callback_on_step_end_tensor_inputs`.
|
| 860 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 861 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 862 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 863 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 864 |
+
|
| 865 |
+
Examples:
|
| 866 |
+
|
| 867 |
+
Returns:
|
| 868 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
| 869 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
| 870 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 871 |
+
"""
|
| 872 |
+
|
| 873 |
+
# height = height or self.default_sample_size * self.vae_scale_factor
|
| 874 |
+
# width = width or self.default_sample_size * self.vae_scale_factor
|
| 875 |
+
|
| 876 |
+
# 1. Check inputs. Raise error if not correct
|
| 877 |
+
self.check_inputs(
|
| 878 |
+
prompt,
|
| 879 |
+
prompt_2,
|
| 880 |
+
prompt_3,
|
| 881 |
+
negative_prompt=negative_prompt,
|
| 882 |
+
negative_prompt_2=negative_prompt_2,
|
| 883 |
+
negative_prompt_3=negative_prompt_3,
|
| 884 |
+
prompt_embeds=prompt_embeds,
|
| 885 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 886 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 887 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 888 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
self._guidance_scale = guidance_scale
|
| 892 |
+
self._image_guidance_scale = image_guidance_scale
|
| 893 |
+
self._clip_skip = clip_skip
|
| 894 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 895 |
+
self._interrupt = False
|
| 896 |
+
|
| 897 |
+
# 2. Define call parameters
|
| 898 |
+
if prompt is not None and isinstance(prompt, str):
|
| 899 |
+
batch_size = 1
|
| 900 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 901 |
+
batch_size = len(prompt)
|
| 902 |
+
else:
|
| 903 |
+
batch_size = prompt_embeds.shape[0]
|
| 904 |
+
|
| 905 |
+
device = self._execution_device
|
| 906 |
+
|
| 907 |
+
(
|
| 908 |
+
prompt_embeds,
|
| 909 |
+
negative_prompt_embeds,
|
| 910 |
+
pooled_prompt_embeds,
|
| 911 |
+
negative_pooled_prompt_embeds,
|
| 912 |
+
) = self.encode_prompt(
|
| 913 |
+
prompt=prompt,
|
| 914 |
+
prompt_2=prompt_2,
|
| 915 |
+
prompt_3=prompt_3,
|
| 916 |
+
negative_prompt=negative_prompt,
|
| 917 |
+
negative_prompt_2=negative_prompt_2,
|
| 918 |
+
negative_prompt_3=negative_prompt_3,
|
| 919 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 920 |
+
prompt_embeds=prompt_embeds,
|
| 921 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 922 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 923 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 924 |
+
device=device,
|
| 925 |
+
clip_skip=self.clip_skip,
|
| 926 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 927 |
+
)
|
| 928 |
+
# print("prompt:", prompt_embeds.shape)
|
| 929 |
+
if self.do_classifier_free_guidance:
|
| 930 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 931 |
+
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
|
| 932 |
+
|
| 933 |
+
# Similiarly
|
| 934 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0)
|
| 935 |
+
|
| 936 |
+
# if self.do_classifier_free_guidance:
|
| 937 |
+
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 938 |
+
# pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 939 |
+
|
| 940 |
+
# 3. Preprocess image
|
| 941 |
+
if (isinstance(image, PIL.Image.Image) and len(image.size)==4) or (isinstance(image, torch.Tensor) and len(image.shape) == 4):
|
| 942 |
+
old_height = image.shape[2]
|
| 943 |
+
old_width = image.shape[3]
|
| 944 |
+
else:
|
| 945 |
+
old_height = image.shape[1]
|
| 946 |
+
old_width = image.shape[2]
|
| 947 |
+
|
| 948 |
+
image = self.image_processor.preprocess(image)
|
| 949 |
+
# image = torchvision.transforms.Resize((512,512),interpolation=PIL.Image.BICUBIC)(image)
|
| 950 |
+
|
| 951 |
+
# 4. Prepare Image latent
|
| 952 |
+
image_latents = self.prepare_image_latents(
|
| 953 |
+
image,
|
| 954 |
+
batch_size,
|
| 955 |
+
num_images_per_prompt,
|
| 956 |
+
prompt_embeds.dtype,
|
| 957 |
+
device,
|
| 958 |
+
self.do_classifier_free_guidance,
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
# 5. Prepare timesteps
|
| 962 |
+
set_flow_timesteps(
|
| 963 |
+
self.scheduler,
|
| 964 |
+
self.transformer,
|
| 965 |
+
num_inference_steps,
|
| 966 |
+
image_latents.shape[-2],
|
| 967 |
+
image_latents.shape[-1],
|
| 968 |
+
device,
|
| 969 |
+
)
|
| 970 |
+
timesteps = self.scheduler.timesteps
|
| 971 |
+
|
| 972 |
+
# timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 973 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 974 |
+
self._num_timesteps = len(timesteps)
|
| 975 |
+
|
| 976 |
+
height, width = image_latents.shape[-2:]
|
| 977 |
+
height = height * self.vae_scale_factor
|
| 978 |
+
width = width * self.vae_scale_factor
|
| 979 |
+
# 6. Prepare latent variables
|
| 980 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 981 |
+
latents = self.prepare_latents(
|
| 982 |
+
batch_size * num_images_per_prompt,
|
| 983 |
+
num_channels_latents,
|
| 984 |
+
height,
|
| 985 |
+
width,
|
| 986 |
+
prompt_embeds.dtype,
|
| 987 |
+
device,
|
| 988 |
+
generator,
|
| 989 |
+
latents,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
# 7. Check that shapes of latents and image match the DIT in_channels
|
| 993 |
+
num_channels_image = image_latents.shape[1]
|
| 994 |
+
if mask_img is not None:
|
| 995 |
+
mask_img = self.image_processor.preprocess(mask_img)
|
| 996 |
+
mask_image_latents = self.prepare_image_latents(
|
| 997 |
+
mask_img,
|
| 998 |
+
batch_size,
|
| 999 |
+
num_images_per_prompt,
|
| 1000 |
+
prompt_embeds.dtype,
|
| 1001 |
+
device,
|
| 1002 |
+
self.do_classifier_free_guidance,
|
| 1003 |
+
)
|
| 1004 |
+
num_channels_image += mask_image_latents.shape[1]
|
| 1005 |
+
|
| 1006 |
+
if num_channels_latents + num_channels_image != self.transformer.config.in_channels:
|
| 1007 |
+
raise ValueError(
|
| 1008 |
+
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
| 1009 |
+
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
| 1010 |
+
f" `num_channels_image`: {num_channels_image} "
|
| 1011 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
| 1012 |
+
" `pipeline.transformer` or your `image` input."
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
# 8. Denoising loop
|
| 1016 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 1017 |
+
for i, t in enumerate(timesteps):
|
| 1018 |
+
if self.interrupt:
|
| 1019 |
+
continue
|
| 1020 |
+
|
| 1021 |
+
# expand the latents if we are doing classifier free guidance
|
| 1022 |
+
latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
|
| 1023 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1024 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 1025 |
+
|
| 1026 |
+
scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
|
| 1027 |
+
if mask_img is not None:
|
| 1028 |
+
scaled_latent_model_input = torch.cat([scaled_latent_model_input, mask_image_latents], dim=1)
|
| 1029 |
+
# if "mask_index" in kwargs and kwargs['mask_index'] is not None:
|
| 1030 |
+
# mask_index = kwargs['mask_index']
|
| 1031 |
+
# else:
|
| 1032 |
+
# mask_index = None
|
| 1033 |
+
noise_pred = self.transformer(
|
| 1034 |
+
hidden_states=scaled_latent_model_input,
|
| 1035 |
+
timestep=timestep,
|
| 1036 |
+
encoder_hidden_states=prompt_embeds,
|
| 1037 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1038 |
+
map_aware_mask=map_aware_mask,
|
| 1039 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1040 |
+
return_dict=False,
|
| 1041 |
+
# mask_index= mask_index,
|
| 1042 |
+
)[0]
|
| 1043 |
+
|
| 1044 |
+
# perform guidance
|
| 1045 |
+
if self.do_classifier_free_guidance:
|
| 1046 |
+
noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
|
| 1047 |
+
noise_pred = (
|
| 1048 |
+
noise_pred_uncond
|
| 1049 |
+
+ self.guidance_scale * (noise_pred_text - noise_pred_image)
|
| 1050 |
+
+ self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
| 1051 |
+
)
|
| 1052 |
+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # neg, prompt
|
| 1053 |
+
# noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1054 |
+
|
| 1055 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1056 |
+
latents_dtype = latents.dtype
|
| 1057 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1058 |
+
|
| 1059 |
+
if latents.dtype != latents_dtype:
|
| 1060 |
+
if torch.backends.mps.is_available():
|
| 1061 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1062 |
+
latents = latents.to(latents_dtype)
|
| 1063 |
+
|
| 1064 |
+
if callback_on_step_end is not None:
|
| 1065 |
+
callback_kwargs = {}
|
| 1066 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1067 |
+
callback_kwargs[k] = locals()[k]
|
| 1068 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1069 |
+
|
| 1070 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1071 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1072 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1073 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 1074 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 1075 |
+
)
|
| 1076 |
+
image_latents = callback_outputs.pop("image_latents", image_latents)
|
| 1077 |
+
if mask_img is not None:
|
| 1078 |
+
mask_image_latents = callback_outputs.pop("mask_image_latents", mask_image_latents)
|
| 1079 |
+
# call the callback, if provided
|
| 1080 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1081 |
+
progress_bar.update()
|
| 1082 |
+
|
| 1083 |
+
if XLA_AVAILABLE:
|
| 1084 |
+
xm.mark_step()
|
| 1085 |
+
|
| 1086 |
+
if output_type == "latent":
|
| 1087 |
+
image = latents
|
| 1088 |
+
|
| 1089 |
+
else:
|
| 1090 |
+
# latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1091 |
+
latents = latents / self.vae.config.scaling_factor
|
| 1092 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1093 |
+
|
| 1094 |
+
do_denormalize = [True] * image.shape[0]
|
| 1095 |
+
aov_name = aov[0]
|
| 1096 |
+
if aov_name == "albedo" or aov_name == "irradiance":
|
| 1097 |
+
do_gamma_correction = True
|
| 1098 |
+
else:
|
| 1099 |
+
do_gamma_correction = False
|
| 1100 |
+
|
| 1101 |
+
if aov_name == "roughness" or aov_name == "metallic":
|
| 1102 |
+
image = image[:, 0:1].repeat(1, 3, 1, 1)
|
| 1103 |
+
# print(image.shape)
|
| 1104 |
+
# print(old_height, old_width)
|
| 1105 |
+
image = torchvision.transforms.Resize((old_height, old_width),interpolation=PIL.Image.BICUBIC)(image)
|
| 1106 |
+
image = self.image_processor.postprocess(
|
| 1107 |
+
image,
|
| 1108 |
+
output_type=output_type,
|
| 1109 |
+
do_denormalize=do_denormalize,
|
| 1110 |
+
do_gamma_correction=do_gamma_correction,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# Offload all models
|
| 1114 |
+
self.maybe_free_model_hooks()
|
| 1115 |
+
|
| 1116 |
+
if not return_dict:
|
| 1117 |
+
return (image,)
|
| 1118 |
+
|
| 1119 |
+
return StableDiffusion3PipelineOutput(images=image)
|
pipeline_utils.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for loading transformer variants from ``transformer/<subfolder>/``."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import importlib.util
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def calculate_shift(
|
| 13 |
+
image_seq_len: int,
|
| 14 |
+
base_seq_len: int = 256,
|
| 15 |
+
max_seq_len: int = 4096,
|
| 16 |
+
base_shift: float = 0.5,
|
| 17 |
+
max_shift: float = 1.15,
|
| 18 |
+
) -> float:
|
| 19 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 20 |
+
b = base_shift - m * base_seq_len
|
| 21 |
+
return image_seq_len * m + b
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_flow_timesteps(
|
| 25 |
+
scheduler,
|
| 26 |
+
transformer,
|
| 27 |
+
num_inference_steps: int,
|
| 28 |
+
latent_height: int,
|
| 29 |
+
latent_width: int,
|
| 30 |
+
device: torch.device,
|
| 31 |
+
) -> None:
|
| 32 |
+
if scheduler.config.get("use_dynamic_shifting", False):
|
| 33 |
+
image_seq_len = (latent_height // transformer.config.patch_size) * (
|
| 34 |
+
latent_width // transformer.config.patch_size
|
| 35 |
+
)
|
| 36 |
+
mu = calculate_shift(
|
| 37 |
+
image_seq_len,
|
| 38 |
+
scheduler.config.get("base_image_seq_len", 256),
|
| 39 |
+
scheduler.config.get("max_image_seq_len", 4096),
|
| 40 |
+
scheduler.config.get("base_shift", 0.5),
|
| 41 |
+
scheduler.config.get("max_shift", 1.15),
|
| 42 |
+
)
|
| 43 |
+
scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
|
| 44 |
+
else:
|
| 45 |
+
scheduler.set_timesteps(num_inference_steps, device=device)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resolve_repo_dir(pretrained_model_name_or_path: str | Path) -> Path:
|
| 49 |
+
return Path(pretrained_model_name_or_path).resolve()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_transformer_from_subfolder(
|
| 53 |
+
repo_dir: str | Path,
|
| 54 |
+
transformer_subfolder: str,
|
| 55 |
+
*,
|
| 56 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 57 |
+
device: str | torch.device | None = None,
|
| 58 |
+
):
|
| 59 |
+
"""Load a transformer checkpoint from ``<repo_dir>/transformer/<transformer_subfolder>/``."""
|
| 60 |
+
repo_dir = resolve_repo_dir(repo_dir)
|
| 61 |
+
transformer_path = repo_dir / "transformer" / transformer_subfolder
|
| 62 |
+
if not transformer_path.is_dir():
|
| 63 |
+
raise FileNotFoundError(f"Transformer folder not found: {transformer_path}")
|
| 64 |
+
|
| 65 |
+
custom_module = transformer_path / "transformer_intrinsic_weather.py"
|
| 66 |
+
if custom_module.exists():
|
| 67 |
+
spec = importlib.util.spec_from_file_location("transformer_intrinsic_weather", custom_module)
|
| 68 |
+
if spec is None or spec.loader is None:
|
| 69 |
+
raise ImportError(f"Cannot import custom transformer module: {custom_module}")
|
| 70 |
+
module = importlib.util.module_from_spec(spec)
|
| 71 |
+
spec.loader.exec_module(module)
|
| 72 |
+
cls = module.IntrinsicWeatherSD3Transformer2DModel
|
| 73 |
+
transformer = cls.from_pretrained(
|
| 74 |
+
transformer_path.as_posix(),
|
| 75 |
+
torch_dtype=dtype,
|
| 76 |
+
local_files_only=True,
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
transformer = SD3Transformer2DModel.from_pretrained(
|
| 80 |
+
transformer_path.as_posix(),
|
| 81 |
+
torch_dtype=dtype,
|
| 82 |
+
local_files_only=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if device is not None:
|
| 86 |
+
transformer = transformer.to(device)
|
| 87 |
+
return transformer
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def resolve_transformer_lora_dir(repo_dir: str | Path, transformer_subfolder: str) -> Path | None:
|
| 91 |
+
"""Return ``transformer/<subfolder>/lora`` when present."""
|
| 92 |
+
lora_dir = resolve_repo_dir(repo_dir) / "transformer" / transformer_subfolder / "lora"
|
| 93 |
+
if lora_dir.is_dir() and any(lora_dir.glob("*.safetensors")):
|
| 94 |
+
return lora_dir
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def load_transformer_lora(pipe, repo_dir: str | Path, transformer_subfolder: str) -> bool:
|
| 99 |
+
"""Load LoRA weights bundled with a transformer variant. Returns True if loaded."""
|
| 100 |
+
lora_dir = resolve_transformer_lora_dir(repo_dir, transformer_subfolder)
|
| 101 |
+
if lora_dir is None:
|
| 102 |
+
return False
|
| 103 |
+
pipe.load_lora_weights(lora_dir.as_posix())
|
| 104 |
+
return True
|
scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.35.1",
|
| 4 |
+
"num_train_timesteps": 1000,
|
| 5 |
+
"shift": 3.0,
|
| 6 |
+
"use_dynamic_shifting": true,
|
| 7 |
+
"base_shift": 0.5,
|
| 8 |
+
"max_shift": 1.15
|
| 9 |
+
}
|
test_all_pipelines.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Smoke-test all IntrinsicWeather pipelines on CUDA with bfloat16."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
REPO = Path(__file__).resolve().parent
|
| 13 |
+
sys.path.insert(0, str(REPO))
|
| 14 |
+
|
| 15 |
+
DTYPE = torch.bfloat16
|
| 16 |
+
DEVICE = "cuda"
|
| 17 |
+
IMAGE_SIZE = 512
|
| 18 |
+
STEPS = 2
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _clear():
|
| 22 |
+
gc.collect()
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
torch.cuda.empty_cache()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_inverse() -> None:
|
| 28 |
+
from imaa.imaa import IMAA
|
| 29 |
+
from pipeline_intrinsic_weather_inverse import IntrinsicWeatherInversePipeline
|
| 30 |
+
from safetensors.torch import load_file
|
| 31 |
+
|
| 32 |
+
print("[inverse] loading pipeline ...")
|
| 33 |
+
pipe = IntrinsicWeatherInversePipeline.from_pretrained(
|
| 34 |
+
REPO,
|
| 35 |
+
transformer_subfolder="inverse-512",
|
| 36 |
+
device=DEVICE,
|
| 37 |
+
local_files_only=True,
|
| 38 |
+
torch_dtype=DTYPE,
|
| 39 |
+
)
|
| 40 |
+
assert next(pipe.transformer.parameters()).dtype == DTYPE
|
| 41 |
+
assert next(pipe.transformer.parameters()).device.type == "cuda"
|
| 42 |
+
print(f"[inverse] transformer in_channels={pipe.transformer.config.in_channels}")
|
| 43 |
+
|
| 44 |
+
imaa = IMAA(dino_model=None, processor=None, num_maps=5, map_embedding_dim=256, common_dim=128).to(DEVICE)
|
| 45 |
+
imaa.load_state_dict(load_file((REPO / "imaa" / "model.safetensors").as_posix()))
|
| 46 |
+
imaa.eval()
|
| 47 |
+
|
| 48 |
+
image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE, dtype=DTYPE)
|
| 49 |
+
prompt_embeds, _, pooled_prompt_embeds, _ = pipe.encode_prompt(
|
| 50 |
+
prompt="Albedo (diffuse basecolor)",
|
| 51 |
+
prompt_2=None,
|
| 52 |
+
prompt_3=None,
|
| 53 |
+
do_classifier_free_guidance=False,
|
| 54 |
+
)
|
| 55 |
+
print("[inverse] running 2-step inference ...")
|
| 56 |
+
out = pipe(
|
| 57 |
+
image=image,
|
| 58 |
+
prompt_embeds=prompt_embeds,
|
| 59 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 60 |
+
guidance_scale=0.0,
|
| 61 |
+
image_guidance_scale=0.0,
|
| 62 |
+
num_inference_steps=STEPS,
|
| 63 |
+
output_type="pt",
|
| 64 |
+
aov=["albedo"],
|
| 65 |
+
map_aware_mask=None,
|
| 66 |
+
)
|
| 67 |
+
print(f"[inverse] output shape={tuple(out.images[0].shape)} dtype={out.images[0].dtype}")
|
| 68 |
+
del pipe, imaa, out
|
| 69 |
+
_clear()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_forward() -> None:
|
| 73 |
+
from pipeline_intrinsic_weather_forward import IntrinsicWeatherForwardPipeline
|
| 74 |
+
|
| 75 |
+
print("[forward] loading pipeline ...")
|
| 76 |
+
pipe = IntrinsicWeatherForwardPipeline.from_pretrained(
|
| 77 |
+
REPO,
|
| 78 |
+
transformer_subfolder="forward",
|
| 79 |
+
device=DEVICE,
|
| 80 |
+
local_files_only=True,
|
| 81 |
+
torch_dtype=DTYPE,
|
| 82 |
+
load_lora=True,
|
| 83 |
+
)
|
| 84 |
+
assert next(pipe.transformer.parameters()).dtype == DTYPE
|
| 85 |
+
assert next(pipe.transformer.parameters()).device.type == "cuda"
|
| 86 |
+
print(f"[forward] transformer in_channels={pipe.transformer.config.in_channels}")
|
| 87 |
+
|
| 88 |
+
aov = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE, dtype=DTYPE)
|
| 89 |
+
print("[forward] running 2-step inference ...")
|
| 90 |
+
out = pipe(
|
| 91 |
+
albedo=aov,
|
| 92 |
+
normal=aov,
|
| 93 |
+
roughness=aov,
|
| 94 |
+
metallic=aov,
|
| 95 |
+
irradiance=aov,
|
| 96 |
+
prompt=["A rainy day."],
|
| 97 |
+
guidance_scale=6.0,
|
| 98 |
+
image_guidance_scale=1.5,
|
| 99 |
+
num_inference_steps=STEPS,
|
| 100 |
+
required_aovs=["albedo", "normal", "roughness", "metallic", "irradiance"],
|
| 101 |
+
generator=torch.Generator(device=DEVICE).manual_seed(0),
|
| 102 |
+
)
|
| 103 |
+
print(f"[forward] output size={out.images[0].size}")
|
| 104 |
+
del pipe, out
|
| 105 |
+
_clear()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def test_unified() -> None:
|
| 109 |
+
from pipeline_intrinsic_weather import IntrinsicWeatherPipeline
|
| 110 |
+
|
| 111 |
+
print("[unified] loading pipeline ...")
|
| 112 |
+
pipe = IntrinsicWeatherPipeline.from_pretrained(
|
| 113 |
+
REPO,
|
| 114 |
+
inverse_transformer_subfolder="inverse-512",
|
| 115 |
+
forward_transformer_subfolder="forward",
|
| 116 |
+
device=DEVICE,
|
| 117 |
+
local_files_only=True,
|
| 118 |
+
torch_dtype=DTYPE,
|
| 119 |
+
load_lora=True,
|
| 120 |
+
load_imaa=True,
|
| 121 |
+
)
|
| 122 |
+
assert next(pipe.inverse_transformer.parameters()).dtype == DTYPE
|
| 123 |
+
assert next(pipe.forward_transformer.parameters()).dtype == DTYPE
|
| 124 |
+
print("[unified] loaded inverse + forward transformers and IMAA")
|
| 125 |
+
del pipe
|
| 126 |
+
_clear()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def main() -> None:
|
| 130 |
+
if not torch.cuda.is_available():
|
| 131 |
+
raise SystemExit("CUDA is required for this test.")
|
| 132 |
+
|
| 133 |
+
print(f"device={torch.cuda.get_device_name(0)} dtype={DTYPE}")
|
| 134 |
+
test_inverse()
|
| 135 |
+
test_forward()
|
| 136 |
+
test_unified()
|
| 137 |
+
print("All pipeline tests passed.")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
main()
|
text_encoder/config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"CLIPTextModelWithProjection"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"hidden_act": "quick_gelu",
|
| 10 |
+
"hidden_size": 768,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"max_position_embeddings": 77,
|
| 16 |
+
"model_type": "clip_text_model",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"pad_token_id": 1,
|
| 20 |
+
"projection_dim": 768,
|
| 21 |
+
"torch_dtype": "float16",
|
| 22 |
+
"transformers_version": "4.41.2",
|
| 23 |
+
"vocab_size": 49408
|
| 24 |
+
}
|
text_encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71e183d11db0c6b6282a4d9e0abb74125edc8692393e89ed8ee5571005f35cb1
|
| 3 |
+
size 247323896
|
text_encoder_2/config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"CLIPTextModelWithProjection"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_size": 1280,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 5120,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"max_position_embeddings": 77,
|
| 16 |
+
"model_type": "clip_text_model",
|
| 17 |
+
"num_attention_heads": 20,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"pad_token_id": 1,
|
| 20 |
+
"projection_dim": 1280,
|
| 21 |
+
"torch_dtype": "float16",
|
| 22 |
+
"transformers_version": "4.41.2",
|
| 23 |
+
"vocab_size": 49408
|
| 24 |
+
}
|
text_encoder_2/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec310df2af79c318e24d20511b601a591ca8cd4f1fce1d8dff822a356bcdb1f4
|
| 3 |
+
size 1389382176
|
text_encoder_3/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"T5EncoderModel"
|
| 4 |
+
],
|
| 5 |
+
"classifier_dropout": 0.0,
|
| 6 |
+
"d_ff": 10240,
|
| 7 |
+
"d_kv": 64,
|
| 8 |
+
"d_model": 4096,
|
| 9 |
+
"decoder_start_token_id": 0,
|
| 10 |
+
"dense_act_fn": "gelu_new",
|
| 11 |
+
"dropout_rate": 0.1,
|
| 12 |
+
"eos_token_id": 1,
|
| 13 |
+
"feed_forward_proj": "gated-gelu",
|
| 14 |
+
"initializer_factor": 1.0,
|
| 15 |
+
"is_encoder_decoder": true,
|
| 16 |
+
"is_gated_act": true,
|
| 17 |
+
"layer_norm_epsilon": 1e-06,
|
| 18 |
+
"model_type": "t5",
|
| 19 |
+
"num_decoder_layers": 24,
|
| 20 |
+
"num_heads": 64,
|
| 21 |
+
"num_layers": 24,
|
| 22 |
+
"output_past": true,
|
| 23 |
+
"pad_token_id": 0,
|
| 24 |
+
"relative_attention_max_distance": 128,
|
| 25 |
+
"relative_attention_num_buckets": 32,
|
| 26 |
+
"tie_word_embeddings": false,
|
| 27 |
+
"torch_dtype": "float16",
|
| 28 |
+
"transformers_version": "4.41.2",
|
| 29 |
+
"use_cache": true,
|
| 30 |
+
"vocab_size": 32128
|
| 31 |
+
}
|
text_encoder_3/model-00001-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f2751ceeb2a96edd693e539dc5d6bba0b8d3814f49a9b3798403a0cec4b2e3d
|
| 3 |
+
size 4994582104
|
text_encoder_3/model-00002-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f63154532130422309532ff56f11945fbea8266c958e3133e8e5aef85c6293c7
|
| 3 |
+
size 4530066248
|
text_encoder_3/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 9524621312
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"encoder.block.0.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 7 |
+
"encoder.block.0.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 8 |
+
"encoder.block.0.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 9 |
+
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": "model-00001-of-00002.safetensors",
|
| 10 |
+
"encoder.block.0.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 11 |
+
"encoder.block.0.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 12 |
+
"encoder.block.0.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 13 |
+
"encoder.block.0.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 14 |
+
"encoder.block.0.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 15 |
+
"encoder.block.0.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 16 |
+
"encoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 17 |
+
"encoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 18 |
+
"encoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 19 |
+
"encoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 20 |
+
"encoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 21 |
+
"encoder.block.1.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 22 |
+
"encoder.block.1.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 23 |
+
"encoder.block.1.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 24 |
+
"encoder.block.1.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 25 |
+
"encoder.block.10.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 26 |
+
"encoder.block.10.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 27 |
+
"encoder.block.10.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 28 |
+
"encoder.block.10.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 29 |
+
"encoder.block.10.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 30 |
+
"encoder.block.10.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 31 |
+
"encoder.block.10.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 32 |
+
"encoder.block.10.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 33 |
+
"encoder.block.10.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 34 |
+
"encoder.block.11.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 35 |
+
"encoder.block.11.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 36 |
+
"encoder.block.11.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 37 |
+
"encoder.block.11.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 38 |
+
"encoder.block.11.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 39 |
+
"encoder.block.11.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 40 |
+
"encoder.block.11.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 41 |
+
"encoder.block.11.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 42 |
+
"encoder.block.11.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 43 |
+
"encoder.block.12.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 44 |
+
"encoder.block.12.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 45 |
+
"encoder.block.12.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 46 |
+
"encoder.block.12.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 47 |
+
"encoder.block.12.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 48 |
+
"encoder.block.12.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 49 |
+
"encoder.block.12.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 50 |
+
"encoder.block.12.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 51 |
+
"encoder.block.12.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 52 |
+
"encoder.block.13.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 53 |
+
"encoder.block.13.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 54 |
+
"encoder.block.13.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 55 |
+
"encoder.block.13.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 56 |
+
"encoder.block.13.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 57 |
+
"encoder.block.13.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 58 |
+
"encoder.block.13.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 59 |
+
"encoder.block.13.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 60 |
+
"encoder.block.13.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 61 |
+
"encoder.block.14.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 62 |
+
"encoder.block.14.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 63 |
+
"encoder.block.14.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 64 |
+
"encoder.block.14.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 65 |
+
"encoder.block.14.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 66 |
+
"encoder.block.14.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 67 |
+
"encoder.block.14.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 68 |
+
"encoder.block.14.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 69 |
+
"encoder.block.14.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 70 |
+
"encoder.block.15.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 71 |
+
"encoder.block.15.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 72 |
+
"encoder.block.15.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 73 |
+
"encoder.block.15.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 74 |
+
"encoder.block.15.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 75 |
+
"encoder.block.15.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 76 |
+
"encoder.block.15.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 77 |
+
"encoder.block.15.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 78 |
+
"encoder.block.15.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 79 |
+
"encoder.block.16.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 80 |
+
"encoder.block.16.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 81 |
+
"encoder.block.16.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 82 |
+
"encoder.block.16.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 83 |
+
"encoder.block.16.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 84 |
+
"encoder.block.16.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 85 |
+
"encoder.block.16.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 86 |
+
"encoder.block.16.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 87 |
+
"encoder.block.16.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 88 |
+
"encoder.block.17.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 89 |
+
"encoder.block.17.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 90 |
+
"encoder.block.17.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 91 |
+
"encoder.block.17.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 92 |
+
"encoder.block.17.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 93 |
+
"encoder.block.17.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 94 |
+
"encoder.block.17.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 95 |
+
"encoder.block.17.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 96 |
+
"encoder.block.17.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 97 |
+
"encoder.block.18.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 98 |
+
"encoder.block.18.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 99 |
+
"encoder.block.18.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 100 |
+
"encoder.block.18.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 101 |
+
"encoder.block.18.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 102 |
+
"encoder.block.18.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 103 |
+
"encoder.block.18.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 104 |
+
"encoder.block.18.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 105 |
+
"encoder.block.18.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 106 |
+
"encoder.block.19.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 107 |
+
"encoder.block.19.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 108 |
+
"encoder.block.19.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 109 |
+
"encoder.block.19.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 110 |
+
"encoder.block.19.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 111 |
+
"encoder.block.19.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 112 |
+
"encoder.block.19.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 113 |
+
"encoder.block.19.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 114 |
+
"encoder.block.19.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 115 |
+
"encoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 116 |
+
"encoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 117 |
+
"encoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 118 |
+
"encoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 119 |
+
"encoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 120 |
+
"encoder.block.2.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 121 |
+
"encoder.block.2.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 122 |
+
"encoder.block.2.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 123 |
+
"encoder.block.2.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 124 |
+
"encoder.block.20.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 125 |
+
"encoder.block.20.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 126 |
+
"encoder.block.20.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 127 |
+
"encoder.block.20.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 128 |
+
"encoder.block.20.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 129 |
+
"encoder.block.20.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 130 |
+
"encoder.block.20.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 131 |
+
"encoder.block.20.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 132 |
+
"encoder.block.20.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 133 |
+
"encoder.block.21.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 134 |
+
"encoder.block.21.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 135 |
+
"encoder.block.21.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 136 |
+
"encoder.block.21.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 137 |
+
"encoder.block.21.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 138 |
+
"encoder.block.21.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 139 |
+
"encoder.block.21.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 140 |
+
"encoder.block.21.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 141 |
+
"encoder.block.21.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 142 |
+
"encoder.block.22.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 143 |
+
"encoder.block.22.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 144 |
+
"encoder.block.22.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 145 |
+
"encoder.block.22.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 146 |
+
"encoder.block.22.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 147 |
+
"encoder.block.22.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 148 |
+
"encoder.block.22.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 149 |
+
"encoder.block.22.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 150 |
+
"encoder.block.22.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 151 |
+
"encoder.block.23.layer.0.SelfAttention.k.weight": "model-00002-of-00002.safetensors",
|
| 152 |
+
"encoder.block.23.layer.0.SelfAttention.o.weight": "model-00002-of-00002.safetensors",
|
| 153 |
+
"encoder.block.23.layer.0.SelfAttention.q.weight": "model-00002-of-00002.safetensors",
|
| 154 |
+
"encoder.block.23.layer.0.SelfAttention.v.weight": "model-00002-of-00002.safetensors",
|
| 155 |
+
"encoder.block.23.layer.0.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 156 |
+
"encoder.block.23.layer.1.DenseReluDense.wi_0.weight": "model-00002-of-00002.safetensors",
|
| 157 |
+
"encoder.block.23.layer.1.DenseReluDense.wi_1.weight": "model-00002-of-00002.safetensors",
|
| 158 |
+
"encoder.block.23.layer.1.DenseReluDense.wo.weight": "model-00002-of-00002.safetensors",
|
| 159 |
+
"encoder.block.23.layer.1.layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 160 |
+
"encoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 161 |
+
"encoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 162 |
+
"encoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 163 |
+
"encoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 164 |
+
"encoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 165 |
+
"encoder.block.3.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 166 |
+
"encoder.block.3.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 167 |
+
"encoder.block.3.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 168 |
+
"encoder.block.3.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 169 |
+
"encoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 170 |
+
"encoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 171 |
+
"encoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 172 |
+
"encoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 173 |
+
"encoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 174 |
+
"encoder.block.4.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 175 |
+
"encoder.block.4.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 176 |
+
"encoder.block.4.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 177 |
+
"encoder.block.4.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 178 |
+
"encoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 179 |
+
"encoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 180 |
+
"encoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 181 |
+
"encoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 182 |
+
"encoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 183 |
+
"encoder.block.5.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 184 |
+
"encoder.block.5.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 185 |
+
"encoder.block.5.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 186 |
+
"encoder.block.5.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 187 |
+
"encoder.block.6.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 188 |
+
"encoder.block.6.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 189 |
+
"encoder.block.6.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 190 |
+
"encoder.block.6.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 191 |
+
"encoder.block.6.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 192 |
+
"encoder.block.6.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 193 |
+
"encoder.block.6.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 194 |
+
"encoder.block.6.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 195 |
+
"encoder.block.6.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 196 |
+
"encoder.block.7.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 197 |
+
"encoder.block.7.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 198 |
+
"encoder.block.7.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 199 |
+
"encoder.block.7.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 200 |
+
"encoder.block.7.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 201 |
+
"encoder.block.7.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 202 |
+
"encoder.block.7.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 203 |
+
"encoder.block.7.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 204 |
+
"encoder.block.7.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 205 |
+
"encoder.block.8.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 206 |
+
"encoder.block.8.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 207 |
+
"encoder.block.8.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 208 |
+
"encoder.block.8.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 209 |
+
"encoder.block.8.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 210 |
+
"encoder.block.8.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 211 |
+
"encoder.block.8.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 212 |
+
"encoder.block.8.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 213 |
+
"encoder.block.8.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 214 |
+
"encoder.block.9.layer.0.SelfAttention.k.weight": "model-00001-of-00002.safetensors",
|
| 215 |
+
"encoder.block.9.layer.0.SelfAttention.o.weight": "model-00001-of-00002.safetensors",
|
| 216 |
+
"encoder.block.9.layer.0.SelfAttention.q.weight": "model-00001-of-00002.safetensors",
|
| 217 |
+
"encoder.block.9.layer.0.SelfAttention.v.weight": "model-00001-of-00002.safetensors",
|
| 218 |
+
"encoder.block.9.layer.0.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 219 |
+
"encoder.block.9.layer.1.DenseReluDense.wi_0.weight": "model-00001-of-00002.safetensors",
|
| 220 |
+
"encoder.block.9.layer.1.DenseReluDense.wi_1.weight": "model-00001-of-00002.safetensors",
|
| 221 |
+
"encoder.block.9.layer.1.DenseReluDense.wo.weight": "model-00001-of-00002.safetensors",
|
| 222 |
+
"encoder.block.9.layer.1.layer_norm.weight": "model-00001-of-00002.safetensors",
|
| 223 |
+
"encoder.final_layer_norm.weight": "model-00002-of-00002.safetensors",
|
| 224 |
+
"shared.weight": "model-00001-of-00002.safetensors"
|
| 225 |
+
}
|
| 226 |
+
}
|
tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|startoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<|endoftext|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<|endoftext|>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"49406": {
|
| 5 |
+
"content": "<|startoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"49407": {
|
| 13 |
+
"content": "<|endoftext|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
}
|
| 20 |
+
},
|
| 21 |
+
"bos_token": "<|startoftext|>",
|
| 22 |
+
"clean_up_tokenization_spaces": true,
|
| 23 |
+
"do_lower_case": true,
|
| 24 |
+
"eos_token": "<|endoftext|>",
|
| 25 |
+
"errors": "replace",
|
| 26 |
+
"model_max_length": 77,
|
| 27 |
+
"pad_token": "<|endoftext|>",
|
| 28 |
+
"tokenizer_class": "CLIPTokenizer",
|
| 29 |
+
"unk_token": "<|endoftext|>"
|
| 30 |
+
}
|
tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_2/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_2/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|startoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "!",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<|endoftext|>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenizer_2/tokenizer_config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "!",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"49406": {
|
| 13 |
+
"content": "<|startoftext|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": true,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"49407": {
|
| 21 |
+
"content": "<|endoftext|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"bos_token": "<|startoftext|>",
|
| 30 |
+
"clean_up_tokenization_spaces": true,
|
| 31 |
+
"do_lower_case": true,
|
| 32 |
+
"eos_token": "<|endoftext|>",
|
| 33 |
+
"errors": "replace",
|
| 34 |
+
"model_max_length": 77,
|
| 35 |
+
"pad_token": "!",
|
| 36 |
+
"tokenizer_class": "CLIPTokenizer",
|
| 37 |
+
"unk_token": "<|endoftext|>"
|
| 38 |
+
}
|
tokenizer_2/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_3/special_tokens_map.json
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<extra_id_0>",
|
| 4 |
+
"<extra_id_1>",
|
| 5 |
+
"<extra_id_2>",
|
| 6 |
+
"<extra_id_3>",
|
| 7 |
+
"<extra_id_4>",
|
| 8 |
+
"<extra_id_5>",
|
| 9 |
+
"<extra_id_6>",
|
| 10 |
+
"<extra_id_7>",
|
| 11 |
+
"<extra_id_8>",
|
| 12 |
+
"<extra_id_9>",
|
| 13 |
+
"<extra_id_10>",
|
| 14 |
+
"<extra_id_11>",
|
| 15 |
+
"<extra_id_12>",
|
| 16 |
+
"<extra_id_13>",
|
| 17 |
+
"<extra_id_14>",
|
| 18 |
+
"<extra_id_15>",
|
| 19 |
+
"<extra_id_16>",
|
| 20 |
+
"<extra_id_17>",
|
| 21 |
+
"<extra_id_18>",
|
| 22 |
+
"<extra_id_19>",
|
| 23 |
+
"<extra_id_20>",
|
| 24 |
+
"<extra_id_21>",
|
| 25 |
+
"<extra_id_22>",
|
| 26 |
+
"<extra_id_23>",
|
| 27 |
+
"<extra_id_24>",
|
| 28 |
+
"<extra_id_25>",
|
| 29 |
+
"<extra_id_26>",
|
| 30 |
+
"<extra_id_27>",
|
| 31 |
+
"<extra_id_28>",
|
| 32 |
+
"<extra_id_29>",
|
| 33 |
+
"<extra_id_30>",
|
| 34 |
+
"<extra_id_31>",
|
| 35 |
+
"<extra_id_32>",
|
| 36 |
+
"<extra_id_33>",
|
| 37 |
+
"<extra_id_34>",
|
| 38 |
+
"<extra_id_35>",
|
| 39 |
+
"<extra_id_36>",
|
| 40 |
+
"<extra_id_37>",
|
| 41 |
+
"<extra_id_38>",
|
| 42 |
+
"<extra_id_39>",
|
| 43 |
+
"<extra_id_40>",
|
| 44 |
+
"<extra_id_41>",
|
| 45 |
+
"<extra_id_42>",
|
| 46 |
+
"<extra_id_43>",
|
| 47 |
+
"<extra_id_44>",
|
| 48 |
+
"<extra_id_45>",
|
| 49 |
+
"<extra_id_46>",
|
| 50 |
+
"<extra_id_47>",
|
| 51 |
+
"<extra_id_48>",
|
| 52 |
+
"<extra_id_49>",
|
| 53 |
+
"<extra_id_50>",
|
| 54 |
+
"<extra_id_51>",
|
| 55 |
+
"<extra_id_52>",
|
| 56 |
+
"<extra_id_53>",
|
| 57 |
+
"<extra_id_54>",
|
| 58 |
+
"<extra_id_55>",
|
| 59 |
+
"<extra_id_56>",
|
| 60 |
+
"<extra_id_57>",
|
| 61 |
+
"<extra_id_58>",
|
| 62 |
+
"<extra_id_59>",
|
| 63 |
+
"<extra_id_60>",
|
| 64 |
+
"<extra_id_61>",
|
| 65 |
+
"<extra_id_62>",
|
| 66 |
+
"<extra_id_63>",
|
| 67 |
+
"<extra_id_64>",
|
| 68 |
+
"<extra_id_65>",
|
| 69 |
+
"<extra_id_66>",
|
| 70 |
+
"<extra_id_67>",
|
| 71 |
+
"<extra_id_68>",
|
| 72 |
+
"<extra_id_69>",
|
| 73 |
+
"<extra_id_70>",
|
| 74 |
+
"<extra_id_71>",
|
| 75 |
+
"<extra_id_72>",
|
| 76 |
+
"<extra_id_73>",
|
| 77 |
+
"<extra_id_74>",
|
| 78 |
+
"<extra_id_75>",
|
| 79 |
+
"<extra_id_76>",
|
| 80 |
+
"<extra_id_77>",
|
| 81 |
+
"<extra_id_78>",
|
| 82 |
+
"<extra_id_79>",
|
| 83 |
+
"<extra_id_80>",
|
| 84 |
+
"<extra_id_81>",
|
| 85 |
+
"<extra_id_82>",
|
| 86 |
+
"<extra_id_83>",
|
| 87 |
+
"<extra_id_84>",
|
| 88 |
+
"<extra_id_85>",
|
| 89 |
+
"<extra_id_86>",
|
| 90 |
+
"<extra_id_87>",
|
| 91 |
+
"<extra_id_88>",
|
| 92 |
+
"<extra_id_89>",
|
| 93 |
+
"<extra_id_90>",
|
| 94 |
+
"<extra_id_91>",
|
| 95 |
+
"<extra_id_92>",
|
| 96 |
+
"<extra_id_93>",
|
| 97 |
+
"<extra_id_94>",
|
| 98 |
+
"<extra_id_95>",
|
| 99 |
+
"<extra_id_96>",
|
| 100 |
+
"<extra_id_97>",
|
| 101 |
+
"<extra_id_98>",
|
| 102 |
+
"<extra_id_99>"
|
| 103 |
+
],
|
| 104 |
+
"eos_token": {
|
| 105 |
+
"content": "</s>",
|
| 106 |
+
"lstrip": false,
|
| 107 |
+
"normalized": false,
|
| 108 |
+
"rstrip": false,
|
| 109 |
+
"single_word": false
|
| 110 |
+
},
|
| 111 |
+
"pad_token": {
|
| 112 |
+
"content": "<pad>",
|
| 113 |
+
"lstrip": false,
|
| 114 |
+
"normalized": false,
|
| 115 |
+
"rstrip": false,
|
| 116 |
+
"single_word": false
|
| 117 |
+
},
|
| 118 |
+
"unk_token": {
|
| 119 |
+
"content": "<unk>",
|
| 120 |
+
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
+
"rstrip": false,
|
| 123 |
+
"single_word": false
|
| 124 |
+
}
|
| 125 |
+
}
|
tokenizer_3/spiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
| 3 |
+
size 791656
|
tokenizer_3/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_3/tokenizer_config.json
ADDED
|
@@ -0,0 +1,940 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": true,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "<pad>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"1": {
|
| 13 |
+
"content": "</s>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"2": {
|
| 21 |
+
"content": "<unk>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"32000": {
|
| 29 |
+
"content": "<extra_id_99>",
|
| 30 |
+
"lstrip": true,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": true,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"32001": {
|
| 37 |
+
"content": "<extra_id_98>",
|
| 38 |
+
"lstrip": true,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": true,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"32002": {
|
| 45 |
+
"content": "<extra_id_97>",
|
| 46 |
+
"lstrip": true,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": true,
|
| 49 |
+
"single_word": false,
|
| 50 |
+
"special": true
|
| 51 |
+
},
|
| 52 |
+
"32003": {
|
| 53 |
+
"content": "<extra_id_96>",
|
| 54 |
+
"lstrip": true,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": true,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
"32004": {
|
| 61 |
+
"content": "<extra_id_95>",
|
| 62 |
+
"lstrip": true,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": true,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": true
|
| 67 |
+
},
|
| 68 |
+
"32005": {
|
| 69 |
+
"content": "<extra_id_94>",
|
| 70 |
+
"lstrip": true,
|
| 71 |
+
"normalized": false,
|
| 72 |
+
"rstrip": true,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"32006": {
|
| 77 |
+
"content": "<extra_id_93>",
|
| 78 |
+
"lstrip": true,
|
| 79 |
+
"normalized": false,
|
| 80 |
+
"rstrip": true,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"32007": {
|
| 85 |
+
"content": "<extra_id_92>",
|
| 86 |
+
"lstrip": true,
|
| 87 |
+
"normalized": false,
|
| 88 |
+
"rstrip": true,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": true
|
| 91 |
+
},
|
| 92 |
+
"32008": {
|
| 93 |
+
"content": "<extra_id_91>",
|
| 94 |
+
"lstrip": true,
|
| 95 |
+
"normalized": false,
|
| 96 |
+
"rstrip": true,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": true
|
| 99 |
+
},
|
| 100 |
+
"32009": {
|
| 101 |
+
"content": "<extra_id_90>",
|
| 102 |
+
"lstrip": true,
|
| 103 |
+
"normalized": false,
|
| 104 |
+
"rstrip": true,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": true
|
| 107 |
+
},
|
| 108 |
+
"32010": {
|
| 109 |
+
"content": "<extra_id_89>",
|
| 110 |
+
"lstrip": true,
|
| 111 |
+
"normalized": false,
|
| 112 |
+
"rstrip": true,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": true
|
| 115 |
+
},
|
| 116 |
+
"32011": {
|
| 117 |
+
"content": "<extra_id_88>",
|
| 118 |
+
"lstrip": true,
|
| 119 |
+
"normalized": false,
|
| 120 |
+
"rstrip": true,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": true
|
| 123 |
+
},
|
| 124 |
+
"32012": {
|
| 125 |
+
"content": "<extra_id_87>",
|
| 126 |
+
"lstrip": true,
|
| 127 |
+
"normalized": false,
|
| 128 |
+
"rstrip": true,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": true
|
| 131 |
+
},
|
| 132 |
+
"32013": {
|
| 133 |
+
"content": "<extra_id_86>",
|
| 134 |
+
"lstrip": true,
|
| 135 |
+
"normalized": false,
|
| 136 |
+
"rstrip": true,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": true
|
| 139 |
+
},
|
| 140 |
+
"32014": {
|
| 141 |
+
"content": "<extra_id_85>",
|
| 142 |
+
"lstrip": true,
|
| 143 |
+
"normalized": false,
|
| 144 |
+
"rstrip": true,
|
| 145 |
+
"single_word": false,
|
| 146 |
+
"special": true
|
| 147 |
+
},
|
| 148 |
+
"32015": {
|
| 149 |
+
"content": "<extra_id_84>",
|
| 150 |
+
"lstrip": true,
|
| 151 |
+
"normalized": false,
|
| 152 |
+
"rstrip": true,
|
| 153 |
+
"single_word": false,
|
| 154 |
+
"special": true
|
| 155 |
+
},
|
| 156 |
+
"32016": {
|
| 157 |
+
"content": "<extra_id_83>",
|
| 158 |
+
"lstrip": true,
|
| 159 |
+
"normalized": false,
|
| 160 |
+
"rstrip": true,
|
| 161 |
+
"single_word": false,
|
| 162 |
+
"special": true
|
| 163 |
+
},
|
| 164 |
+
"32017": {
|
| 165 |
+
"content": "<extra_id_82>",
|
| 166 |
+
"lstrip": true,
|
| 167 |
+
"normalized": false,
|
| 168 |
+
"rstrip": true,
|
| 169 |
+
"single_word": false,
|
| 170 |
+
"special": true
|
| 171 |
+
},
|
| 172 |
+
"32018": {
|
| 173 |
+
"content": "<extra_id_81>",
|
| 174 |
+
"lstrip": true,
|
| 175 |
+
"normalized": false,
|
| 176 |
+
"rstrip": true,
|
| 177 |
+
"single_word": false,
|
| 178 |
+
"special": true
|
| 179 |
+
},
|
| 180 |
+
"32019": {
|
| 181 |
+
"content": "<extra_id_80>",
|
| 182 |
+
"lstrip": true,
|
| 183 |
+
"normalized": false,
|
| 184 |
+
"rstrip": true,
|
| 185 |
+
"single_word": false,
|
| 186 |
+
"special": true
|
| 187 |
+
},
|
| 188 |
+
"32020": {
|
| 189 |
+
"content": "<extra_id_79>",
|
| 190 |
+
"lstrip": true,
|
| 191 |
+
"normalized": false,
|
| 192 |
+
"rstrip": true,
|
| 193 |
+
"single_word": false,
|
| 194 |
+
"special": true
|
| 195 |
+
},
|
| 196 |
+
"32021": {
|
| 197 |
+
"content": "<extra_id_78>",
|
| 198 |
+
"lstrip": true,
|
| 199 |
+
"normalized": false,
|
| 200 |
+
"rstrip": true,
|
| 201 |
+
"single_word": false,
|
| 202 |
+
"special": true
|
| 203 |
+
},
|
| 204 |
+
"32022": {
|
| 205 |
+
"content": "<extra_id_77>",
|
| 206 |
+
"lstrip": true,
|
| 207 |
+
"normalized": false,
|
| 208 |
+
"rstrip": true,
|
| 209 |
+
"single_word": false,
|
| 210 |
+
"special": true
|
| 211 |
+
},
|
| 212 |
+
"32023": {
|
| 213 |
+
"content": "<extra_id_76>",
|
| 214 |
+
"lstrip": true,
|
| 215 |
+
"normalized": false,
|
| 216 |
+
"rstrip": true,
|
| 217 |
+
"single_word": false,
|
| 218 |
+
"special": true
|
| 219 |
+
},
|
| 220 |
+
"32024": {
|
| 221 |
+
"content": "<extra_id_75>",
|
| 222 |
+
"lstrip": true,
|
| 223 |
+
"normalized": false,
|
| 224 |
+
"rstrip": true,
|
| 225 |
+
"single_word": false,
|
| 226 |
+
"special": true
|
| 227 |
+
},
|
| 228 |
+
"32025": {
|
| 229 |
+
"content": "<extra_id_74>",
|
| 230 |
+
"lstrip": true,
|
| 231 |
+
"normalized": false,
|
| 232 |
+
"rstrip": true,
|
| 233 |
+
"single_word": false,
|
| 234 |
+
"special": true
|
| 235 |
+
},
|
| 236 |
+
"32026": {
|
| 237 |
+
"content": "<extra_id_73>",
|
| 238 |
+
"lstrip": true,
|
| 239 |
+
"normalized": false,
|
| 240 |
+
"rstrip": true,
|
| 241 |
+
"single_word": false,
|
| 242 |
+
"special": true
|
| 243 |
+
},
|
| 244 |
+
"32027": {
|
| 245 |
+
"content": "<extra_id_72>",
|
| 246 |
+
"lstrip": true,
|
| 247 |
+
"normalized": false,
|
| 248 |
+
"rstrip": true,
|
| 249 |
+
"single_word": false,
|
| 250 |
+
"special": true
|
| 251 |
+
},
|
| 252 |
+
"32028": {
|
| 253 |
+
"content": "<extra_id_71>",
|
| 254 |
+
"lstrip": true,
|
| 255 |
+
"normalized": false,
|
| 256 |
+
"rstrip": true,
|
| 257 |
+
"single_word": false,
|
| 258 |
+
"special": true
|
| 259 |
+
},
|
| 260 |
+
"32029": {
|
| 261 |
+
"content": "<extra_id_70>",
|
| 262 |
+
"lstrip": true,
|
| 263 |
+
"normalized": false,
|
| 264 |
+
"rstrip": true,
|
| 265 |
+
"single_word": false,
|
| 266 |
+
"special": true
|
| 267 |
+
},
|
| 268 |
+
"32030": {
|
| 269 |
+
"content": "<extra_id_69>",
|
| 270 |
+
"lstrip": true,
|
| 271 |
+
"normalized": false,
|
| 272 |
+
"rstrip": true,
|
| 273 |
+
"single_word": false,
|
| 274 |
+
"special": true
|
| 275 |
+
},
|
| 276 |
+
"32031": {
|
| 277 |
+
"content": "<extra_id_68>",
|
| 278 |
+
"lstrip": true,
|
| 279 |
+
"normalized": false,
|
| 280 |
+
"rstrip": true,
|
| 281 |
+
"single_word": false,
|
| 282 |
+
"special": true
|
| 283 |
+
},
|
| 284 |
+
"32032": {
|
| 285 |
+
"content": "<extra_id_67>",
|
| 286 |
+
"lstrip": true,
|
| 287 |
+
"normalized": false,
|
| 288 |
+
"rstrip": true,
|
| 289 |
+
"single_word": false,
|
| 290 |
+
"special": true
|
| 291 |
+
},
|
| 292 |
+
"32033": {
|
| 293 |
+
"content": "<extra_id_66>",
|
| 294 |
+
"lstrip": true,
|
| 295 |
+
"normalized": false,
|
| 296 |
+
"rstrip": true,
|
| 297 |
+
"single_word": false,
|
| 298 |
+
"special": true
|
| 299 |
+
},
|
| 300 |
+
"32034": {
|
| 301 |
+
"content": "<extra_id_65>",
|
| 302 |
+
"lstrip": true,
|
| 303 |
+
"normalized": false,
|
| 304 |
+
"rstrip": true,
|
| 305 |
+
"single_word": false,
|
| 306 |
+
"special": true
|
| 307 |
+
},
|
| 308 |
+
"32035": {
|
| 309 |
+
"content": "<extra_id_64>",
|
| 310 |
+
"lstrip": true,
|
| 311 |
+
"normalized": false,
|
| 312 |
+
"rstrip": true,
|
| 313 |
+
"single_word": false,
|
| 314 |
+
"special": true
|
| 315 |
+
},
|
| 316 |
+
"32036": {
|
| 317 |
+
"content": "<extra_id_63>",
|
| 318 |
+
"lstrip": true,
|
| 319 |
+
"normalized": false,
|
| 320 |
+
"rstrip": true,
|
| 321 |
+
"single_word": false,
|
| 322 |
+
"special": true
|
| 323 |
+
},
|
| 324 |
+
"32037": {
|
| 325 |
+
"content": "<extra_id_62>",
|
| 326 |
+
"lstrip": true,
|
| 327 |
+
"normalized": false,
|
| 328 |
+
"rstrip": true,
|
| 329 |
+
"single_word": false,
|
| 330 |
+
"special": true
|
| 331 |
+
},
|
| 332 |
+
"32038": {
|
| 333 |
+
"content": "<extra_id_61>",
|
| 334 |
+
"lstrip": true,
|
| 335 |
+
"normalized": false,
|
| 336 |
+
"rstrip": true,
|
| 337 |
+
"single_word": false,
|
| 338 |
+
"special": true
|
| 339 |
+
},
|
| 340 |
+
"32039": {
|
| 341 |
+
"content": "<extra_id_60>",
|
| 342 |
+
"lstrip": true,
|
| 343 |
+
"normalized": false,
|
| 344 |
+
"rstrip": true,
|
| 345 |
+
"single_word": false,
|
| 346 |
+
"special": true
|
| 347 |
+
},
|
| 348 |
+
"32040": {
|
| 349 |
+
"content": "<extra_id_59>",
|
| 350 |
+
"lstrip": true,
|
| 351 |
+
"normalized": false,
|
| 352 |
+
"rstrip": true,
|
| 353 |
+
"single_word": false,
|
| 354 |
+
"special": true
|
| 355 |
+
},
|
| 356 |
+
"32041": {
|
| 357 |
+
"content": "<extra_id_58>",
|
| 358 |
+
"lstrip": true,
|
| 359 |
+
"normalized": false,
|
| 360 |
+
"rstrip": true,
|
| 361 |
+
"single_word": false,
|
| 362 |
+
"special": true
|
| 363 |
+
},
|
| 364 |
+
"32042": {
|
| 365 |
+
"content": "<extra_id_57>",
|
| 366 |
+
"lstrip": true,
|
| 367 |
+
"normalized": false,
|
| 368 |
+
"rstrip": true,
|
| 369 |
+
"single_word": false,
|
| 370 |
+
"special": true
|
| 371 |
+
},
|
| 372 |
+
"32043": {
|
| 373 |
+
"content": "<extra_id_56>",
|
| 374 |
+
"lstrip": true,
|
| 375 |
+
"normalized": false,
|
| 376 |
+
"rstrip": true,
|
| 377 |
+
"single_word": false,
|
| 378 |
+
"special": true
|
| 379 |
+
},
|
| 380 |
+
"32044": {
|
| 381 |
+
"content": "<extra_id_55>",
|
| 382 |
+
"lstrip": true,
|
| 383 |
+
"normalized": false,
|
| 384 |
+
"rstrip": true,
|
| 385 |
+
"single_word": false,
|
| 386 |
+
"special": true
|
| 387 |
+
},
|
| 388 |
+
"32045": {
|
| 389 |
+
"content": "<extra_id_54>",
|
| 390 |
+
"lstrip": true,
|
| 391 |
+
"normalized": false,
|
| 392 |
+
"rstrip": true,
|
| 393 |
+
"single_word": false,
|
| 394 |
+
"special": true
|
| 395 |
+
},
|
| 396 |
+
"32046": {
|
| 397 |
+
"content": "<extra_id_53>",
|
| 398 |
+
"lstrip": true,
|
| 399 |
+
"normalized": false,
|
| 400 |
+
"rstrip": true,
|
| 401 |
+
"single_word": false,
|
| 402 |
+
"special": true
|
| 403 |
+
},
|
| 404 |
+
"32047": {
|
| 405 |
+
"content": "<extra_id_52>",
|
| 406 |
+
"lstrip": true,
|
| 407 |
+
"normalized": false,
|
| 408 |
+
"rstrip": true,
|
| 409 |
+
"single_word": false,
|
| 410 |
+
"special": true
|
| 411 |
+
},
|
| 412 |
+
"32048": {
|
| 413 |
+
"content": "<extra_id_51>",
|
| 414 |
+
"lstrip": true,
|
| 415 |
+
"normalized": false,
|
| 416 |
+
"rstrip": true,
|
| 417 |
+
"single_word": false,
|
| 418 |
+
"special": true
|
| 419 |
+
},
|
| 420 |
+
"32049": {
|
| 421 |
+
"content": "<extra_id_50>",
|
| 422 |
+
"lstrip": true,
|
| 423 |
+
"normalized": false,
|
| 424 |
+
"rstrip": true,
|
| 425 |
+
"single_word": false,
|
| 426 |
+
"special": true
|
| 427 |
+
},
|
| 428 |
+
"32050": {
|
| 429 |
+
"content": "<extra_id_49>",
|
| 430 |
+
"lstrip": true,
|
| 431 |
+
"normalized": false,
|
| 432 |
+
"rstrip": true,
|
| 433 |
+
"single_word": false,
|
| 434 |
+
"special": true
|
| 435 |
+
},
|
| 436 |
+
"32051": {
|
| 437 |
+
"content": "<extra_id_48>",
|
| 438 |
+
"lstrip": true,
|
| 439 |
+
"normalized": false,
|
| 440 |
+
"rstrip": true,
|
| 441 |
+
"single_word": false,
|
| 442 |
+
"special": true
|
| 443 |
+
},
|
| 444 |
+
"32052": {
|
| 445 |
+
"content": "<extra_id_47>",
|
| 446 |
+
"lstrip": true,
|
| 447 |
+
"normalized": false,
|
| 448 |
+
"rstrip": true,
|
| 449 |
+
"single_word": false,
|
| 450 |
+
"special": true
|
| 451 |
+
},
|
| 452 |
+
"32053": {
|
| 453 |
+
"content": "<extra_id_46>",
|
| 454 |
+
"lstrip": true,
|
| 455 |
+
"normalized": false,
|
| 456 |
+
"rstrip": true,
|
| 457 |
+
"single_word": false,
|
| 458 |
+
"special": true
|
| 459 |
+
},
|
| 460 |
+
"32054": {
|
| 461 |
+
"content": "<extra_id_45>",
|
| 462 |
+
"lstrip": true,
|
| 463 |
+
"normalized": false,
|
| 464 |
+
"rstrip": true,
|
| 465 |
+
"single_word": false,
|
| 466 |
+
"special": true
|
| 467 |
+
},
|
| 468 |
+
"32055": {
|
| 469 |
+
"content": "<extra_id_44>",
|
| 470 |
+
"lstrip": true,
|
| 471 |
+
"normalized": false,
|
| 472 |
+
"rstrip": true,
|
| 473 |
+
"single_word": false,
|
| 474 |
+
"special": true
|
| 475 |
+
},
|
| 476 |
+
"32056": {
|
| 477 |
+
"content": "<extra_id_43>",
|
| 478 |
+
"lstrip": true,
|
| 479 |
+
"normalized": false,
|
| 480 |
+
"rstrip": true,
|
| 481 |
+
"single_word": false,
|
| 482 |
+
"special": true
|
| 483 |
+
},
|
| 484 |
+
"32057": {
|
| 485 |
+
"content": "<extra_id_42>",
|
| 486 |
+
"lstrip": true,
|
| 487 |
+
"normalized": false,
|
| 488 |
+
"rstrip": true,
|
| 489 |
+
"single_word": false,
|
| 490 |
+
"special": true
|
| 491 |
+
},
|
| 492 |
+
"32058": {
|
| 493 |
+
"content": "<extra_id_41>",
|
| 494 |
+
"lstrip": true,
|
| 495 |
+
"normalized": false,
|
| 496 |
+
"rstrip": true,
|
| 497 |
+
"single_word": false,
|
| 498 |
+
"special": true
|
| 499 |
+
},
|
| 500 |
+
"32059": {
|
| 501 |
+
"content": "<extra_id_40>",
|
| 502 |
+
"lstrip": true,
|
| 503 |
+
"normalized": false,
|
| 504 |
+
"rstrip": true,
|
| 505 |
+
"single_word": false,
|
| 506 |
+
"special": true
|
| 507 |
+
},
|
| 508 |
+
"32060": {
|
| 509 |
+
"content": "<extra_id_39>",
|
| 510 |
+
"lstrip": true,
|
| 511 |
+
"normalized": false,
|
| 512 |
+
"rstrip": true,
|
| 513 |
+
"single_word": false,
|
| 514 |
+
"special": true
|
| 515 |
+
},
|
| 516 |
+
"32061": {
|
| 517 |
+
"content": "<extra_id_38>",
|
| 518 |
+
"lstrip": true,
|
| 519 |
+
"normalized": false,
|
| 520 |
+
"rstrip": true,
|
| 521 |
+
"single_word": false,
|
| 522 |
+
"special": true
|
| 523 |
+
},
|
| 524 |
+
"32062": {
|
| 525 |
+
"content": "<extra_id_37>",
|
| 526 |
+
"lstrip": true,
|
| 527 |
+
"normalized": false,
|
| 528 |
+
"rstrip": true,
|
| 529 |
+
"single_word": false,
|
| 530 |
+
"special": true
|
| 531 |
+
},
|
| 532 |
+
"32063": {
|
| 533 |
+
"content": "<extra_id_36>",
|
| 534 |
+
"lstrip": true,
|
| 535 |
+
"normalized": false,
|
| 536 |
+
"rstrip": true,
|
| 537 |
+
"single_word": false,
|
| 538 |
+
"special": true
|
| 539 |
+
},
|
| 540 |
+
"32064": {
|
| 541 |
+
"content": "<extra_id_35>",
|
| 542 |
+
"lstrip": true,
|
| 543 |
+
"normalized": false,
|
| 544 |
+
"rstrip": true,
|
| 545 |
+
"single_word": false,
|
| 546 |
+
"special": true
|
| 547 |
+
},
|
| 548 |
+
"32065": {
|
| 549 |
+
"content": "<extra_id_34>",
|
| 550 |
+
"lstrip": true,
|
| 551 |
+
"normalized": false,
|
| 552 |
+
"rstrip": true,
|
| 553 |
+
"single_word": false,
|
| 554 |
+
"special": true
|
| 555 |
+
},
|
| 556 |
+
"32066": {
|
| 557 |
+
"content": "<extra_id_33>",
|
| 558 |
+
"lstrip": true,
|
| 559 |
+
"normalized": false,
|
| 560 |
+
"rstrip": true,
|
| 561 |
+
"single_word": false,
|
| 562 |
+
"special": true
|
| 563 |
+
},
|
| 564 |
+
"32067": {
|
| 565 |
+
"content": "<extra_id_32>",
|
| 566 |
+
"lstrip": true,
|
| 567 |
+
"normalized": false,
|
| 568 |
+
"rstrip": true,
|
| 569 |
+
"single_word": false,
|
| 570 |
+
"special": true
|
| 571 |
+
},
|
| 572 |
+
"32068": {
|
| 573 |
+
"content": "<extra_id_31>",
|
| 574 |
+
"lstrip": true,
|
| 575 |
+
"normalized": false,
|
| 576 |
+
"rstrip": true,
|
| 577 |
+
"single_word": false,
|
| 578 |
+
"special": true
|
| 579 |
+
},
|
| 580 |
+
"32069": {
|
| 581 |
+
"content": "<extra_id_30>",
|
| 582 |
+
"lstrip": true,
|
| 583 |
+
"normalized": false,
|
| 584 |
+
"rstrip": true,
|
| 585 |
+
"single_word": false,
|
| 586 |
+
"special": true
|
| 587 |
+
},
|
| 588 |
+
"32070": {
|
| 589 |
+
"content": "<extra_id_29>",
|
| 590 |
+
"lstrip": true,
|
| 591 |
+
"normalized": false,
|
| 592 |
+
"rstrip": true,
|
| 593 |
+
"single_word": false,
|
| 594 |
+
"special": true
|
| 595 |
+
},
|
| 596 |
+
"32071": {
|
| 597 |
+
"content": "<extra_id_28>",
|
| 598 |
+
"lstrip": true,
|
| 599 |
+
"normalized": false,
|
| 600 |
+
"rstrip": true,
|
| 601 |
+
"single_word": false,
|
| 602 |
+
"special": true
|
| 603 |
+
},
|
| 604 |
+
"32072": {
|
| 605 |
+
"content": "<extra_id_27>",
|
| 606 |
+
"lstrip": true,
|
| 607 |
+
"normalized": false,
|
| 608 |
+
"rstrip": true,
|
| 609 |
+
"single_word": false,
|
| 610 |
+
"special": true
|
| 611 |
+
},
|
| 612 |
+
"32073": {
|
| 613 |
+
"content": "<extra_id_26>",
|
| 614 |
+
"lstrip": true,
|
| 615 |
+
"normalized": false,
|
| 616 |
+
"rstrip": true,
|
| 617 |
+
"single_word": false,
|
| 618 |
+
"special": true
|
| 619 |
+
},
|
| 620 |
+
"32074": {
|
| 621 |
+
"content": "<extra_id_25>",
|
| 622 |
+
"lstrip": true,
|
| 623 |
+
"normalized": false,
|
| 624 |
+
"rstrip": true,
|
| 625 |
+
"single_word": false,
|
| 626 |
+
"special": true
|
| 627 |
+
},
|
| 628 |
+
"32075": {
|
| 629 |
+
"content": "<extra_id_24>",
|
| 630 |
+
"lstrip": true,
|
| 631 |
+
"normalized": false,
|
| 632 |
+
"rstrip": true,
|
| 633 |
+
"single_word": false,
|
| 634 |
+
"special": true
|
| 635 |
+
},
|
| 636 |
+
"32076": {
|
| 637 |
+
"content": "<extra_id_23>",
|
| 638 |
+
"lstrip": true,
|
| 639 |
+
"normalized": false,
|
| 640 |
+
"rstrip": true,
|
| 641 |
+
"single_word": false,
|
| 642 |
+
"special": true
|
| 643 |
+
},
|
| 644 |
+
"32077": {
|
| 645 |
+
"content": "<extra_id_22>",
|
| 646 |
+
"lstrip": true,
|
| 647 |
+
"normalized": false,
|
| 648 |
+
"rstrip": true,
|
| 649 |
+
"single_word": false,
|
| 650 |
+
"special": true
|
| 651 |
+
},
|
| 652 |
+
"32078": {
|
| 653 |
+
"content": "<extra_id_21>",
|
| 654 |
+
"lstrip": true,
|
| 655 |
+
"normalized": false,
|
| 656 |
+
"rstrip": true,
|
| 657 |
+
"single_word": false,
|
| 658 |
+
"special": true
|
| 659 |
+
},
|
| 660 |
+
"32079": {
|
| 661 |
+
"content": "<extra_id_20>",
|
| 662 |
+
"lstrip": true,
|
| 663 |
+
"normalized": false,
|
| 664 |
+
"rstrip": true,
|
| 665 |
+
"single_word": false,
|
| 666 |
+
"special": true
|
| 667 |
+
},
|
| 668 |
+
"32080": {
|
| 669 |
+
"content": "<extra_id_19>",
|
| 670 |
+
"lstrip": true,
|
| 671 |
+
"normalized": false,
|
| 672 |
+
"rstrip": true,
|
| 673 |
+
"single_word": false,
|
| 674 |
+
"special": true
|
| 675 |
+
},
|
| 676 |
+
"32081": {
|
| 677 |
+
"content": "<extra_id_18>",
|
| 678 |
+
"lstrip": true,
|
| 679 |
+
"normalized": false,
|
| 680 |
+
"rstrip": true,
|
| 681 |
+
"single_word": false,
|
| 682 |
+
"special": true
|
| 683 |
+
},
|
| 684 |
+
"32082": {
|
| 685 |
+
"content": "<extra_id_17>",
|
| 686 |
+
"lstrip": true,
|
| 687 |
+
"normalized": false,
|
| 688 |
+
"rstrip": true,
|
| 689 |
+
"single_word": false,
|
| 690 |
+
"special": true
|
| 691 |
+
},
|
| 692 |
+
"32083": {
|
| 693 |
+
"content": "<extra_id_16>",
|
| 694 |
+
"lstrip": true,
|
| 695 |
+
"normalized": false,
|
| 696 |
+
"rstrip": true,
|
| 697 |
+
"single_word": false,
|
| 698 |
+
"special": true
|
| 699 |
+
},
|
| 700 |
+
"32084": {
|
| 701 |
+
"content": "<extra_id_15>",
|
| 702 |
+
"lstrip": true,
|
| 703 |
+
"normalized": false,
|
| 704 |
+
"rstrip": true,
|
| 705 |
+
"single_word": false,
|
| 706 |
+
"special": true
|
| 707 |
+
},
|
| 708 |
+
"32085": {
|
| 709 |
+
"content": "<extra_id_14>",
|
| 710 |
+
"lstrip": true,
|
| 711 |
+
"normalized": false,
|
| 712 |
+
"rstrip": true,
|
| 713 |
+
"single_word": false,
|
| 714 |
+
"special": true
|
| 715 |
+
},
|
| 716 |
+
"32086": {
|
| 717 |
+
"content": "<extra_id_13>",
|
| 718 |
+
"lstrip": true,
|
| 719 |
+
"normalized": false,
|
| 720 |
+
"rstrip": true,
|
| 721 |
+
"single_word": false,
|
| 722 |
+
"special": true
|
| 723 |
+
},
|
| 724 |
+
"32087": {
|
| 725 |
+
"content": "<extra_id_12>",
|
| 726 |
+
"lstrip": true,
|
| 727 |
+
"normalized": false,
|
| 728 |
+
"rstrip": true,
|
| 729 |
+
"single_word": false,
|
| 730 |
+
"special": true
|
| 731 |
+
},
|
| 732 |
+
"32088": {
|
| 733 |
+
"content": "<extra_id_11>",
|
| 734 |
+
"lstrip": true,
|
| 735 |
+
"normalized": false,
|
| 736 |
+
"rstrip": true,
|
| 737 |
+
"single_word": false,
|
| 738 |
+
"special": true
|
| 739 |
+
},
|
| 740 |
+
"32089": {
|
| 741 |
+
"content": "<extra_id_10>",
|
| 742 |
+
"lstrip": true,
|
| 743 |
+
"normalized": false,
|
| 744 |
+
"rstrip": true,
|
| 745 |
+
"single_word": false,
|
| 746 |
+
"special": true
|
| 747 |
+
},
|
| 748 |
+
"32090": {
|
| 749 |
+
"content": "<extra_id_9>",
|
| 750 |
+
"lstrip": true,
|
| 751 |
+
"normalized": false,
|
| 752 |
+
"rstrip": true,
|
| 753 |
+
"single_word": false,
|
| 754 |
+
"special": true
|
| 755 |
+
},
|
| 756 |
+
"32091": {
|
| 757 |
+
"content": "<extra_id_8>",
|
| 758 |
+
"lstrip": true,
|
| 759 |
+
"normalized": false,
|
| 760 |
+
"rstrip": true,
|
| 761 |
+
"single_word": false,
|
| 762 |
+
"special": true
|
| 763 |
+
},
|
| 764 |
+
"32092": {
|
| 765 |
+
"content": "<extra_id_7>",
|
| 766 |
+
"lstrip": true,
|
| 767 |
+
"normalized": false,
|
| 768 |
+
"rstrip": true,
|
| 769 |
+
"single_word": false,
|
| 770 |
+
"special": true
|
| 771 |
+
},
|
| 772 |
+
"32093": {
|
| 773 |
+
"content": "<extra_id_6>",
|
| 774 |
+
"lstrip": true,
|
| 775 |
+
"normalized": false,
|
| 776 |
+
"rstrip": true,
|
| 777 |
+
"single_word": false,
|
| 778 |
+
"special": true
|
| 779 |
+
},
|
| 780 |
+
"32094": {
|
| 781 |
+
"content": "<extra_id_5>",
|
| 782 |
+
"lstrip": true,
|
| 783 |
+
"normalized": false,
|
| 784 |
+
"rstrip": true,
|
| 785 |
+
"single_word": false,
|
| 786 |
+
"special": true
|
| 787 |
+
},
|
| 788 |
+
"32095": {
|
| 789 |
+
"content": "<extra_id_4>",
|
| 790 |
+
"lstrip": true,
|
| 791 |
+
"normalized": false,
|
| 792 |
+
"rstrip": true,
|
| 793 |
+
"single_word": false,
|
| 794 |
+
"special": true
|
| 795 |
+
},
|
| 796 |
+
"32096": {
|
| 797 |
+
"content": "<extra_id_3>",
|
| 798 |
+
"lstrip": true,
|
| 799 |
+
"normalized": false,
|
| 800 |
+
"rstrip": true,
|
| 801 |
+
"single_word": false,
|
| 802 |
+
"special": true
|
| 803 |
+
},
|
| 804 |
+
"32097": {
|
| 805 |
+
"content": "<extra_id_2>",
|
| 806 |
+
"lstrip": true,
|
| 807 |
+
"normalized": false,
|
| 808 |
+
"rstrip": true,
|
| 809 |
+
"single_word": false,
|
| 810 |
+
"special": true
|
| 811 |
+
},
|
| 812 |
+
"32098": {
|
| 813 |
+
"content": "<extra_id_1>",
|
| 814 |
+
"lstrip": true,
|
| 815 |
+
"normalized": false,
|
| 816 |
+
"rstrip": true,
|
| 817 |
+
"single_word": false,
|
| 818 |
+
"special": true
|
| 819 |
+
},
|
| 820 |
+
"32099": {
|
| 821 |
+
"content": "<extra_id_0>",
|
| 822 |
+
"lstrip": true,
|
| 823 |
+
"normalized": false,
|
| 824 |
+
"rstrip": true,
|
| 825 |
+
"single_word": false,
|
| 826 |
+
"special": true
|
| 827 |
+
}
|
| 828 |
+
},
|
| 829 |
+
"additional_special_tokens": [
|
| 830 |
+
"<extra_id_0>",
|
| 831 |
+
"<extra_id_1>",
|
| 832 |
+
"<extra_id_2>",
|
| 833 |
+
"<extra_id_3>",
|
| 834 |
+
"<extra_id_4>",
|
| 835 |
+
"<extra_id_5>",
|
| 836 |
+
"<extra_id_6>",
|
| 837 |
+
"<extra_id_7>",
|
| 838 |
+
"<extra_id_8>",
|
| 839 |
+
"<extra_id_9>",
|
| 840 |
+
"<extra_id_10>",
|
| 841 |
+
"<extra_id_11>",
|
| 842 |
+
"<extra_id_12>",
|
| 843 |
+
"<extra_id_13>",
|
| 844 |
+
"<extra_id_14>",
|
| 845 |
+
"<extra_id_15>",
|
| 846 |
+
"<extra_id_16>",
|
| 847 |
+
"<extra_id_17>",
|
| 848 |
+
"<extra_id_18>",
|
| 849 |
+
"<extra_id_19>",
|
| 850 |
+
"<extra_id_20>",
|
| 851 |
+
"<extra_id_21>",
|
| 852 |
+
"<extra_id_22>",
|
| 853 |
+
"<extra_id_23>",
|
| 854 |
+
"<extra_id_24>",
|
| 855 |
+
"<extra_id_25>",
|
| 856 |
+
"<extra_id_26>",
|
| 857 |
+
"<extra_id_27>",
|
| 858 |
+
"<extra_id_28>",
|
| 859 |
+
"<extra_id_29>",
|
| 860 |
+
"<extra_id_30>",
|
| 861 |
+
"<extra_id_31>",
|
| 862 |
+
"<extra_id_32>",
|
| 863 |
+
"<extra_id_33>",
|
| 864 |
+
"<extra_id_34>",
|
| 865 |
+
"<extra_id_35>",
|
| 866 |
+
"<extra_id_36>",
|
| 867 |
+
"<extra_id_37>",
|
| 868 |
+
"<extra_id_38>",
|
| 869 |
+
"<extra_id_39>",
|
| 870 |
+
"<extra_id_40>",
|
| 871 |
+
"<extra_id_41>",
|
| 872 |
+
"<extra_id_42>",
|
| 873 |
+
"<extra_id_43>",
|
| 874 |
+
"<extra_id_44>",
|
| 875 |
+
"<extra_id_45>",
|
| 876 |
+
"<extra_id_46>",
|
| 877 |
+
"<extra_id_47>",
|
| 878 |
+
"<extra_id_48>",
|
| 879 |
+
"<extra_id_49>",
|
| 880 |
+
"<extra_id_50>",
|
| 881 |
+
"<extra_id_51>",
|
| 882 |
+
"<extra_id_52>",
|
| 883 |
+
"<extra_id_53>",
|
| 884 |
+
"<extra_id_54>",
|
| 885 |
+
"<extra_id_55>",
|
| 886 |
+
"<extra_id_56>",
|
| 887 |
+
"<extra_id_57>",
|
| 888 |
+
"<extra_id_58>",
|
| 889 |
+
"<extra_id_59>",
|
| 890 |
+
"<extra_id_60>",
|
| 891 |
+
"<extra_id_61>",
|
| 892 |
+
"<extra_id_62>",
|
| 893 |
+
"<extra_id_63>",
|
| 894 |
+
"<extra_id_64>",
|
| 895 |
+
"<extra_id_65>",
|
| 896 |
+
"<extra_id_66>",
|
| 897 |
+
"<extra_id_67>",
|
| 898 |
+
"<extra_id_68>",
|
| 899 |
+
"<extra_id_69>",
|
| 900 |
+
"<extra_id_70>",
|
| 901 |
+
"<extra_id_71>",
|
| 902 |
+
"<extra_id_72>",
|
| 903 |
+
"<extra_id_73>",
|
| 904 |
+
"<extra_id_74>",
|
| 905 |
+
"<extra_id_75>",
|
| 906 |
+
"<extra_id_76>",
|
| 907 |
+
"<extra_id_77>",
|
| 908 |
+
"<extra_id_78>",
|
| 909 |
+
"<extra_id_79>",
|
| 910 |
+
"<extra_id_80>",
|
| 911 |
+
"<extra_id_81>",
|
| 912 |
+
"<extra_id_82>",
|
| 913 |
+
"<extra_id_83>",
|
| 914 |
+
"<extra_id_84>",
|
| 915 |
+
"<extra_id_85>",
|
| 916 |
+
"<extra_id_86>",
|
| 917 |
+
"<extra_id_87>",
|
| 918 |
+
"<extra_id_88>",
|
| 919 |
+
"<extra_id_89>",
|
| 920 |
+
"<extra_id_90>",
|
| 921 |
+
"<extra_id_91>",
|
| 922 |
+
"<extra_id_92>",
|
| 923 |
+
"<extra_id_93>",
|
| 924 |
+
"<extra_id_94>",
|
| 925 |
+
"<extra_id_95>",
|
| 926 |
+
"<extra_id_96>",
|
| 927 |
+
"<extra_id_97>",
|
| 928 |
+
"<extra_id_98>",
|
| 929 |
+
"<extra_id_99>"
|
| 930 |
+
],
|
| 931 |
+
"clean_up_tokenization_spaces": true,
|
| 932 |
+
"eos_token": "</s>",
|
| 933 |
+
"extra_ids": 100,
|
| 934 |
+
"legacy": true,
|
| 935 |
+
"model_max_length": 512,
|
| 936 |
+
"pad_token": "<pad>",
|
| 937 |
+
"sp_model_kwargs": {},
|
| 938 |
+
"tokenizer_class": "T5Tokenizer",
|
| 939 |
+
"unk_token": "<unk>"
|
| 940 |
+
}
|
transformer/forward/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "SD3Transformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.38.0",
|
| 4 |
+
"attention_head_dim": 64,
|
| 5 |
+
"caption_projection_dim": 1536,
|
| 6 |
+
"dual_attention_layers": [
|
| 7 |
+
0,
|
| 8 |
+
1,
|
| 9 |
+
2,
|
| 10 |
+
3,
|
| 11 |
+
4,
|
| 12 |
+
5,
|
| 13 |
+
6,
|
| 14 |
+
7,
|
| 15 |
+
8,
|
| 16 |
+
9,
|
| 17 |
+
10,
|
| 18 |
+
11,
|
| 19 |
+
12
|
| 20 |
+
],
|
| 21 |
+
"in_channels": 96,
|
| 22 |
+
"joint_attention_dim": 4096,
|
| 23 |
+
"num_attention_heads": 24,
|
| 24 |
+
"num_layers": 24,
|
| 25 |
+
"out_channels": 16,
|
| 26 |
+
"patch_size": 2,
|
| 27 |
+
"pooled_projection_dim": 2048,
|
| 28 |
+
"pos_embed_max_size": 384,
|
| 29 |
+
"qk_norm": "rms_norm",
|
| 30 |
+
"sample_size": 128
|
| 31 |
+
}
|
transformer/forward/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1f917f113024a1a10ad868d578d522639296062f937e0f7f8b8b8b31ec9de38
|
| 3 |
+
size 9880726944
|
transformer/forward/lora/pytorch_lora_weights.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7947c779fa0b3bba7124a21a5ffa29507ca09cef45ffd8e84f6a8eb738afa7fc
|
| 3 |
+
size 75154632
|
transformer/inverse-512/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "IntrinsicWeatherSD3Transformer2DModel",
|
| 3 |
+
"_diffusers_version": "0.38.0",
|
| 4 |
+
"attention_head_dim": 64,
|
| 5 |
+
"caption_projection_dim": 1536,
|
| 6 |
+
"dual_attention_layers": [
|
| 7 |
+
0,
|
| 8 |
+
1,
|
| 9 |
+
2,
|
| 10 |
+
3,
|
| 11 |
+
4,
|
| 12 |
+
5,
|
| 13 |
+
6,
|
| 14 |
+
7,
|
| 15 |
+
8,
|
| 16 |
+
9,
|
| 17 |
+
10,
|
| 18 |
+
11,
|
| 19 |
+
12
|
| 20 |
+
],
|
| 21 |
+
"in_channels": 32,
|
| 22 |
+
"joint_attention_dim": 4096,
|
| 23 |
+
"num_attention_heads": 24,
|
| 24 |
+
"num_layers": 24,
|
| 25 |
+
"out_channels": 16,
|
| 26 |
+
"patch_size": 2,
|
| 27 |
+
"pooled_projection_dim": 2048,
|
| 28 |
+
"pos_embed_max_size": 384,
|
| 29 |
+
"qk_norm": "rms_norm",
|
| 30 |
+
"sample_size": 128
|
| 31 |
+
}
|
transformer/inverse-512/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1cdcd40fc9770b2397b533ff4fd3132bfee99990ae59f31f22895fbdae76d797
|
| 3 |
+
size 9879154080
|
transformer/inverse-512/transformer_intrinsic_weather.py
ADDED
|
@@ -0,0 +1,1527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch import nn as torch_nn
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
| 27 |
+
from diffusers.models.attention import FeedForward, JointTransformerBlock, _chunked_feed_forward
|
| 28 |
+
from diffusers.models.attention_processor import (
|
| 29 |
+
Attention,
|
| 30 |
+
AttentionProcessor,
|
| 31 |
+
AttnProcessor,
|
| 32 |
+
AttnProcessor2_0,
|
| 33 |
+
FusedJointAttnProcessor2_0,
|
| 34 |
+
JointAttnProcessor2_0,
|
| 35 |
+
SpatialNorm,
|
| 36 |
+
)
|
| 37 |
+
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed, SinusoidalPositionalEmbedding
|
| 38 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 39 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 40 |
+
from diffusers.models.normalization import (
|
| 41 |
+
AdaLayerNorm,
|
| 42 |
+
AdaLayerNormContinuous,
|
| 43 |
+
AdaLayerNormZero,
|
| 44 |
+
RMSNorm,
|
| 45 |
+
SD35AdaLayerNormZeroX,
|
| 46 |
+
)
|
| 47 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 48 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MapAwareAttention(nn.Module):
|
| 54 |
+
r"""
|
| 55 |
+
A cross attention layer.
|
| 56 |
+
|
| 57 |
+
Parameters:
|
| 58 |
+
query_dim (`int`):
|
| 59 |
+
The number of channels in the query.
|
| 60 |
+
cross_attention_dim (`int`, *optional*):
|
| 61 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
| 62 |
+
heads (`int`, *optional*, defaults to 8):
|
| 63 |
+
The number of heads to use for multi-head attention.
|
| 64 |
+
kv_heads (`int`, *optional*, defaults to `None`):
|
| 65 |
+
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
|
| 66 |
+
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
|
| 67 |
+
Query Attention (MQA) otherwise GQA is used.
|
| 68 |
+
dim_head (`int`, *optional*, defaults to 64):
|
| 69 |
+
The number of channels in each head.
|
| 70 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 71 |
+
The dropout probability to use.
|
| 72 |
+
bias (`bool`, *optional*, defaults to False):
|
| 73 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
| 74 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
| 75 |
+
Set to `True` to upcast the attention computation to `float32`.
|
| 76 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
| 77 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
| 78 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
| 79 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
| 80 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
| 81 |
+
The number of groups to use for the group norm in the cross attention.
|
| 82 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 83 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 84 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
| 85 |
+
The number of groups to use for the group norm in the attention.
|
| 86 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
| 87 |
+
The number of channels to use for the spatial normalization.
|
| 88 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
| 89 |
+
Set to `True` to use a bias in the output linear layer.
|
| 90 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
| 91 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
| 92 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
| 94 |
+
`added_kv_proj_dim` is not `None`.
|
| 95 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
| 96 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
| 97 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
| 98 |
+
A factor to rescale the output by dividing it with this value.
|
| 99 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
| 100 |
+
Set to `True` to add the residual connection to the output.
|
| 101 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
| 102 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
| 103 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
| 104 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
| 105 |
+
`AttnProcessor` otherwise.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
query_dim: int,
|
| 111 |
+
cross_attention_dim: Optional[int] = None,
|
| 112 |
+
heads: int = 8,
|
| 113 |
+
kv_heads: Optional[int] = None,
|
| 114 |
+
dim_head: int = 64,
|
| 115 |
+
dropout: float = 0.0,
|
| 116 |
+
bias: bool = False,
|
| 117 |
+
upcast_attention: bool = False,
|
| 118 |
+
upcast_softmax: bool = False,
|
| 119 |
+
cross_attention_norm: Optional[str] = None,
|
| 120 |
+
cross_attention_norm_num_groups: int = 32,
|
| 121 |
+
qk_norm: Optional[str] = None,
|
| 122 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 123 |
+
added_proj_bias: Optional[bool] = True,
|
| 124 |
+
norm_num_groups: Optional[int] = None,
|
| 125 |
+
spatial_norm_dim: Optional[int] = None,
|
| 126 |
+
out_bias: bool = True,
|
| 127 |
+
scale_qk: bool = True,
|
| 128 |
+
only_cross_attention: bool = False,
|
| 129 |
+
eps: float = 1e-5,
|
| 130 |
+
rescale_output_factor: float = 1.0,
|
| 131 |
+
residual_connection: bool = False,
|
| 132 |
+
_from_deprecated_attn_block: bool = False,
|
| 133 |
+
processor: Optional["AttnProcessor"] = None,
|
| 134 |
+
out_dim: int = None,
|
| 135 |
+
out_context_dim: int = None,
|
| 136 |
+
context_pre_only=None,
|
| 137 |
+
pre_only=False,
|
| 138 |
+
elementwise_affine: bool = True,
|
| 139 |
+
is_causal: bool = False,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
|
| 143 |
+
# To prevent circular import.
|
| 144 |
+
from diffusers.models.normalization import FP32LayerNorm, LpNorm, RMSNorm
|
| 145 |
+
|
| 146 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 147 |
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
| 148 |
+
self.query_dim = query_dim
|
| 149 |
+
self.use_bias = bias
|
| 150 |
+
self.is_cross_attention = cross_attention_dim is not None
|
| 151 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 152 |
+
self.upcast_attention = upcast_attention
|
| 153 |
+
self.upcast_softmax = upcast_softmax
|
| 154 |
+
self.rescale_output_factor = rescale_output_factor
|
| 155 |
+
self.residual_connection = residual_connection
|
| 156 |
+
self.dropout = dropout
|
| 157 |
+
self.fused_projections = False
|
| 158 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 159 |
+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
| 160 |
+
self.context_pre_only = context_pre_only
|
| 161 |
+
self.pre_only = pre_only
|
| 162 |
+
self.is_causal = is_causal
|
| 163 |
+
|
| 164 |
+
# we make use of this private variable to know whether this class is loaded
|
| 165 |
+
# with an deprecated state dict so that we can convert it on the fly
|
| 166 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
| 167 |
+
|
| 168 |
+
self.scale_qk = scale_qk
|
| 169 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 170 |
+
|
| 171 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 172 |
+
# for slice_size > 0 the attention score computation
|
| 173 |
+
# is split across the batch axis to save memory
|
| 174 |
+
# You can set slice_size with `set_attention_slice`
|
| 175 |
+
self.sliceable_head_dim = heads
|
| 176 |
+
|
| 177 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 178 |
+
self.only_cross_attention = only_cross_attention
|
| 179 |
+
|
| 180 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 181 |
+
raise ValueError(
|
| 182 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if norm_num_groups is not None:
|
| 186 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
| 187 |
+
else:
|
| 188 |
+
self.group_norm = None
|
| 189 |
+
|
| 190 |
+
if spatial_norm_dim is not None:
|
| 191 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
| 192 |
+
else:
|
| 193 |
+
self.spatial_norm = None
|
| 194 |
+
|
| 195 |
+
if qk_norm is None:
|
| 196 |
+
self.norm_q = None
|
| 197 |
+
self.norm_k = None
|
| 198 |
+
elif qk_norm == "layer_norm":
|
| 199 |
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 200 |
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 201 |
+
elif qk_norm == "fp32_layer_norm":
|
| 202 |
+
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
| 203 |
+
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
| 204 |
+
elif qk_norm == "layer_norm_across_heads":
|
| 205 |
+
# Lumina applies qk norm across all heads
|
| 206 |
+
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
| 207 |
+
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
| 208 |
+
elif qk_norm == "rms_norm":
|
| 209 |
+
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 210 |
+
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 211 |
+
elif qk_norm == "rms_norm_across_heads":
|
| 212 |
+
# LTX applies qk norm across all heads
|
| 213 |
+
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
| 214 |
+
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
| 215 |
+
elif qk_norm == "l2":
|
| 216 |
+
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
| 217 |
+
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
| 218 |
+
else:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if cross_attention_norm is None:
|
| 224 |
+
self.norm_cross = None
|
| 225 |
+
elif cross_attention_norm == "layer_norm":
|
| 226 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
| 227 |
+
elif cross_attention_norm == "group_norm":
|
| 228 |
+
if self.added_kv_proj_dim is not None:
|
| 229 |
+
# The given `encoder_hidden_states` are initially of shape
|
| 230 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
| 231 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
| 232 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
| 233 |
+
# the number of channels for the group norm.
|
| 234 |
+
norm_cross_num_channels = added_kv_proj_dim
|
| 235 |
+
else:
|
| 236 |
+
norm_cross_num_channels = self.cross_attention_dim
|
| 237 |
+
|
| 238 |
+
self.norm_cross = nn.GroupNorm(
|
| 239 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 247 |
+
|
| 248 |
+
if not self.only_cross_attention:
|
| 249 |
+
# only relevant for the `AddedKVProcessor` classes
|
| 250 |
+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
| 251 |
+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
| 252 |
+
else:
|
| 253 |
+
self.to_k = None
|
| 254 |
+
self.to_v = None
|
| 255 |
+
|
| 256 |
+
self.added_proj_bias = added_proj_bias
|
| 257 |
+
if self.added_kv_proj_dim is not None:
|
| 258 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
| 259 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
| 260 |
+
if self.context_pre_only is not None:
|
| 261 |
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 262 |
+
else:
|
| 263 |
+
self.add_q_proj = None
|
| 264 |
+
self.add_k_proj = None
|
| 265 |
+
self.add_v_proj = None
|
| 266 |
+
|
| 267 |
+
if not self.pre_only:
|
| 268 |
+
self.to_out = nn.ModuleList([])
|
| 269 |
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 270 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 271 |
+
else:
|
| 272 |
+
self.to_out = None
|
| 273 |
+
|
| 274 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
| 275 |
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
| 276 |
+
else:
|
| 277 |
+
self.to_add_out = None
|
| 278 |
+
|
| 279 |
+
if qk_norm is not None and added_kv_proj_dim is not None:
|
| 280 |
+
if qk_norm == "layer_norm":
|
| 281 |
+
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 282 |
+
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 283 |
+
elif qk_norm == "fp32_layer_norm":
|
| 284 |
+
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
| 285 |
+
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
| 286 |
+
elif qk_norm == "rms_norm":
|
| 287 |
+
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
| 288 |
+
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
| 289 |
+
elif qk_norm == "rms_norm_across_heads":
|
| 290 |
+
# Wan applies qk norm across all heads
|
| 291 |
+
# Wan also doesn't apply a q norm
|
| 292 |
+
self.norm_added_q = None
|
| 293 |
+
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
| 294 |
+
else:
|
| 295 |
+
raise ValueError(
|
| 296 |
+
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
self.norm_added_q = None
|
| 300 |
+
self.norm_added_k = None
|
| 301 |
+
|
| 302 |
+
# set attention processor
|
| 303 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 304 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 305 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 306 |
+
if processor is None:
|
| 307 |
+
processor = (
|
| 308 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
| 309 |
+
)
|
| 310 |
+
self.set_processor(processor)
|
| 311 |
+
|
| 312 |
+
# def set_use_xla_flash_attention(
|
| 313 |
+
# self,
|
| 314 |
+
# use_xla_flash_attention: bool,
|
| 315 |
+
# partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
| 316 |
+
# is_flux=False,
|
| 317 |
+
# ) -> None:
|
| 318 |
+
# r"""
|
| 319 |
+
# Set whether to use xla flash attention from `torch_xla` or not.
|
| 320 |
+
|
| 321 |
+
# Args:
|
| 322 |
+
# use_xla_flash_attention (`bool`):
|
| 323 |
+
# Whether to use pallas flash attention kernel from `torch_xla` or not.
|
| 324 |
+
# partition_spec (`Tuple[]`, *optional*):
|
| 325 |
+
# Specify the partition specification if using SPMD. Otherwise None.
|
| 326 |
+
# """
|
| 327 |
+
# if use_xla_flash_attention:
|
| 328 |
+
# if not is_torch_xla_available:
|
| 329 |
+
# raise "torch_xla is not available"
|
| 330 |
+
# elif is_torch_xla_version("<", "2.3"):
|
| 331 |
+
# raise "flash attention pallas kernel is supported from torch_xla version 2.3"
|
| 332 |
+
# elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
| 333 |
+
# raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
| 334 |
+
# else:
|
| 335 |
+
# if is_flux:
|
| 336 |
+
# processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
|
| 337 |
+
# else:
|
| 338 |
+
# processor = XLAFlashAttnProcessor2_0(partition_spec)
|
| 339 |
+
# else:
|
| 340 |
+
# processor = (
|
| 341 |
+
# AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
| 342 |
+
# )
|
| 343 |
+
# self.set_processor(processor)
|
| 344 |
+
|
| 345 |
+
# def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
| 346 |
+
# r"""
|
| 347 |
+
# Set whether to use npu flash attention from `torch_npu` or not.
|
| 348 |
+
|
| 349 |
+
# """
|
| 350 |
+
# if use_npu_flash_attention:
|
| 351 |
+
# processor = AttnProcessorNPU()
|
| 352 |
+
# else:
|
| 353 |
+
# # set attention processor
|
| 354 |
+
# # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 355 |
+
# # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 356 |
+
# # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 357 |
+
# processor = (
|
| 358 |
+
# AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
| 359 |
+
# )
|
| 360 |
+
# self.set_processor(processor)
|
| 361 |
+
|
| 362 |
+
# def set_use_memory_efficient_attention_xformers(
|
| 363 |
+
# self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
| 364 |
+
# ) -> None:
|
| 365 |
+
# r"""
|
| 366 |
+
# Set whether to use memory efficient attention from `xformers` or not.
|
| 367 |
+
|
| 368 |
+
# Args:
|
| 369 |
+
# use_memory_efficient_attention_xformers (`bool`):
|
| 370 |
+
# Whether to use memory efficient attention from `xformers` or not.
|
| 371 |
+
# attention_op (`Callable`, *optional*):
|
| 372 |
+
# The attention operation to use. Defaults to `None` which uses the default attention operation from
|
| 373 |
+
# `xformers`.
|
| 374 |
+
# """
|
| 375 |
+
# is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
| 376 |
+
# self.processor,
|
| 377 |
+
# (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
|
| 378 |
+
# )
|
| 379 |
+
# is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
| 380 |
+
# self.processor,
|
| 381 |
+
# (
|
| 382 |
+
# AttnAddedKVProcessor,
|
| 383 |
+
# AttnAddedKVProcessor2_0,
|
| 384 |
+
# SlicedAttnAddedKVProcessor,
|
| 385 |
+
# XFormersAttnAddedKVProcessor,
|
| 386 |
+
# ),
|
| 387 |
+
# )
|
| 388 |
+
# is_ip_adapter = hasattr(self, "processor") and isinstance(
|
| 389 |
+
# self.processor,
|
| 390 |
+
# (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
| 391 |
+
# )
|
| 392 |
+
# is_joint_processor = hasattr(self, "processor") and isinstance(
|
| 393 |
+
# self.processor,
|
| 394 |
+
# (
|
| 395 |
+
# JointAttnProcessor2_0,
|
| 396 |
+
# XFormersJointAttnProcessor,
|
| 397 |
+
# ),
|
| 398 |
+
# )
|
| 399 |
+
|
| 400 |
+
# if use_memory_efficient_attention_xformers:
|
| 401 |
+
# if is_added_kv_processor and is_custom_diffusion:
|
| 402 |
+
# raise NotImplementedError(
|
| 403 |
+
# f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
|
| 404 |
+
# )
|
| 405 |
+
# if not is_xformers_available():
|
| 406 |
+
# raise ModuleNotFoundError(
|
| 407 |
+
# (
|
| 408 |
+
# "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 409 |
+
# " xformers"
|
| 410 |
+
# ),
|
| 411 |
+
# name="xformers",
|
| 412 |
+
# )
|
| 413 |
+
# elif not torch.cuda.is_available():
|
| 414 |
+
# raise ValueError(
|
| 415 |
+
# "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
| 416 |
+
# " only available for GPU "
|
| 417 |
+
# )
|
| 418 |
+
# else:
|
| 419 |
+
# try:
|
| 420 |
+
# # Make sure we can run the memory efficient attention
|
| 421 |
+
# dtype = None
|
| 422 |
+
# if attention_op is not None:
|
| 423 |
+
# op_fw, op_bw = attention_op
|
| 424 |
+
# dtype, *_ = op_fw.SUPPORTED_DTYPES
|
| 425 |
+
# q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
| 426 |
+
# _ = xformers.ops.memory_efficient_attention(q, q, q)
|
| 427 |
+
# except Exception as e:
|
| 428 |
+
# raise e
|
| 429 |
+
|
| 430 |
+
# if is_custom_diffusion:
|
| 431 |
+
# processor = CustomDiffusionXFormersAttnProcessor(
|
| 432 |
+
# train_kv=self.processor.train_kv,
|
| 433 |
+
# train_q_out=self.processor.train_q_out,
|
| 434 |
+
# hidden_size=self.processor.hidden_size,
|
| 435 |
+
# cross_attention_dim=self.processor.cross_attention_dim,
|
| 436 |
+
# attention_op=attention_op,
|
| 437 |
+
# )
|
| 438 |
+
# processor.load_state_dict(self.processor.state_dict())
|
| 439 |
+
# if hasattr(self.processor, "to_k_custom_diffusion"):
|
| 440 |
+
# processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
| 441 |
+
# elif is_added_kv_processor:
|
| 442 |
+
# # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
| 443 |
+
# # which uses this type of cross attention ONLY because the attention mask of format
|
| 444 |
+
# # [0, ..., -10.000, ..., 0, ...,] is not supported
|
| 445 |
+
# # throw warning
|
| 446 |
+
# logger.info(
|
| 447 |
+
# "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
| 448 |
+
# )
|
| 449 |
+
# processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
| 450 |
+
# elif is_ip_adapter:
|
| 451 |
+
# processor = IPAdapterXFormersAttnProcessor(
|
| 452 |
+
# hidden_size=self.processor.hidden_size,
|
| 453 |
+
# cross_attention_dim=self.processor.cross_attention_dim,
|
| 454 |
+
# num_tokens=self.processor.num_tokens,
|
| 455 |
+
# scale=self.processor.scale,
|
| 456 |
+
# attention_op=attention_op,
|
| 457 |
+
# )
|
| 458 |
+
# processor.load_state_dict(self.processor.state_dict())
|
| 459 |
+
# if hasattr(self.processor, "to_k_ip"):
|
| 460 |
+
# processor.to(
|
| 461 |
+
# device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
| 462 |
+
# )
|
| 463 |
+
# elif is_joint_processor:
|
| 464 |
+
# processor = XFormersJointAttnProcessor(attention_op=attention_op)
|
| 465 |
+
# else:
|
| 466 |
+
# processor = XFormersAttnProcessor(attention_op=attention_op)
|
| 467 |
+
# else:
|
| 468 |
+
# if is_custom_diffusion:
|
| 469 |
+
# attn_processor_class = (
|
| 470 |
+
# CustomDiffusionAttnProcessor2_0
|
| 471 |
+
# if hasattr(F, "scaled_dot_product_attention")
|
| 472 |
+
# else CustomDiffusionAttnProcessor
|
| 473 |
+
# )
|
| 474 |
+
# processor = attn_processor_class(
|
| 475 |
+
# train_kv=self.processor.train_kv,
|
| 476 |
+
# train_q_out=self.processor.train_q_out,
|
| 477 |
+
# hidden_size=self.processor.hidden_size,
|
| 478 |
+
# cross_attention_dim=self.processor.cross_attention_dim,
|
| 479 |
+
# )
|
| 480 |
+
# processor.load_state_dict(self.processor.state_dict())
|
| 481 |
+
# if hasattr(self.processor, "to_k_custom_diffusion"):
|
| 482 |
+
# processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
| 483 |
+
# elif is_ip_adapter:
|
| 484 |
+
# processor = IPAdapterAttnProcessor2_0(
|
| 485 |
+
# hidden_size=self.processor.hidden_size,
|
| 486 |
+
# cross_attention_dim=self.processor.cross_attention_dim,
|
| 487 |
+
# num_tokens=self.processor.num_tokens,
|
| 488 |
+
# scale=self.processor.scale,
|
| 489 |
+
# )
|
| 490 |
+
# processor.load_state_dict(self.processor.state_dict())
|
| 491 |
+
# if hasattr(self.processor, "to_k_ip"):
|
| 492 |
+
# processor.to(
|
| 493 |
+
# device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
| 494 |
+
# )
|
| 495 |
+
# else:
|
| 496 |
+
# # set attention processor
|
| 497 |
+
# # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 498 |
+
# # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 499 |
+
# # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 500 |
+
# processor = (
|
| 501 |
+
# AttnProcessor2_0()
|
| 502 |
+
# if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
| 503 |
+
# else AttnProcessor()
|
| 504 |
+
# )
|
| 505 |
+
|
| 506 |
+
# self.set_processor(processor)
|
| 507 |
+
|
| 508 |
+
# def set_attention_slice(self, slice_size: int) -> None:
|
| 509 |
+
# r"""
|
| 510 |
+
# Set the slice size for attention computation.
|
| 511 |
+
|
| 512 |
+
# Args:
|
| 513 |
+
# slice_size (`int`):
|
| 514 |
+
# The slice size for attention computation.
|
| 515 |
+
# """
|
| 516 |
+
# if slice_size is not None and slice_size > self.sliceable_head_dim:
|
| 517 |
+
# raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
| 518 |
+
|
| 519 |
+
# if slice_size is not None and self.added_kv_proj_dim is not None:
|
| 520 |
+
# processor = SlicedAttnAddedKVProcessor(slice_size)
|
| 521 |
+
# elif slice_size is not None:
|
| 522 |
+
# processor = SlicedAttnProcessor(slice_size)
|
| 523 |
+
# elif self.added_kv_proj_dim is not None:
|
| 524 |
+
# processor = AttnAddedKVProcessor()
|
| 525 |
+
# else:
|
| 526 |
+
# # set attention processor
|
| 527 |
+
# # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 528 |
+
# # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 529 |
+
# # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 530 |
+
# processor = (
|
| 531 |
+
# AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
| 532 |
+
# )
|
| 533 |
+
|
| 534 |
+
# self.set_processor(processor)
|
| 535 |
+
|
| 536 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 537 |
+
r"""
|
| 538 |
+
Set the attention processor to use.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
processor (`AttnProcessor`):
|
| 542 |
+
The attention processor to use.
|
| 543 |
+
"""
|
| 544 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 545 |
+
# pop `processor` from `self._modules`
|
| 546 |
+
if (
|
| 547 |
+
hasattr(self, "processor")
|
| 548 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 549 |
+
and not isinstance(processor, torch.nn.Module)
|
| 550 |
+
):
|
| 551 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
| 552 |
+
self._modules.pop("processor")
|
| 553 |
+
|
| 554 |
+
self.processor = processor
|
| 555 |
+
|
| 556 |
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
| 557 |
+
r"""
|
| 558 |
+
Get the attention processor in use.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 562 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
"AttentionProcessor": The attention processor in use.
|
| 566 |
+
"""
|
| 567 |
+
if not return_deprecated_lora:
|
| 568 |
+
return self.processor
|
| 569 |
+
|
| 570 |
+
def forward(
|
| 571 |
+
self,
|
| 572 |
+
hidden_states: torch.Tensor,
|
| 573 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 574 |
+
map_aware_mask: Optional[torch.FloatTensor] = None,
|
| 575 |
+
**cross_attention_kwargs,
|
| 576 |
+
) -> torch.Tensor:
|
| 577 |
+
r"""
|
| 578 |
+
The forward method of the `Attention` class.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
hidden_states (`torch.Tensor`):
|
| 582 |
+
The hidden states of the query.
|
| 583 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
| 584 |
+
The hidden states of the encoder.
|
| 585 |
+
map_aware_mask (`torch.Tensor`, *optional*):
|
| 586 |
+
The attention mask to use. If `None`, no mask is applied.
|
| 587 |
+
**cross_attention_kwargs:
|
| 588 |
+
Additional keyword arguments to pass along to the cross attention.
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
`torch.Tensor`: The output of the attention layer.
|
| 592 |
+
"""
|
| 593 |
+
# The `Attention` class can call different attention processors / attention functions
|
| 594 |
+
# here we simply pass along all tensors to the selected processor class
|
| 595 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 596 |
+
|
| 597 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 598 |
+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
| 599 |
+
unused_kwargs = [
|
| 600 |
+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
| 601 |
+
]
|
| 602 |
+
if len(unused_kwargs) > 0:
|
| 603 |
+
logger.warning(
|
| 604 |
+
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 605 |
+
)
|
| 606 |
+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
| 607 |
+
|
| 608 |
+
return self.processor(
|
| 609 |
+
self,
|
| 610 |
+
hidden_states,
|
| 611 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 612 |
+
attention_mask=map_aware_mask,
|
| 613 |
+
**cross_attention_kwargs,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 617 |
+
r"""
|
| 618 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
| 619 |
+
is the number of heads initialized while constructing the `Attention` class.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
`torch.Tensor`: The reshaped tensor.
|
| 626 |
+
"""
|
| 627 |
+
head_size = self.heads
|
| 628 |
+
batch_size, seq_len, dim = tensor.shape
|
| 629 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 630 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 631 |
+
return tensor
|
| 632 |
+
|
| 633 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
| 634 |
+
r"""
|
| 635 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
| 636 |
+
the number of heads initialized while constructing the `Attention` class.
|
| 637 |
+
|
| 638 |
+
Args:
|
| 639 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 640 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
| 641 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
`torch.Tensor`: The reshaped tensor.
|
| 645 |
+
"""
|
| 646 |
+
head_size = self.heads
|
| 647 |
+
if tensor.ndim == 3:
|
| 648 |
+
batch_size, seq_len, dim = tensor.shape
|
| 649 |
+
extra_dim = 1
|
| 650 |
+
else:
|
| 651 |
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
| 652 |
+
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
| 653 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 654 |
+
|
| 655 |
+
if out_dim == 3:
|
| 656 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
| 657 |
+
|
| 658 |
+
return tensor
|
| 659 |
+
|
| 660 |
+
def get_attention_scores(
|
| 661 |
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
| 662 |
+
) -> torch.Tensor:
|
| 663 |
+
r"""
|
| 664 |
+
Compute the attention scores.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
query (`torch.Tensor`): The query tensor.
|
| 668 |
+
key (`torch.Tensor`): The key tensor.
|
| 669 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
`torch.Tensor`: The attention probabilities/scores.
|
| 673 |
+
"""
|
| 674 |
+
dtype = query.dtype
|
| 675 |
+
if self.upcast_attention:
|
| 676 |
+
query = query.float()
|
| 677 |
+
key = key.float()
|
| 678 |
+
|
| 679 |
+
if attention_mask is None:
|
| 680 |
+
baddbmm_input = torch.empty(
|
| 681 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
| 682 |
+
)
|
| 683 |
+
beta = 0
|
| 684 |
+
else:
|
| 685 |
+
baddbmm_input = attention_mask
|
| 686 |
+
beta = 1
|
| 687 |
+
|
| 688 |
+
attention_scores = torch.baddbmm(
|
| 689 |
+
baddbmm_input,
|
| 690 |
+
query,
|
| 691 |
+
key.transpose(-1, -2),
|
| 692 |
+
beta=beta,
|
| 693 |
+
alpha=self.scale,
|
| 694 |
+
)
|
| 695 |
+
del baddbmm_input
|
| 696 |
+
|
| 697 |
+
if self.upcast_softmax:
|
| 698 |
+
attention_scores = attention_scores.float()
|
| 699 |
+
|
| 700 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 701 |
+
del attention_scores
|
| 702 |
+
|
| 703 |
+
attention_probs = attention_probs.to(dtype)
|
| 704 |
+
|
| 705 |
+
return attention_probs
|
| 706 |
+
|
| 707 |
+
def prepare_attention_mask(
|
| 708 |
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
| 709 |
+
) -> torch.Tensor:
|
| 710 |
+
r"""
|
| 711 |
+
Prepare the attention mask for the attention computation.
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
attention_mask (`torch.Tensor`):
|
| 715 |
+
The attention mask to prepare.
|
| 716 |
+
target_length (`int`):
|
| 717 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
| 718 |
+
batch_size (`int`):
|
| 719 |
+
The batch size, which is used to repeat the attention mask.
|
| 720 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
| 721 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
| 722 |
+
|
| 723 |
+
Returns:
|
| 724 |
+
`torch.Tensor`: The prepared attention mask.
|
| 725 |
+
"""
|
| 726 |
+
head_size = self.heads
|
| 727 |
+
if attention_mask is None:
|
| 728 |
+
return attention_mask
|
| 729 |
+
|
| 730 |
+
current_length: int = attention_mask.shape[-1]
|
| 731 |
+
if current_length != target_length:
|
| 732 |
+
if attention_mask.device.type == "mps":
|
| 733 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 734 |
+
# Instead, we can manually construct the padding tensor.
|
| 735 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
| 736 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
| 737 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 738 |
+
else:
|
| 739 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
| 740 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
| 741 |
+
# remaining_length: int = target_length - current_length
|
| 742 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
| 743 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 744 |
+
|
| 745 |
+
if out_dim == 3:
|
| 746 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
| 747 |
+
attention_mask = attention_mask.repeat_interleave(
|
| 748 |
+
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
|
| 749 |
+
)
|
| 750 |
+
elif out_dim == 4:
|
| 751 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 752 |
+
attention_mask = attention_mask.repeat_interleave(
|
| 753 |
+
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
return attention_mask
|
| 757 |
+
|
| 758 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
| 759 |
+
r"""
|
| 760 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
| 761 |
+
`Attention` class.
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
| 768 |
+
"""
|
| 769 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
| 770 |
+
|
| 771 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 772 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 773 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
| 774 |
+
# Group norm norms along the channels dimension and expects
|
| 775 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
| 776 |
+
# to norm along the hidden dimension, so we need to move
|
| 777 |
+
# (batch_size, sequence_length, hidden_size) ->
|
| 778 |
+
# (batch_size, hidden_size, sequence_length)
|
| 779 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 780 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 781 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 782 |
+
else:
|
| 783 |
+
assert False
|
| 784 |
+
|
| 785 |
+
return encoder_hidden_states
|
| 786 |
+
|
| 787 |
+
@torch.no_grad()
|
| 788 |
+
def fuse_projections(self, fuse=True):
|
| 789 |
+
device = self.to_q.weight.data.device
|
| 790 |
+
dtype = self.to_q.weight.data.dtype
|
| 791 |
+
|
| 792 |
+
if not self.is_cross_attention:
|
| 793 |
+
# fetch weight matrices.
|
| 794 |
+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
| 795 |
+
in_features = concatenated_weights.shape[1]
|
| 796 |
+
out_features = concatenated_weights.shape[0]
|
| 797 |
+
|
| 798 |
+
# create a new single projection layer and copy over the weights.
|
| 799 |
+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
| 800 |
+
self.to_qkv.weight.copy_(concatenated_weights)
|
| 801 |
+
if self.use_bias:
|
| 802 |
+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
| 803 |
+
self.to_qkv.bias.copy_(concatenated_bias)
|
| 804 |
+
|
| 805 |
+
else:
|
| 806 |
+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
| 807 |
+
in_features = concatenated_weights.shape[1]
|
| 808 |
+
out_features = concatenated_weights.shape[0]
|
| 809 |
+
|
| 810 |
+
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
| 811 |
+
self.to_kv.weight.copy_(concatenated_weights)
|
| 812 |
+
if self.use_bias:
|
| 813 |
+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
| 814 |
+
self.to_kv.bias.copy_(concatenated_bias)
|
| 815 |
+
|
| 816 |
+
# handle added projections for SD3 and others.
|
| 817 |
+
if (
|
| 818 |
+
getattr(self, "add_q_proj", None) is not None
|
| 819 |
+
and getattr(self, "add_k_proj", None) is not None
|
| 820 |
+
and getattr(self, "add_v_proj", None) is not None
|
| 821 |
+
):
|
| 822 |
+
concatenated_weights = torch.cat(
|
| 823 |
+
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
| 824 |
+
)
|
| 825 |
+
in_features = concatenated_weights.shape[1]
|
| 826 |
+
out_features = concatenated_weights.shape[0]
|
| 827 |
+
|
| 828 |
+
self.to_added_qkv = nn.Linear(
|
| 829 |
+
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
| 830 |
+
)
|
| 831 |
+
self.to_added_qkv.weight.copy_(concatenated_weights)
|
| 832 |
+
if self.added_proj_bias:
|
| 833 |
+
concatenated_bias = torch.cat(
|
| 834 |
+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
| 835 |
+
)
|
| 836 |
+
self.to_added_qkv.bias.copy_(concatenated_bias)
|
| 837 |
+
|
| 838 |
+
self.fused_projections = fuse
|
| 839 |
+
|
| 840 |
+
class MapAwareAttnProcessor2_0:
|
| 841 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
| 842 |
+
|
| 843 |
+
def __init__(self):
|
| 844 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 845 |
+
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 846 |
+
|
| 847 |
+
def __call__(
|
| 848 |
+
self,
|
| 849 |
+
attn: Attention,
|
| 850 |
+
hidden_states: torch.FloatTensor,
|
| 851 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
| 852 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 853 |
+
*args,
|
| 854 |
+
**kwargs,
|
| 855 |
+
) -> torch.FloatTensor:
|
| 856 |
+
# print("attention_mask: ", attention_mask)
|
| 857 |
+
residual = hidden_states
|
| 858 |
+
|
| 859 |
+
batch_size = hidden_states.shape[0]
|
| 860 |
+
|
| 861 |
+
# `sample` projections.
|
| 862 |
+
query = attn.to_q(hidden_states)
|
| 863 |
+
key = attn.to_k(hidden_states)
|
| 864 |
+
value = attn.to_v(hidden_states)
|
| 865 |
+
|
| 866 |
+
inner_dim = key.shape[-1]
|
| 867 |
+
head_dim = inner_dim // attn.heads
|
| 868 |
+
|
| 869 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 870 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 871 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 872 |
+
|
| 873 |
+
if attn.norm_q is not None:
|
| 874 |
+
query = attn.norm_q(query)
|
| 875 |
+
if attn.norm_k is not None:
|
| 876 |
+
key = attn.norm_k(key)
|
| 877 |
+
|
| 878 |
+
# `context` projections.
|
| 879 |
+
if encoder_hidden_states is not None:
|
| 880 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
| 881 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 882 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 883 |
+
|
| 884 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
| 885 |
+
batch_size, -1, attn.heads, head_dim
|
| 886 |
+
).transpose(1, 2)
|
| 887 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
| 888 |
+
batch_size, -1, attn.heads, head_dim
|
| 889 |
+
).transpose(1, 2)
|
| 890 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
| 891 |
+
batch_size, -1, attn.heads, head_dim
|
| 892 |
+
).transpose(1, 2)
|
| 893 |
+
|
| 894 |
+
if attn.norm_added_q is not None:
|
| 895 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
| 896 |
+
if attn.norm_added_k is not None:
|
| 897 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
| 898 |
+
|
| 899 |
+
# print(f"image q: {query.shape}, image k: {key.shape}, image v: {value.shape}") # [B, 24, 1024, 64]
|
| 900 |
+
# print(f"text q: {encoder_hidden_states_query_proj.shape}, text k: {encoder_hidden_states_key_proj.shape}, text v: {encoder_hidden_states_value_proj.shape}")
|
| 901 |
+
# [B, 24, 154, 64]
|
| 902 |
+
|
| 903 |
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
| 904 |
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
| 905 |
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
| 906 |
+
|
| 907 |
+
# print(f"Joint - query shape: {query.shape}, key shape: {key.shape}, value shape: {value.shape}")
|
| 908 |
+
# [B, 24, 1178, 64]
|
| 909 |
+
map_aware_mask = attention_mask
|
| 910 |
+
else:
|
| 911 |
+
map_aware_mask = None
|
| 912 |
+
# print(
|
| 913 |
+
# "map_aware_mask:",
|
| 914 |
+
# None if map_aware_mask is None else (map_aware_mask.shape, map_aware_mask.dtype)
|
| 915 |
+
# )
|
| 916 |
+
|
| 917 |
+
# print("query: ", query.shape, query.dtype)
|
| 918 |
+
if map_aware_mask is not None:
|
| 919 |
+
map_aware_mask = map_aware_mask.to(query.dtype)
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value,
|
| 923 |
+
attn_mask=map_aware_mask,
|
| 924 |
+
dropout_p=0.0,
|
| 925 |
+
is_causal=False)
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 929 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 930 |
+
|
| 931 |
+
if encoder_hidden_states is not None:
|
| 932 |
+
# Split the attention outputs.
|
| 933 |
+
hidden_states, encoder_hidden_states = (
|
| 934 |
+
hidden_states[:, : residual.shape[1]],
|
| 935 |
+
hidden_states[:, residual.shape[1] :],
|
| 936 |
+
)
|
| 937 |
+
if not attn.context_pre_only:
|
| 938 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 939 |
+
|
| 940 |
+
# linear proj
|
| 941 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 942 |
+
# dropout
|
| 943 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 944 |
+
|
| 945 |
+
if encoder_hidden_states is not None:
|
| 946 |
+
return hidden_states, encoder_hidden_states
|
| 947 |
+
else:
|
| 948 |
+
return hidden_states
|
| 949 |
+
|
| 950 |
+
class MapAwareTransformerBlock(nn.Module):
|
| 951 |
+
r"""
|
| 952 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
| 953 |
+
|
| 954 |
+
Reference: https://huggingface.co/papers/2403.03206
|
| 955 |
+
|
| 956 |
+
Parameters:
|
| 957 |
+
dim (`int`): The number of channels in the input and output.
|
| 958 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 959 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 960 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
| 961 |
+
processing of `context` conditions.
|
| 962 |
+
"""
|
| 963 |
+
|
| 964 |
+
def __init__(
|
| 965 |
+
self,
|
| 966 |
+
dim: int,
|
| 967 |
+
num_attention_heads: int,
|
| 968 |
+
attention_head_dim: int,
|
| 969 |
+
context_pre_only: bool = False,
|
| 970 |
+
qk_norm: Optional[str] = None,
|
| 971 |
+
use_dual_attention: bool = False,
|
| 972 |
+
):
|
| 973 |
+
super().__init__()
|
| 974 |
+
|
| 975 |
+
self.use_dual_attention = use_dual_attention
|
| 976 |
+
self.context_pre_only = context_pre_only
|
| 977 |
+
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
| 978 |
+
|
| 979 |
+
if use_dual_attention:
|
| 980 |
+
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
| 981 |
+
else:
|
| 982 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 983 |
+
|
| 984 |
+
if context_norm_type == "ada_norm_continous":
|
| 985 |
+
self.norm1_context = AdaLayerNormContinuous(
|
| 986 |
+
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
| 987 |
+
)
|
| 988 |
+
elif context_norm_type == "ada_norm_zero":
|
| 989 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 990 |
+
else:
|
| 991 |
+
raise ValueError(
|
| 992 |
+
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# if hasattr(F, "scaled_dot_product_attention"):
|
| 996 |
+
# processor = JointAttnProcessor2_0()
|
| 997 |
+
# else:
|
| 998 |
+
# raise ValueError(
|
| 999 |
+
# "The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 1000 |
+
# )
|
| 1001 |
+
|
| 1002 |
+
self.attn = MapAwareAttention(
|
| 1003 |
+
query_dim=dim,
|
| 1004 |
+
cross_attention_dim=None,
|
| 1005 |
+
added_kv_proj_dim=dim,
|
| 1006 |
+
dim_head=attention_head_dim,
|
| 1007 |
+
heads=num_attention_heads,
|
| 1008 |
+
out_dim=dim,
|
| 1009 |
+
context_pre_only=context_pre_only,
|
| 1010 |
+
bias=True,
|
| 1011 |
+
processor=MapAwareAttnProcessor2_0(),
|
| 1012 |
+
qk_norm=qk_norm,
|
| 1013 |
+
eps=1e-6,
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
if use_dual_attention:
|
| 1017 |
+
self.attn2 = Attention(
|
| 1018 |
+
query_dim=dim,
|
| 1019 |
+
cross_attention_dim=None,
|
| 1020 |
+
dim_head=attention_head_dim,
|
| 1021 |
+
heads=num_attention_heads,
|
| 1022 |
+
out_dim=dim,
|
| 1023 |
+
bias=True,
|
| 1024 |
+
processor=JointAttnProcessor2_0(),
|
| 1025 |
+
qk_norm=qk_norm,
|
| 1026 |
+
eps=1e-6,
|
| 1027 |
+
)
|
| 1028 |
+
else:
|
| 1029 |
+
self.attn2 = None
|
| 1030 |
+
|
| 1031 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 1032 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 1033 |
+
|
| 1034 |
+
if not context_pre_only:
|
| 1035 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 1036 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 1037 |
+
else:
|
| 1038 |
+
self.norm2_context = None
|
| 1039 |
+
self.ff_context = None
|
| 1040 |
+
|
| 1041 |
+
# let chunk size default to None
|
| 1042 |
+
self._chunk_size = None
|
| 1043 |
+
self._chunk_dim = 0
|
| 1044 |
+
|
| 1045 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
| 1046 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 1047 |
+
# Sets chunk feed-forward
|
| 1048 |
+
self._chunk_size = chunk_size
|
| 1049 |
+
self._chunk_dim = dim
|
| 1050 |
+
|
| 1051 |
+
def forward(
|
| 1052 |
+
self,
|
| 1053 |
+
hidden_states: torch.FloatTensor,
|
| 1054 |
+
encoder_hidden_states: torch.FloatTensor,
|
| 1055 |
+
temb: torch.FloatTensor,
|
| 1056 |
+
map_aware_mask: Optional[torch.FloatTensor] = None,
|
| 1057 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1058 |
+
):
|
| 1059 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 1060 |
+
if self.use_dual_attention:
|
| 1061 |
+
# print(f"hidden_states: {type(hidden_states)}")
|
| 1062 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
| 1063 |
+
hidden_states, emb=temb
|
| 1064 |
+
)
|
| 1065 |
+
else:
|
| 1066 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 1067 |
+
|
| 1068 |
+
if self.context_pre_only:
|
| 1069 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
| 1070 |
+
else:
|
| 1071 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 1072 |
+
encoder_hidden_states, emb=temb
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
# Attention.
|
| 1076 |
+
attn_output, context_attn_output = self.attn(
|
| 1077 |
+
hidden_states=norm_hidden_states,
|
| 1078 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 1079 |
+
map_aware_mask=map_aware_mask,
|
| 1080 |
+
**joint_attention_kwargs,
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# Process attention outputs for the `hidden_states`.
|
| 1084 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 1085 |
+
hidden_states = hidden_states + attn_output
|
| 1086 |
+
|
| 1087 |
+
if self.use_dual_attention:
|
| 1088 |
+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
| 1089 |
+
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
| 1090 |
+
hidden_states = hidden_states + attn_output2
|
| 1091 |
+
|
| 1092 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 1093 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 1094 |
+
if self._chunk_size is not None:
|
| 1095 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 1096 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
| 1097 |
+
else:
|
| 1098 |
+
ff_output = self.ff(norm_hidden_states)
|
| 1099 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 1100 |
+
|
| 1101 |
+
hidden_states = hidden_states + ff_output
|
| 1102 |
+
|
| 1103 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 1104 |
+
if self.context_pre_only:
|
| 1105 |
+
encoder_hidden_states = None
|
| 1106 |
+
else:
|
| 1107 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 1108 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 1109 |
+
|
| 1110 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 1111 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 1112 |
+
if self._chunk_size is not None:
|
| 1113 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 1114 |
+
context_ff_output = _chunked_feed_forward(
|
| 1115 |
+
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
| 1116 |
+
)
|
| 1117 |
+
else:
|
| 1118 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 1119 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 1120 |
+
|
| 1121 |
+
return encoder_hidden_states, hidden_states
|
| 1122 |
+
|
| 1123 |
+
@maybe_allow_in_graph
|
| 1124 |
+
class SD3SingleTransformerBlock(nn.Module):
|
| 1125 |
+
def __init__(
|
| 1126 |
+
self,
|
| 1127 |
+
dim: int,
|
| 1128 |
+
num_attention_heads: int,
|
| 1129 |
+
attention_head_dim: int,
|
| 1130 |
+
):
|
| 1131 |
+
super().__init__()
|
| 1132 |
+
|
| 1133 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 1134 |
+
self.attn = Attention(
|
| 1135 |
+
query_dim=dim,
|
| 1136 |
+
dim_head=attention_head_dim,
|
| 1137 |
+
heads=num_attention_heads,
|
| 1138 |
+
out_dim=dim,
|
| 1139 |
+
bias=True,
|
| 1140 |
+
processor=JointAttnProcessor2_0(),
|
| 1141 |
+
eps=1e-6,
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 1145 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 1146 |
+
|
| 1147 |
+
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
| 1148 |
+
# 1. Attention
|
| 1149 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 1150 |
+
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
|
| 1151 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 1152 |
+
hidden_states = hidden_states + attn_output
|
| 1153 |
+
|
| 1154 |
+
# 2. Feed Forward
|
| 1155 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 1156 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
| 1157 |
+
ff_output = self.ff(norm_hidden_states)
|
| 1158 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 1159 |
+
hidden_states = hidden_states + ff_output
|
| 1160 |
+
|
| 1161 |
+
return hidden_states
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
class IntrinsicWeatherSD3Transformer2DModel(
|
| 1165 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
| 1166 |
+
):
|
| 1167 |
+
"""
|
| 1168 |
+
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
| 1169 |
+
|
| 1170 |
+
Parameters:
|
| 1171 |
+
sample_size (`int`, defaults to `128`):
|
| 1172 |
+
The width/height of the latents. This is fixed during training since it is used to learn a number of
|
| 1173 |
+
position embeddings.
|
| 1174 |
+
patch_size (`int`, defaults to `2`):
|
| 1175 |
+
Patch size to turn the input data into small patches.
|
| 1176 |
+
in_channels (`int`, defaults to `16`):
|
| 1177 |
+
The number of latent channels in the input.
|
| 1178 |
+
num_layers (`int`, defaults to `18`):
|
| 1179 |
+
The number of layers of transformer blocks to use.
|
| 1180 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 1181 |
+
The number of channels in each head.
|
| 1182 |
+
num_attention_heads (`int`, defaults to `18`):
|
| 1183 |
+
The number of heads to use for multi-head attention.
|
| 1184 |
+
joint_attention_dim (`int`, defaults to `4096`):
|
| 1185 |
+
The embedding dimension to use for joint text-image attention.
|
| 1186 |
+
caption_projection_dim (`int`, defaults to `1152`):
|
| 1187 |
+
The embedding dimension of caption embeddings.
|
| 1188 |
+
pooled_projection_dim (`int`, defaults to `2048`):
|
| 1189 |
+
The embedding dimension of pooled text projections.
|
| 1190 |
+
out_channels (`int`, defaults to `16`):
|
| 1191 |
+
The number of latent channels in the output.
|
| 1192 |
+
pos_embed_max_size (`int`, defaults to `96`):
|
| 1193 |
+
The maximum latent height/width of positional embeddings.
|
| 1194 |
+
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
| 1195 |
+
The number of dual-stream transformer blocks to use.
|
| 1196 |
+
qk_norm (`str`, *optional*, defaults to `None`):
|
| 1197 |
+
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
| 1198 |
+
"""
|
| 1199 |
+
|
| 1200 |
+
_supports_gradient_checkpointing = True
|
| 1201 |
+
_no_split_modules = ["JointTransformerBlock"]
|
| 1202 |
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
| 1203 |
+
|
| 1204 |
+
@register_to_config
|
| 1205 |
+
def __init__(
|
| 1206 |
+
self,
|
| 1207 |
+
sample_size: int = 128,
|
| 1208 |
+
patch_size: int = 2,
|
| 1209 |
+
in_channels: int = 16,
|
| 1210 |
+
num_layers: int = 18,
|
| 1211 |
+
attention_head_dim: int = 64,
|
| 1212 |
+
num_attention_heads: int = 18,
|
| 1213 |
+
joint_attention_dim: int = 4096,
|
| 1214 |
+
caption_projection_dim: int = 1152,
|
| 1215 |
+
pooled_projection_dim: int = 2048,
|
| 1216 |
+
out_channels: int = 16,
|
| 1217 |
+
pos_embed_max_size: int = 96,
|
| 1218 |
+
dual_attention_layers: Tuple[
|
| 1219 |
+
int, ...
|
| 1220 |
+
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
|
| 1221 |
+
qk_norm: Optional[str] = None,
|
| 1222 |
+
):
|
| 1223 |
+
super().__init__()
|
| 1224 |
+
self.out_channels = out_channels if out_channels is not None else in_channels
|
| 1225 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 1226 |
+
|
| 1227 |
+
self.pos_embed = PatchEmbed(
|
| 1228 |
+
height=sample_size,
|
| 1229 |
+
width=sample_size,
|
| 1230 |
+
patch_size=patch_size,
|
| 1231 |
+
in_channels=in_channels,
|
| 1232 |
+
embed_dim=self.inner_dim,
|
| 1233 |
+
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
|
| 1234 |
+
)
|
| 1235 |
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
| 1236 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 1237 |
+
)
|
| 1238 |
+
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
| 1239 |
+
|
| 1240 |
+
self.transformer_blocks = nn.ModuleList(
|
| 1241 |
+
[
|
| 1242 |
+
MapAwareTransformerBlock(
|
| 1243 |
+
dim=self.inner_dim,
|
| 1244 |
+
num_attention_heads=num_attention_heads,
|
| 1245 |
+
attention_head_dim=attention_head_dim,
|
| 1246 |
+
context_pre_only=i == num_layers - 1,
|
| 1247 |
+
qk_norm=qk_norm,
|
| 1248 |
+
use_dual_attention=True if i in dual_attention_layers else False,
|
| 1249 |
+
)
|
| 1250 |
+
for i in range(num_layers)
|
| 1251 |
+
]
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 1255 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 1256 |
+
|
| 1257 |
+
self.gradient_checkpointing = False
|
| 1258 |
+
|
| 1259 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
| 1260 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
| 1261 |
+
"""
|
| 1262 |
+
Sets the attention processor to use [feed forward
|
| 1263 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
| 1264 |
+
|
| 1265 |
+
Parameters:
|
| 1266 |
+
chunk_size (`int`, *optional*):
|
| 1267 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
| 1268 |
+
over each tensor of dim=`dim`.
|
| 1269 |
+
dim (`int`, *optional*, defaults to `0`):
|
| 1270 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
| 1271 |
+
or dim=1 (sequence length).
|
| 1272 |
+
"""
|
| 1273 |
+
if dim not in [0, 1]:
|
| 1274 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
| 1275 |
+
|
| 1276 |
+
# By default chunk size is 1
|
| 1277 |
+
chunk_size = chunk_size or 1
|
| 1278 |
+
|
| 1279 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
| 1280 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
| 1281 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
| 1282 |
+
|
| 1283 |
+
for child in module.children():
|
| 1284 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
| 1285 |
+
|
| 1286 |
+
for module in self.children():
|
| 1287 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
| 1288 |
+
|
| 1289 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
| 1290 |
+
def disable_forward_chunking(self):
|
| 1291 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
| 1292 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
| 1293 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
| 1294 |
+
|
| 1295 |
+
for child in module.children():
|
| 1296 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
| 1297 |
+
|
| 1298 |
+
for module in self.children():
|
| 1299 |
+
fn_recursive_feed_forward(module, None, 0)
|
| 1300 |
+
|
| 1301 |
+
@property
|
| 1302 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 1303 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 1304 |
+
r"""
|
| 1305 |
+
Returns:
|
| 1306 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 1307 |
+
indexed by its weight name.
|
| 1308 |
+
"""
|
| 1309 |
+
# set recursively
|
| 1310 |
+
processors = {}
|
| 1311 |
+
|
| 1312 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 1313 |
+
if hasattr(module, "get_processor"):
|
| 1314 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 1315 |
+
|
| 1316 |
+
for sub_name, child in module.named_children():
|
| 1317 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 1318 |
+
|
| 1319 |
+
return processors
|
| 1320 |
+
|
| 1321 |
+
for name, module in self.named_children():
|
| 1322 |
+
fn_recursive_add_processors(name, module, processors)
|
| 1323 |
+
|
| 1324 |
+
return processors
|
| 1325 |
+
|
| 1326 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 1327 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 1328 |
+
r"""
|
| 1329 |
+
Sets the attention processor to use to compute attention.
|
| 1330 |
+
|
| 1331 |
+
Parameters:
|
| 1332 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 1333 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 1334 |
+
for **all** `Attention` layers.
|
| 1335 |
+
|
| 1336 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 1337 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 1338 |
+
|
| 1339 |
+
"""
|
| 1340 |
+
count = len(self.attn_processors.keys())
|
| 1341 |
+
|
| 1342 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 1343 |
+
raise ValueError(
|
| 1344 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 1345 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 1349 |
+
if hasattr(module, "set_processor"):
|
| 1350 |
+
if not isinstance(processor, dict):
|
| 1351 |
+
module.set_processor(processor)
|
| 1352 |
+
else:
|
| 1353 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 1354 |
+
|
| 1355 |
+
for sub_name, child in module.named_children():
|
| 1356 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 1357 |
+
|
| 1358 |
+
for name, module in self.named_children():
|
| 1359 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 1360 |
+
|
| 1361 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
|
| 1362 |
+
def fuse_qkv_projections(self):
|
| 1363 |
+
"""
|
| 1364 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 1365 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 1366 |
+
|
| 1367 |
+
<Tip warning={true}>
|
| 1368 |
+
|
| 1369 |
+
This API is 🧪 experimental.
|
| 1370 |
+
|
| 1371 |
+
</Tip>
|
| 1372 |
+
"""
|
| 1373 |
+
self.original_attn_processors = None
|
| 1374 |
+
|
| 1375 |
+
for _, attn_processor in self.attn_processors.items():
|
| 1376 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 1377 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 1378 |
+
|
| 1379 |
+
self.original_attn_processors = self.attn_processors
|
| 1380 |
+
|
| 1381 |
+
for module in self.modules():
|
| 1382 |
+
if isinstance(module, Attention):
|
| 1383 |
+
module.fuse_projections(fuse=True)
|
| 1384 |
+
|
| 1385 |
+
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
| 1386 |
+
|
| 1387 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 1388 |
+
def unfuse_qkv_projections(self):
|
| 1389 |
+
"""Disables the fused QKV projection if enabled.
|
| 1390 |
+
|
| 1391 |
+
<Tip warning={true}>
|
| 1392 |
+
|
| 1393 |
+
This API is 🧪 experimental.
|
| 1394 |
+
|
| 1395 |
+
</Tip>
|
| 1396 |
+
|
| 1397 |
+
"""
|
| 1398 |
+
if self.original_attn_processors is not None:
|
| 1399 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 1400 |
+
|
| 1401 |
+
def forward(
|
| 1402 |
+
self,
|
| 1403 |
+
hidden_states: torch.Tensor,
|
| 1404 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 1405 |
+
pooled_projections: torch.Tensor = None,
|
| 1406 |
+
timestep: torch.LongTensor = None,
|
| 1407 |
+
block_controlnet_hidden_states: List = None,
|
| 1408 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1409 |
+
return_dict: bool = True,
|
| 1410 |
+
skip_layers: Optional[List[int]] = None,
|
| 1411 |
+
map_aware_mask: Optional[torch.FloatTensor] = None,
|
| 1412 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 1413 |
+
"""
|
| 1414 |
+
The [`SD3Transformer2DModel`] forward method.
|
| 1415 |
+
|
| 1416 |
+
Args:
|
| 1417 |
+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
| 1418 |
+
Input `hidden_states`.
|
| 1419 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
|
| 1420 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 1421 |
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
|
| 1422 |
+
Embeddings projected from the embeddings of input conditions.
|
| 1423 |
+
timestep (`torch.LongTensor`):
|
| 1424 |
+
Used to indicate denoising step.
|
| 1425 |
+
block_controlnet_hidden_states (`list` of `torch.Tensor`):
|
| 1426 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 1427 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 1428 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 1429 |
+
`self.processor` in
|
| 1430 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1431 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1432 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 1433 |
+
tuple.
|
| 1434 |
+
skip_layers (`list` of `int`, *optional*):
|
| 1435 |
+
A list of layer indices to skip during the forward pass.
|
| 1436 |
+
|
| 1437 |
+
Returns:
|
| 1438 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 1439 |
+
`tuple` where the first element is the sample tensor.
|
| 1440 |
+
"""
|
| 1441 |
+
if joint_attention_kwargs is not None:
|
| 1442 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 1443 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 1444 |
+
else:
|
| 1445 |
+
lora_scale = 1.0
|
| 1446 |
+
|
| 1447 |
+
if USE_PEFT_BACKEND:
|
| 1448 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1449 |
+
scale_lora_layers(self, lora_scale)
|
| 1450 |
+
else:
|
| 1451 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 1452 |
+
logger.warning(
|
| 1453 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
height, width = hidden_states.shape[-2:]
|
| 1457 |
+
|
| 1458 |
+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
| 1459 |
+
temb = self.time_text_embed(timestep, pooled_projections)
|
| 1460 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 1461 |
+
|
| 1462 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 1463 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 1464 |
+
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
|
| 1465 |
+
|
| 1466 |
+
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
|
| 1467 |
+
|
| 1468 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 1469 |
+
# print("index: ", index_block)
|
| 1470 |
+
|
| 1471 |
+
# Skip specified layers
|
| 1472 |
+
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
| 1473 |
+
|
| 1474 |
+
if index_block >= (self.config.num_layers // 2) and map_aware_mask is not None:
|
| 1475 |
+
current_mask = map_aware_mask.to(hidden_states.device)
|
| 1476 |
+
else:
|
| 1477 |
+
current_mask = None
|
| 1478 |
+
|
| 1479 |
+
# print("transformer: map_aware_mask:", current_mask.shape if current_mask is not None else None)
|
| 1480 |
+
|
| 1481 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
|
| 1482 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 1483 |
+
block,
|
| 1484 |
+
hidden_states,
|
| 1485 |
+
encoder_hidden_states,
|
| 1486 |
+
temb,
|
| 1487 |
+
current_mask,
|
| 1488 |
+
joint_attention_kwargs,
|
| 1489 |
+
)
|
| 1490 |
+
elif not is_skip:
|
| 1491 |
+
encoder_hidden_states, hidden_states = block(
|
| 1492 |
+
hidden_states=hidden_states,
|
| 1493 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1494 |
+
temb=temb,
|
| 1495 |
+
map_aware_mask=current_mask,
|
| 1496 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1497 |
+
)
|
| 1498 |
+
|
| 1499 |
+
# controlnet residual
|
| 1500 |
+
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
|
| 1501 |
+
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
|
| 1502 |
+
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
|
| 1503 |
+
|
| 1504 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1505 |
+
hidden_states = self.proj_out(hidden_states)
|
| 1506 |
+
|
| 1507 |
+
# unpatchify
|
| 1508 |
+
patch_size = self.config.patch_size
|
| 1509 |
+
height = height // patch_size
|
| 1510 |
+
width = width // patch_size
|
| 1511 |
+
|
| 1512 |
+
hidden_states = hidden_states.reshape(
|
| 1513 |
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
| 1514 |
+
)
|
| 1515 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 1516 |
+
output = hidden_states.reshape(
|
| 1517 |
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
| 1518 |
+
)
|
| 1519 |
+
|
| 1520 |
+
if USE_PEFT_BACKEND:
|
| 1521 |
+
# remove `lora_scale` from each PEFT layer
|
| 1522 |
+
unscale_lora_layers(self, lora_scale)
|
| 1523 |
+
|
| 1524 |
+
if not return_dict:
|
| 1525 |
+
return (output,)
|
| 1526 |
+
|
| 1527 |
+
return Transformer2DModelOutput(sample=output)
|
vae/config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.29.0.dev0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
128,
|
| 7 |
+
256,
|
| 8 |
+
512,
|
| 9 |
+
512
|
| 10 |
+
],
|
| 11 |
+
"down_block_types": [
|
| 12 |
+
"DownEncoderBlock2D",
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D"
|
| 16 |
+
],
|
| 17 |
+
"force_upcast": true,
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 16,
|
| 20 |
+
"latents_mean": null,
|
| 21 |
+
"latents_std": null,
|
| 22 |
+
"layers_per_block": 2,
|
| 23 |
+
"norm_num_groups": 32,
|
| 24 |
+
"out_channels": 3,
|
| 25 |
+
"sample_size": 1024,
|
| 26 |
+
"scaling_factor": 1.5305,
|
| 27 |
+
"shift_factor": 0.0609,
|
| 28 |
+
"up_block_types": [
|
| 29 |
+
"UpDecoderBlock2D",
|
| 30 |
+
"UpDecoderBlock2D",
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D"
|
| 33 |
+
],
|
| 34 |
+
"use_post_quant_conv": false,
|
| 35 |
+
"use_quant_conv": false
|
| 36 |
+
}
|
vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9b67a279283625caee39d61eacb5324243848477b4eb535355eaaa8423d4e09
|
| 3 |
+
size 167666654
|