Upload folder using huggingface_hub
Browse files- README.md +87 -3
- config.json +52 -0
- controlnet/config.json +51 -0
- controlnet/controlnet.py +238 -0
- controlnet/diffusion_pytorch_model.safetensors +3 -0
- feature_extractor/preprocessor_config.json +28 -0
- model_index.json +37 -0
- pipeline_diffusionsat.py +303 -0
- pipeline_diffusionsat_controlnet.py +425 -0
- scheduler/scheduler_config.json +20 -0
- text_encoder/config.json +25 -0
- text_encoder/model.safetensors +3 -0
- tokenizer/config.json +52 -0
- tokenizer/merges.txt +0 -0
- tokenizer/special_tokens_map.json +24 -0
- tokenizer/tokenizer_config.json +33 -0
- tokenizer/vocab.json +0 -0
- unet/config.json +56 -0
- unet/diffusion_pytorch_model.safetensors +3 -0
- unet/sat_unet.py +265 -0
- vae/config.json +31 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
README.md
CHANGED
|
@@ -1,3 +1,87 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DiffusionSat Custom Pipelines
|
| 2 |
+
|
| 3 |
+
Custom community pipelines for loading DiffusionSat checkpoints directly with `diffusers.DiffusionPipeline.from_pretrained()`.
|
| 4 |
+
|
| 5 |
+
> See [Diffusers Community Pipeline Documentation](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
| 6 |
+
|
| 7 |
+
## Available Pipelines
|
| 8 |
+
|
| 9 |
+
This directory contains two custom pipelines:
|
| 10 |
+
|
| 11 |
+
1. **`pipeline_diffusionsat.py`**: Standard text-to-image pipeline with DiffusionSat metadata support.
|
| 12 |
+
2. **`pipeline_diffusionsat_controlnet.py`**: ControlNet pipeline with DiffusionSat metadata and conditional metadata support.
|
| 13 |
+
|
| 14 |
+
## Setup
|
| 15 |
+
|
| 16 |
+
The checkpoint folder (`ckpt/diffusionsat/`) should contain the standard diffusers components (unet, vae, scheduler, etc.). You can reference these pipeline files directly from this directory or copy them to your checkpoint folder.
|
| 17 |
+
|
| 18 |
+
## Usage
|
| 19 |
+
|
| 20 |
+
### 1. Text-to-Image Pipeline
|
| 21 |
+
|
| 22 |
+
Use `pipeline_diffusionsat.py` for standard generation.
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
import torch
|
| 26 |
+
from diffusers import DiffusionPipeline
|
| 27 |
+
|
| 28 |
+
# Load pipeline
|
| 29 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 30 |
+
"path/to/ckpt/diffusionsat",
|
| 31 |
+
custom_pipeline="./custom_pipelines/pipeline_diffusionsat.py", # Path to this file
|
| 32 |
+
torch_dtype=torch.float16,
|
| 33 |
+
trust_remote_code=True,
|
| 34 |
+
)
|
| 35 |
+
pipe = pipe.to("cuda")
|
| 36 |
+
|
| 37 |
+
# Optional: Metadata (normalized lat, lon, timestamp, GSD, etc.)
|
| 38 |
+
# metadata = [0.5, -0.3, 0.7, 0.2, 0.1, 0.0, 0.5]
|
| 39 |
+
|
| 40 |
+
# Generate
|
| 41 |
+
image = pipe(
|
| 42 |
+
"satellite image of farmland",
|
| 43 |
+
metadata=None, # Optional
|
| 44 |
+
num_inference_steps=30,
|
| 45 |
+
).images[0]
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### 2. ControlNet Pipeline
|
| 49 |
+
|
| 50 |
+
Use `pipeline_diffusionsat_controlnet.py` for ControlNet generation.
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import torch
|
| 54 |
+
from diffusers import DiffusionPipeline, ControlNetModel
|
| 55 |
+
from diffusers.utils import load_image
|
| 56 |
+
|
| 57 |
+
# 1. Load ControlNet
|
| 58 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 59 |
+
"path/to/ckpt/diffusionsat/controlnet",
|
| 60 |
+
torch_dtype=torch.float16
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 2. Load Pipeline with ControlNet
|
| 64 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 65 |
+
"path/to/ckpt/diffusionsat",
|
| 66 |
+
controlnet=controlnet,
|
| 67 |
+
custom_pipeline="./custom_pipelines/pipeline_diffusionsat_controlnet.py", # Path to this file
|
| 68 |
+
torch_dtype=torch.float16,
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
)
|
| 71 |
+
pipe = pipe.to("cuda")
|
| 72 |
+
|
| 73 |
+
# 3. Prepare Control Image
|
| 74 |
+
control_image = load_image("path/to/conditioning_image.png")
|
| 75 |
+
|
| 76 |
+
# 4. Generate
|
| 77 |
+
# metadata: Target image metadata (optional)
|
| 78 |
+
# cond_metadata: Conditioning image metadata (optional)
|
| 79 |
+
|
| 80 |
+
image = pipe(
|
| 81 |
+
"satellite image of farmland",
|
| 82 |
+
image=control_image,
|
| 83 |
+
metadata=None,
|
| 84 |
+
cond_metadata=None,
|
| 85 |
+
num_inference_steps=30,
|
| 86 |
+
).images[0]
|
| 87 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "ControlNetModel",
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"_name_or_path": "/data/jiabo/diffusionsat/testoutput/checkpoint-1",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"attention_head_dim": [
|
| 7 |
+
5,
|
| 8 |
+
10,
|
| 9 |
+
20,
|
| 10 |
+
20
|
| 11 |
+
],
|
| 12 |
+
"block_out_channels": [
|
| 13 |
+
320,
|
| 14 |
+
640,
|
| 15 |
+
1280,
|
| 16 |
+
1280
|
| 17 |
+
],
|
| 18 |
+
"class_embed_type": null,
|
| 19 |
+
"conditioning_embedding_out_channels": [
|
| 20 |
+
16,
|
| 21 |
+
32,
|
| 22 |
+
96,
|
| 23 |
+
256
|
| 24 |
+
],
|
| 25 |
+
"conditioning_in_channels": 3,
|
| 26 |
+
"conditioning_scale": 1,
|
| 27 |
+
"controlnet_conditioning_channel_order": "rgb",
|
| 28 |
+
"cross_attention_dim": 1024,
|
| 29 |
+
"down_block_types": [
|
| 30 |
+
"CrossAttnDownBlock2D",
|
| 31 |
+
"CrossAttnDownBlock2D",
|
| 32 |
+
"CrossAttnDownBlock2D",
|
| 33 |
+
"DownBlock2D"
|
| 34 |
+
],
|
| 35 |
+
"downsample_padding": 1,
|
| 36 |
+
"flip_sin_to_cos": true,
|
| 37 |
+
"freq_shift": 0,
|
| 38 |
+
"global_pool_conditions": false,
|
| 39 |
+
"in_channels": 4,
|
| 40 |
+
"layers_per_block": 2,
|
| 41 |
+
"mid_block_scale_factor": 1,
|
| 42 |
+
"norm_eps": 1e-05,
|
| 43 |
+
"norm_num_groups": 32,
|
| 44 |
+
"num_class_embeds": null,
|
| 45 |
+
"num_metadata": 7,
|
| 46 |
+
"only_cross_attention": false,
|
| 47 |
+
"projection_class_embeddings_input_dim": null,
|
| 48 |
+
"resnet_time_scale_shift": "default",
|
| 49 |
+
"upcast_attention": true,
|
| 50 |
+
"use_linear_projection": true,
|
| 51 |
+
"use_metadata": true
|
| 52 |
+
}
|
controlnet/config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": ["controlnet", "ControlNetModel"],
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"attention_head_dim": [
|
| 6 |
+
5,
|
| 7 |
+
10,
|
| 8 |
+
20,
|
| 9 |
+
20
|
| 10 |
+
],
|
| 11 |
+
"block_out_channels": [
|
| 12 |
+
320,
|
| 13 |
+
640,
|
| 14 |
+
1280,
|
| 15 |
+
1280
|
| 16 |
+
],
|
| 17 |
+
"class_embed_type": null,
|
| 18 |
+
"conditioning_embedding_out_channels": [
|
| 19 |
+
16,
|
| 20 |
+
32,
|
| 21 |
+
96,
|
| 22 |
+
256
|
| 23 |
+
],
|
| 24 |
+
"conditioning_in_channels": 3,
|
| 25 |
+
"conditioning_scale": 1,
|
| 26 |
+
"controlnet_conditioning_channel_order": "rgb",
|
| 27 |
+
"cross_attention_dim": 1024,
|
| 28 |
+
"down_block_types": [
|
| 29 |
+
"CrossAttnDownBlock2D",
|
| 30 |
+
"CrossAttnDownBlock2D",
|
| 31 |
+
"CrossAttnDownBlock2D",
|
| 32 |
+
"DownBlock2D"
|
| 33 |
+
],
|
| 34 |
+
"downsample_padding": 1,
|
| 35 |
+
"flip_sin_to_cos": true,
|
| 36 |
+
"freq_shift": 0,
|
| 37 |
+
"global_pool_conditions": false,
|
| 38 |
+
"in_channels": 4,
|
| 39 |
+
"layers_per_block": 2,
|
| 40 |
+
"mid_block_scale_factor": 1,
|
| 41 |
+
"norm_eps": 1e-05,
|
| 42 |
+
"norm_num_groups": 32,
|
| 43 |
+
"num_class_embeds": null,
|
| 44 |
+
"num_metadata": 7,
|
| 45 |
+
"only_cross_attention": false,
|
| 46 |
+
"projection_class_embeddings_input_dim": null,
|
| 47 |
+
"resnet_time_scale_shift": "default",
|
| 48 |
+
"upcast_attention": true,
|
| 49 |
+
"use_linear_projection": true,
|
| 50 |
+
"use_metadata": true
|
| 51 |
+
}
|
controlnet/controlnet.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ControlNet wrapper that reuses diffusers implementation and adds metadata."""
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.models.controlnets.controlnet import (
|
| 9 |
+
ControlNetConditioningEmbedding as HFConditioningEmbedding,
|
| 10 |
+
ControlNetModel as HFControlNetModel,
|
| 11 |
+
ControlNetOutput,
|
| 12 |
+
zero_module,
|
| 13 |
+
)
|
| 14 |
+
from diffusers.utils import logging
|
| 15 |
+
|
| 16 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ControlNetConditioningEmbedding(HFConditioningEmbedding):
|
| 20 |
+
"""Adapter to allow variable downsample stride via `scale` while reusing upstream layers."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
conditioning_embedding_channels: int,
|
| 25 |
+
conditioning_channels: int = 3,
|
| 26 |
+
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
| 27 |
+
scale: int = 1,
|
| 28 |
+
):
|
| 29 |
+
# Initialize base, then optionally override blocks to respect custom stride.
|
| 30 |
+
super().__init__(
|
| 31 |
+
conditioning_embedding_channels=conditioning_embedding_channels,
|
| 32 |
+
conditioning_channels=conditioning_channels,
|
| 33 |
+
block_out_channels=block_out_channels,
|
| 34 |
+
)
|
| 35 |
+
if scale != 1:
|
| 36 |
+
blocks = nn.ModuleList([])
|
| 37 |
+
current_scale = scale
|
| 38 |
+
for i in range(len(block_out_channels) - 1):
|
| 39 |
+
channel_in = block_out_channels[i]
|
| 40 |
+
channel_out = block_out_channels[i + 1]
|
| 41 |
+
blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
| 42 |
+
stride = 2 if current_scale < 8 else 1
|
| 43 |
+
blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride))
|
| 44 |
+
if current_scale != 8:
|
| 45 |
+
current_scale = int(current_scale * 2)
|
| 46 |
+
self.blocks = blocks
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ControlNetModel(HFControlNetModel):
|
| 50 |
+
"""Thin wrapper around `diffusers.ControlNetModel` with metadata embeddings."""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
*args,
|
| 55 |
+
conditioning_in_channels: int = 3,
|
| 56 |
+
conditioning_scale: int = 1,
|
| 57 |
+
use_metadata: bool = True,
|
| 58 |
+
num_metadata: int = 7,
|
| 59 |
+
**kwargs,
|
| 60 |
+
):
|
| 61 |
+
# Map alias to upstream argument.
|
| 62 |
+
kwargs.setdefault("conditioning_channels", conditioning_in_channels)
|
| 63 |
+
|
| 64 |
+
super().__init__(*args, **kwargs)
|
| 65 |
+
|
| 66 |
+
# Track custom config entries for save/load parity.
|
| 67 |
+
self.register_to_config(
|
| 68 |
+
use_metadata=use_metadata, num_metadata=num_metadata, conditioning_scale=conditioning_scale
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.use_metadata = use_metadata
|
| 72 |
+
self.num_metadata = num_metadata
|
| 73 |
+
|
| 74 |
+
if use_metadata:
|
| 75 |
+
timestep_input_dim = self.time_embedding.linear_1.in_features
|
| 76 |
+
time_embed_dim = self.time_embedding.linear_2.out_features
|
| 77 |
+
self.metadata_embedding = nn.ModuleList(
|
| 78 |
+
[
|
| 79 |
+
self._build_metadata_embedding(timestep_input_dim, time_embed_dim)
|
| 80 |
+
for _ in range(num_metadata)
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
self.metadata_embedding = None
|
| 85 |
+
|
| 86 |
+
# Optionally replace conditioning embedding to honor `conditioning_scale` stride tweaks.
|
| 87 |
+
if conditioning_scale != 1:
|
| 88 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
| 89 |
+
conditioning_embedding_channels=self.controlnet_cond_embedding.conv_out.out_channels,
|
| 90 |
+
conditioning_channels=conditioning_in_channels,
|
| 91 |
+
block_out_channels=tuple(
|
| 92 |
+
layer.out_channels for layer in self.controlnet_cond_embedding.blocks[1::2]
|
| 93 |
+
),
|
| 94 |
+
scale=conditioning_scale,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _build_metadata_embedding(timestep_input_dim: int, time_embed_dim: int) -> nn.Module:
|
| 99 |
+
from diffusers.models.embeddings import TimestepEmbedding
|
| 100 |
+
|
| 101 |
+
return TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 102 |
+
|
| 103 |
+
def _encode_metadata(
|
| 104 |
+
self, metadata: Optional[torch.Tensor], dtype: torch.dtype
|
| 105 |
+
) -> Optional[torch.Tensor]:
|
| 106 |
+
if self.metadata_embedding is None:
|
| 107 |
+
return None
|
| 108 |
+
if metadata is None:
|
| 109 |
+
raise ValueError("metadata must be provided when use_metadata=True")
|
| 110 |
+
if metadata.dim() != 2 or metadata.shape[1] != self.num_metadata:
|
| 111 |
+
raise ValueError(f"Invalid metadata shape {metadata.shape}, expected (batch, {self.num_metadata})")
|
| 112 |
+
|
| 113 |
+
md_bsz = metadata.shape[0]
|
| 114 |
+
projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
|
| 115 |
+
|
| 116 |
+
md_emb = projected.new_zeros((md_bsz, projected.shape[-1]))
|
| 117 |
+
for idx, md_embed in enumerate(self.metadata_embedding):
|
| 118 |
+
md_emb = md_emb + md_embed(projected[:, idx, :])
|
| 119 |
+
return md_emb
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self,
|
| 123 |
+
sample: torch.Tensor,
|
| 124 |
+
timestep: Union[torch.Tensor, float, int],
|
| 125 |
+
encoder_hidden_states: torch.Tensor,
|
| 126 |
+
controlnet_cond: torch.Tensor,
|
| 127 |
+
conditioning_scale: float = 1.0,
|
| 128 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 129 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 130 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 131 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 132 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 133 |
+
guess_mode: bool = False,
|
| 134 |
+
metadata: Optional[torch.Tensor] = None,
|
| 135 |
+
return_dict: bool = True,
|
| 136 |
+
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
| 137 |
+
# Start from upstream logic, inserting metadata into the timestep embeddings.
|
| 138 |
+
|
| 139 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
| 140 |
+
if channel_order == "bgr":
|
| 141 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
| 142 |
+
elif channel_order != "rgb":
|
| 143 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
| 144 |
+
|
| 145 |
+
if attention_mask is not None:
|
| 146 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 147 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 148 |
+
|
| 149 |
+
timesteps = timestep
|
| 150 |
+
if not torch.is_tensor(timesteps):
|
| 151 |
+
is_mps = sample.device.type == "mps"
|
| 152 |
+
is_npu = sample.device.type == "npu"
|
| 153 |
+
if isinstance(timestep, float):
|
| 154 |
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 155 |
+
else:
|
| 156 |
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
| 157 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 158 |
+
elif len(timesteps.shape) == 0:
|
| 159 |
+
timesteps = timesteps[None].to(sample.device)
|
| 160 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 161 |
+
|
| 162 |
+
t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
|
| 163 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 164 |
+
|
| 165 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
| 166 |
+
if class_emb is not None:
|
| 167 |
+
if self.config.class_embed_type == "timestep":
|
| 168 |
+
class_emb = class_emb.to(dtype=sample.dtype)
|
| 169 |
+
emb = emb + class_emb
|
| 170 |
+
|
| 171 |
+
aug_emb = self.get_aug_embed(
|
| 172 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs or {}
|
| 173 |
+
)
|
| 174 |
+
if aug_emb is not None:
|
| 175 |
+
emb = emb + aug_emb
|
| 176 |
+
|
| 177 |
+
md_emb = self._encode_metadata(metadata=metadata, dtype=sample.dtype)
|
| 178 |
+
if md_emb is not None:
|
| 179 |
+
emb = emb + md_emb
|
| 180 |
+
|
| 181 |
+
sample = self.conv_in(sample)
|
| 182 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
| 183 |
+
sample = sample + controlnet_cond
|
| 184 |
+
|
| 185 |
+
down_block_res_samples = (sample,)
|
| 186 |
+
for downsample_block in self.down_blocks:
|
| 187 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 188 |
+
sample, res_samples = downsample_block(
|
| 189 |
+
hidden_states=sample,
|
| 190 |
+
temb=emb,
|
| 191 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 192 |
+
attention_mask=attention_mask,
|
| 193 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 197 |
+
down_block_res_samples += res_samples
|
| 198 |
+
|
| 199 |
+
if self.mid_block is not None:
|
| 200 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 201 |
+
sample = self.mid_block(
|
| 202 |
+
sample,
|
| 203 |
+
emb,
|
| 204 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 205 |
+
attention_mask=attention_mask,
|
| 206 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
sample = self.mid_block(sample, emb)
|
| 210 |
+
|
| 211 |
+
controlnet_down_block_res_samples = ()
|
| 212 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
| 213 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
| 214 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
| 215 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
| 216 |
+
|
| 217 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
| 218 |
+
|
| 219 |
+
if guess_mode and not self.config.global_pool_conditions:
|
| 220 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) * conditioning_scale
|
| 221 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
| 222 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1]
|
| 223 |
+
else:
|
| 224 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
| 225 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
| 226 |
+
|
| 227 |
+
if self.config.global_pool_conditions:
|
| 228 |
+
down_block_res_samples = [
|
| 229 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
| 230 |
+
]
|
| 231 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
| 232 |
+
|
| 233 |
+
if not return_dict:
|
| 234 |
+
return (down_block_res_samples, mid_block_res_sample)
|
| 235 |
+
|
| 236 |
+
return ControlNetOutput(
|
| 237 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
| 238 |
+
)
|
controlnet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3bd5f6b9aea04714f331cd94d721c8adb8b378a2774a9805e6f0a369e33aacd7
|
| 3 |
+
size 1514372328
|
feature_extractor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": {
|
| 3 |
+
"height": 224,
|
| 4 |
+
"width": 224
|
| 5 |
+
},
|
| 6 |
+
"do_center_crop": true,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 12 |
+
"image_mean": [
|
| 13 |
+
0.48145466,
|
| 14 |
+
0.4578275,
|
| 15 |
+
0.40821073
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "CLIPImageProcessor",
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.26862954,
|
| 20 |
+
0.26130258,
|
| 21 |
+
0.27577711
|
| 22 |
+
],
|
| 23 |
+
"resample": 3,
|
| 24 |
+
"rescale_factor": 0.00392156862745098,
|
| 25 |
+
"size": {
|
| 26 |
+
"shortest_edge": 224
|
| 27 |
+
}
|
| 28 |
+
}
|
model_index.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": ["pipeline_diffusionsat_controlnet", "DiffusionSatControlNetPipeline"],
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"controlnet": [
|
| 5 |
+
"controlnet",
|
| 6 |
+
"ControlNetModel"
|
| 7 |
+
],
|
| 8 |
+
"feature_extractor": [
|
| 9 |
+
"transformers",
|
| 10 |
+
"CLIPImageProcessor"
|
| 11 |
+
],
|
| 12 |
+
"requires_safety_checker": false,
|
| 13 |
+
"safety_checker": [
|
| 14 |
+
null,
|
| 15 |
+
null
|
| 16 |
+
],
|
| 17 |
+
"scheduler": [
|
| 18 |
+
"diffusers",
|
| 19 |
+
"DDIMScheduler"
|
| 20 |
+
],
|
| 21 |
+
"text_encoder": [
|
| 22 |
+
"transformers",
|
| 23 |
+
"CLIPTextModel"
|
| 24 |
+
],
|
| 25 |
+
"tokenizer": [
|
| 26 |
+
"transformers",
|
| 27 |
+
"CLIPTokenizer"
|
| 28 |
+
],
|
| 29 |
+
"unet": [
|
| 30 |
+
"sat_unet",
|
| 31 |
+
"SatUNet"
|
| 32 |
+
],
|
| 33 |
+
"vae": [
|
| 34 |
+
"diffusers",
|
| 35 |
+
"AutoencoderKL"
|
| 36 |
+
]
|
| 37 |
+
}
|
pipeline_diffusionsat.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Self-contained DiffusionSat text-to-image pipeline that can be loaded directly
|
| 3 |
+
from the checkpoint folder without importing the project package.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from packaging import version
|
| 12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 13 |
+
|
| 14 |
+
from diffusers.configuration_utils import FrozenDict
|
| 15 |
+
from diffusers.models import AutoencoderKL
|
| 16 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 17 |
+
from diffusers.utils import (
|
| 18 |
+
deprecate,
|
| 19 |
+
logging,
|
| 20 |
+
randn_tensor,
|
| 21 |
+
replace_example_docstring,
|
| 22 |
+
is_accelerate_available,
|
| 23 |
+
)
|
| 24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 25 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 26 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 27 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
| 28 |
+
StableDiffusionPipeline as DiffusersStableDiffusionPipeline,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 32 |
+
|
| 33 |
+
EXAMPLE_DOC_STRING = """
|
| 34 |
+
Examples:
|
| 35 |
+
```py
|
| 36 |
+
>>> import torch
|
| 37 |
+
>>> from diffusers import DiffusionPipeline
|
| 38 |
+
|
| 39 |
+
>>> pipe = DiffusionPipeline.from_pretrained("path/to/ckpt/diffusionsat", torch_dtype=torch.float16)
|
| 40 |
+
>>> pipe = pipe.to("cuda")
|
| 41 |
+
|
| 42 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
| 43 |
+
>>> image = pipe(prompt).images[0]
|
| 44 |
+
```
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DiffusionSatPipeline(DiffusionPipeline):
|
| 49 |
+
"""
|
| 50 |
+
Pipeline for text-to-image generation using the DiffusionSat UNet with optional metadata.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
vae: AutoencoderKL,
|
| 58 |
+
text_encoder: CLIPTextModel,
|
| 59 |
+
tokenizer: CLIPTokenizer,
|
| 60 |
+
unet: Any,
|
| 61 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 62 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 63 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 64 |
+
requires_safety_checker: bool = True,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
| 69 |
+
deprecation_message = (
|
| 70 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 71 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 72 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 73 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 74 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 75 |
+
" file"
|
| 76 |
+
)
|
| 77 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 78 |
+
new_config = dict(scheduler.config)
|
| 79 |
+
new_config["steps_offset"] = 1
|
| 80 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 81 |
+
|
| 82 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
| 83 |
+
deprecation_message = (
|
| 84 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 85 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 86 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 87 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 88 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 89 |
+
)
|
| 90 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 91 |
+
new_config = dict(scheduler.config)
|
| 92 |
+
new_config["clip_sample"] = False
|
| 93 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 94 |
+
|
| 95 |
+
if safety_checker is None and requires_safety_checker:
|
| 96 |
+
logger.warning(
|
| 97 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 98 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 99 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 100 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 101 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 102 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if safety_checker is not None and feature_extractor is None:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 108 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
| 112 |
+
version.parse(unet.config._diffusers_version).base_version
|
| 113 |
+
) < version.parse("0.9.0.dev0")
|
| 114 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
| 115 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
| 116 |
+
deprecation_message = (
|
| 117 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
| 118 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
| 119 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
| 120 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
| 121 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
| 122 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
| 123 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
| 124 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
| 125 |
+
" the `unet/config.json` file"
|
| 126 |
+
)
|
| 127 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
| 128 |
+
new_config = dict(unet.config)
|
| 129 |
+
new_config["sample_size"] = 64
|
| 130 |
+
unet._internal_dict = FrozenDict(new_config)
|
| 131 |
+
|
| 132 |
+
self.register_modules(
|
| 133 |
+
vae=vae,
|
| 134 |
+
text_encoder=text_encoder,
|
| 135 |
+
tokenizer=tokenizer,
|
| 136 |
+
unet=unet,
|
| 137 |
+
scheduler=scheduler,
|
| 138 |
+
safety_checker=safety_checker,
|
| 139 |
+
feature_extractor=feature_extractor,
|
| 140 |
+
)
|
| 141 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 142 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 143 |
+
|
| 144 |
+
# Borrow helper implementations from diffusers' StableDiffusionPipeline for convenience.
|
| 145 |
+
enable_vae_slicing = DiffusersStableDiffusionPipeline.enable_vae_slicing
|
| 146 |
+
disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
|
| 147 |
+
enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
|
| 148 |
+
_execution_device = DiffusersStableDiffusionPipeline._execution_device
|
| 149 |
+
_encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
|
| 150 |
+
run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
|
| 151 |
+
decode_latents = DiffusersStableDiffusionPipeline.decode_latents
|
| 152 |
+
prepare_extra_step_kwargs = DiffusersStableDiffusionPipeline.prepare_extra_step_kwargs
|
| 153 |
+
check_inputs = DiffusersStableDiffusionPipeline.check_inputs
|
| 154 |
+
prepare_latents = DiffusersStableDiffusionPipeline.prepare_latents
|
| 155 |
+
|
| 156 |
+
def prepare_metadata(
|
| 157 |
+
self, batch_size, metadata, do_classifier_free_guidance, device, dtype,
|
| 158 |
+
):
|
| 159 |
+
has_metadata = getattr(self.unet.config, "use_metadata", False)
|
| 160 |
+
num_metadata = getattr(self.unet.config, "num_metadata", 0)
|
| 161 |
+
|
| 162 |
+
if metadata is None and has_metadata and num_metadata > 0:
|
| 163 |
+
metadata = torch.zeros((batch_size, num_metadata), device=device, dtype=dtype)
|
| 164 |
+
|
| 165 |
+
if metadata is None:
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
md = torch.tensor(metadata) if not torch.is_tensor(metadata) else metadata
|
| 169 |
+
if len(md.shape) == 1:
|
| 170 |
+
md = md.unsqueeze(0).expand(batch_size, -1)
|
| 171 |
+
md = md.to(device=device, dtype=dtype)
|
| 172 |
+
|
| 173 |
+
if do_classifier_free_guidance:
|
| 174 |
+
md = torch.cat([torch.zeros_like(md), md])
|
| 175 |
+
|
| 176 |
+
return md
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 180 |
+
def __call__(
|
| 181 |
+
self,
|
| 182 |
+
prompt: Union[str, List[str]] = None,
|
| 183 |
+
height: Optional[int] = None,
|
| 184 |
+
width: Optional[int] = None,
|
| 185 |
+
num_inference_steps: int = 50,
|
| 186 |
+
guidance_scale: float = 7.5,
|
| 187 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 188 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 189 |
+
eta: float = 0.0,
|
| 190 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 191 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 192 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 193 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 194 |
+
output_type: Optional[str] = "pil",
|
| 195 |
+
return_dict: bool = True,
|
| 196 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 197 |
+
callback_steps: Optional[int] = 1,
|
| 198 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 199 |
+
metadata: Optional[List[float]] = None,
|
| 200 |
+
):
|
| 201 |
+
# 0. Default height and width to unet
|
| 202 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 203 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 204 |
+
|
| 205 |
+
# 1. Check inputs. Raise error if not correct
|
| 206 |
+
self.check_inputs(
|
| 207 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# 2. Define call parameters
|
| 211 |
+
if prompt is not None and isinstance(prompt, str):
|
| 212 |
+
batch_size = 1
|
| 213 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 214 |
+
batch_size = len(prompt)
|
| 215 |
+
else:
|
| 216 |
+
batch_size = prompt_embeds.shape[0]
|
| 217 |
+
|
| 218 |
+
device = self._execution_device
|
| 219 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 220 |
+
|
| 221 |
+
# 3. Encode input prompt
|
| 222 |
+
prompt_embeds = self._encode_prompt(
|
| 223 |
+
prompt,
|
| 224 |
+
device,
|
| 225 |
+
num_images_per_prompt,
|
| 226 |
+
do_classifier_free_guidance,
|
| 227 |
+
negative_prompt,
|
| 228 |
+
prompt_embeds=prompt_embeds,
|
| 229 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# 4. Prepare timesteps
|
| 233 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 234 |
+
timesteps = self.scheduler.timesteps
|
| 235 |
+
|
| 236 |
+
# 5. Prepare latent variables
|
| 237 |
+
num_channels_latents = self.unet.in_channels if hasattr(self.unet, "in_channels") else self.unet.config.in_channels
|
| 238 |
+
latents = self.prepare_latents(
|
| 239 |
+
batch_size * num_images_per_prompt,
|
| 240 |
+
num_channels_latents,
|
| 241 |
+
height,
|
| 242 |
+
width,
|
| 243 |
+
prompt_embeds.dtype,
|
| 244 |
+
device,
|
| 245 |
+
generator,
|
| 246 |
+
latents,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# 6. Prepare extra step kwargs.
|
| 250 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 251 |
+
|
| 252 |
+
# 6.5: Prepare metadata (auto-zero filled when missing)
|
| 253 |
+
input_metadata = self.prepare_metadata(
|
| 254 |
+
batch_size, metadata, do_classifier_free_guidance, device, prompt_embeds.dtype
|
| 255 |
+
)
|
| 256 |
+
if input_metadata is not None:
|
| 257 |
+
assert input_metadata.shape[-1] == getattr(self.unet.config, "num_metadata", input_metadata.shape[-1])
|
| 258 |
+
assert input_metadata.shape[0] == prompt_embeds.shape[0]
|
| 259 |
+
|
| 260 |
+
# 7. Denoising loop
|
| 261 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 262 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 263 |
+
for i, t in enumerate(timesteps):
|
| 264 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 265 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 266 |
+
|
| 267 |
+
noise_pred = self.unet(
|
| 268 |
+
latent_model_input,
|
| 269 |
+
t,
|
| 270 |
+
metadata=input_metadata,
|
| 271 |
+
encoder_hidden_states=prompt_embeds,
|
| 272 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 273 |
+
).sample
|
| 274 |
+
|
| 275 |
+
if do_classifier_free_guidance:
|
| 276 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 277 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 278 |
+
|
| 279 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 280 |
+
|
| 281 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 282 |
+
progress_bar.update()
|
| 283 |
+
if callback is not None and i % callback_steps == 0:
|
| 284 |
+
callback(i, t, latents)
|
| 285 |
+
|
| 286 |
+
if output_type == "latent":
|
| 287 |
+
image = latents
|
| 288 |
+
has_nsfw_concept = None
|
| 289 |
+
elif output_type == "pil":
|
| 290 |
+
image = self.decode_latents(latents)
|
| 291 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 292 |
+
image = self.numpy_to_pil(image)
|
| 293 |
+
else:
|
| 294 |
+
image = self.decode_latents(latents)
|
| 295 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 296 |
+
|
| 297 |
+
if not return_dict:
|
| 298 |
+
return (image, has_nsfw_concept)
|
| 299 |
+
|
| 300 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
__all__ = ["DiffusionSatPipeline"]
|
pipeline_diffusionsat_controlnet.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Self-contained DiffusionSat ControlNet pipeline that can be loaded directly from
|
| 3 |
+
the checkpoint folder without importing the project package.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import einops
|
| 12 |
+
import numpy as np
|
| 13 |
+
import PIL.Image
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import nn
|
| 17 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 18 |
+
|
| 19 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
| 20 |
+
from diffusers.models import AutoencoderKL
|
| 21 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 22 |
+
from diffusers.utils import (
|
| 23 |
+
PIL_INTERPOLATION,
|
| 24 |
+
logging,
|
| 25 |
+
randn_tensor,
|
| 26 |
+
replace_example_docstring,
|
| 27 |
+
is_accelerate_available,
|
| 28 |
+
is_accelerate_version,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 31 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
| 32 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 33 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
| 34 |
+
StableDiffusionPipeline as DiffusersStableDiffusionPipeline,
|
| 35 |
+
)
|
| 36 |
+
from diffusers.pipelines.controlnet.pipeline_controlnet import (
|
| 37 |
+
StableDiffusionControlNetPipeline as DiffusersControlNetPipeline,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
+
|
| 42 |
+
EXAMPLE_DOC_STRING = """
|
| 43 |
+
Examples:
|
| 44 |
+
```py
|
| 45 |
+
>>> # !pip install opencv-python transformers accelerate
|
| 46 |
+
>>> from diffusers import DiffusionPipeline
|
| 47 |
+
>>> from diffusers.utils import load_image
|
| 48 |
+
>>> import numpy as np
|
| 49 |
+
>>> import torch
|
| 50 |
+
>>> import cv2
|
| 51 |
+
>>> from PIL import Image
|
| 52 |
+
>>>
|
| 53 |
+
>>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
| 54 |
+
>>> image = np.array(image)
|
| 55 |
+
>>> image = cv2.Canny(image, 100, 200)
|
| 56 |
+
>>> image = image[:, :, None]
|
| 57 |
+
>>> image = np.concatenate([image, image, image], axis=2)
|
| 58 |
+
>>> canny_image = Image.fromarray(image)
|
| 59 |
+
>>>
|
| 60 |
+
>>> pipe = DiffusionPipeline.from_pretrained("path/to/ckpt/diffusionsat", torch_dtype=torch.float16)
|
| 61 |
+
>>> pipe = pipe.to("cuda")
|
| 62 |
+
>>> pipe.enable_xformers_memory_efficient_attention()
|
| 63 |
+
>>> generator = torch.manual_seed(0)
|
| 64 |
+
>>> image = pipe(
|
| 65 |
+
... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
|
| 66 |
+
... ).images[0]
|
| 67 |
+
```
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class DiffusionSatControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
| 72 |
+
"""
|
| 73 |
+
ControlNet-aware pipeline for DiffusionSat. This is a mostly direct copy of
|
| 74 |
+
the project pipeline to avoid importing the `diffusionsat` package when
|
| 75 |
+
loading from the checkpoint folder. Minimal tweaks:
|
| 76 |
+
- auto-fills metadata/cond_metadata with zeros when the model expects them.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
vae: AutoencoderKL,
|
| 84 |
+
text_encoder: CLIPTextModel,
|
| 85 |
+
tokenizer: CLIPTokenizer,
|
| 86 |
+
unet: Any,
|
| 87 |
+
controlnet: Any,
|
| 88 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 89 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 90 |
+
feature_extractor: CLIPImageProcessor,
|
| 91 |
+
requires_safety_checker: bool = True,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
if safety_checker is None and requires_safety_checker:
|
| 96 |
+
logger.warning(
|
| 97 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 98 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 99 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 100 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 101 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if safety_checker is not None and feature_extractor is None:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 107 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Support MultiControlNetModel-like objects without importing the project module.
|
| 111 |
+
if isinstance(controlnet, (list, tuple)):
|
| 112 |
+
# defer to diffusers' MultiControlNetModel if available
|
| 113 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 114 |
+
|
| 115 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 116 |
+
|
| 117 |
+
self.register_modules(
|
| 118 |
+
vae=vae,
|
| 119 |
+
text_encoder=text_encoder,
|
| 120 |
+
tokenizer=tokenizer,
|
| 121 |
+
unet=unet,
|
| 122 |
+
controlnet=controlnet,
|
| 123 |
+
scheduler=scheduler,
|
| 124 |
+
safety_checker=safety_checker,
|
| 125 |
+
feature_extractor=feature_extractor,
|
| 126 |
+
)
|
| 127 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 128 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 129 |
+
|
| 130 |
+
# Reuse helpers from diffusers baseline pipelines.
|
| 131 |
+
enable_vae_slicing = DiffusersStableDiffusionPipeline.enable_vae_slicing
|
| 132 |
+
disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
|
| 133 |
+
enable_vae_tiling = DiffusersStableDiffusionPipeline.enable_vae_tiling
|
| 134 |
+
disable_vae_tiling = DiffusersStableDiffusionPipeline.disable_vae_tiling
|
| 135 |
+
enable_sequential_cpu_offload = DiffusersControlNetPipeline.enable_sequential_cpu_offload
|
| 136 |
+
enable_model_cpu_offload = DiffusersControlNetPipeline.enable_model_cpu_offload
|
| 137 |
+
_execution_device = DiffusersStableDiffusionPipeline._execution_device
|
| 138 |
+
_encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
|
| 139 |
+
run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
|
| 140 |
+
decode_latents = DiffusersStableDiffusionPipeline.decode_latents
|
| 141 |
+
prepare_extra_step_kwargs = DiffusersStableDiffusionPipeline.prepare_extra_step_kwargs
|
| 142 |
+
check_inputs = DiffusersControlNetPipeline.check_inputs
|
| 143 |
+
check_image = DiffusersControlNetPipeline.check_image
|
| 144 |
+
prepare_image = DiffusersControlNetPipeline.prepare_image
|
| 145 |
+
prepare_latents = DiffusersStableDiffusionPipeline.prepare_latents
|
| 146 |
+
|
| 147 |
+
def prepare_metadata(self, batch_size, metadata, ndims, do_classifier_free_guidance, device, dtype):
|
| 148 |
+
has_metadata = getattr(self.unet.config, "use_metadata", False)
|
| 149 |
+
num_metadata = getattr(self.unet.config, "num_metadata", 0)
|
| 150 |
+
|
| 151 |
+
if metadata is None and has_metadata and num_metadata > 0:
|
| 152 |
+
shape = (batch_size, num_metadata) if ndims == 2 else (batch_size, num_metadata, 1)
|
| 153 |
+
metadata = torch.zeros(shape, device=device, dtype=dtype)
|
| 154 |
+
|
| 155 |
+
if metadata is None:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
md = torch.as_tensor(metadata)
|
| 159 |
+
if ndims == 2:
|
| 160 |
+
assert (len(md.shape) == 1 and batch_size == 1) or (len(md.shape) == 2 and batch_size > 1)
|
| 161 |
+
if len(md.shape) == 1:
|
| 162 |
+
md = md.unsqueeze(0).expand(batch_size, -1)
|
| 163 |
+
elif ndims == 3:
|
| 164 |
+
assert (len(md.shape) == 2 and batch_size == 1) or (len(md.shape) == 3 and batch_size > 1)
|
| 165 |
+
if len(md.shape) == 2:
|
| 166 |
+
md = md.unsqueeze(0).expand(batch_size, -1, -1)
|
| 167 |
+
|
| 168 |
+
if do_classifier_free_guidance:
|
| 169 |
+
md = torch.cat([torch.zeros_like(md), md])
|
| 170 |
+
|
| 171 |
+
md = md.to(device=device, dtype=dtype)
|
| 172 |
+
return md
|
| 173 |
+
|
| 174 |
+
def _default_height_width(self, height, width, image):
|
| 175 |
+
while isinstance(image, list):
|
| 176 |
+
image = image[0]
|
| 177 |
+
|
| 178 |
+
if height is None:
|
| 179 |
+
if isinstance(image, PIL.Image.Image):
|
| 180 |
+
height = image.height
|
| 181 |
+
elif isinstance(image, torch.Tensor):
|
| 182 |
+
height = image.shape[2]
|
| 183 |
+
height = (height // 8) * 8
|
| 184 |
+
|
| 185 |
+
if width is None:
|
| 186 |
+
if isinstance(image, PIL.Image.Image):
|
| 187 |
+
width = image.width
|
| 188 |
+
elif isinstance(image, torch.Tensor):
|
| 189 |
+
width = image.shape[3]
|
| 190 |
+
width = (width // 8) * 8
|
| 191 |
+
|
| 192 |
+
return height, width
|
| 193 |
+
|
| 194 |
+
# override DiffusionPipeline
|
| 195 |
+
def save_pretrained(
|
| 196 |
+
self,
|
| 197 |
+
save_directory: Union[str, os.PathLike],
|
| 198 |
+
safe_serialization: bool = False,
|
| 199 |
+
variant: Optional[str] = None,
|
| 200 |
+
):
|
| 201 |
+
# For single or multi controlnet, rely on default save logic.
|
| 202 |
+
super().save_pretrained(save_directory, safe_serialization=safe_serialization, variant=variant)
|
| 203 |
+
|
| 204 |
+
@torch.no_grad()
|
| 205 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 206 |
+
def __call__(
|
| 207 |
+
self,
|
| 208 |
+
prompt: Union[str, List[str]] = None,
|
| 209 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
|
| 210 |
+
height: Optional[int] = None,
|
| 211 |
+
width: Optional[int] = None,
|
| 212 |
+
num_inference_steps: int = 50,
|
| 213 |
+
guidance_scale: float = 7.5,
|
| 214 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 215 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 216 |
+
eta: float = 0.0,
|
| 217 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 218 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 219 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 220 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 221 |
+
output_type: Optional[str] = "pil",
|
| 222 |
+
return_dict: bool = True,
|
| 223 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 224 |
+
callback_steps: int = 1,
|
| 225 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 226 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 227 |
+
guess_mode: bool = False,
|
| 228 |
+
metadata: Optional[List[float]] = None,
|
| 229 |
+
cond_metadata: Optional[List[float]] = None,
|
| 230 |
+
is_temporal: bool = False,
|
| 231 |
+
conditioning_downsample: bool = True,
|
| 232 |
+
):
|
| 233 |
+
# 0. Default height and width to unet
|
| 234 |
+
height, width = self._default_height_width(height, width, image)
|
| 235 |
+
cond_height, cond_width = height, width
|
| 236 |
+
if not conditioning_downsample:
|
| 237 |
+
cond_height, cond_width = height // 8, width // 8
|
| 238 |
+
|
| 239 |
+
# 1. Check inputs. Raise error if not correct
|
| 240 |
+
self.check_inputs(
|
| 241 |
+
prompt,
|
| 242 |
+
image,
|
| 243 |
+
height,
|
| 244 |
+
width,
|
| 245 |
+
callback_steps,
|
| 246 |
+
negative_prompt,
|
| 247 |
+
prompt_embeds,
|
| 248 |
+
negative_prompt_embeds,
|
| 249 |
+
controlnet_conditioning_scale,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# 2. Define call parameters
|
| 253 |
+
if prompt is not None and isinstance(prompt, str):
|
| 254 |
+
batch_size = 1
|
| 255 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 256 |
+
batch_size = len(prompt)
|
| 257 |
+
else:
|
| 258 |
+
batch_size = prompt_embeds.shape[0]
|
| 259 |
+
|
| 260 |
+
device = self._execution_device
|
| 261 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 262 |
+
|
| 263 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 264 |
+
|
| 265 |
+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 266 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
|
| 267 |
+
|
| 268 |
+
# 3. Encode input prompt
|
| 269 |
+
prompt_embeds = self._encode_prompt(
|
| 270 |
+
prompt,
|
| 271 |
+
device,
|
| 272 |
+
num_images_per_prompt,
|
| 273 |
+
do_classifier_free_guidance,
|
| 274 |
+
negative_prompt,
|
| 275 |
+
prompt_embeds=prompt_embeds,
|
| 276 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# 4. Prepare image
|
| 280 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
| 281 |
+
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
| 282 |
+
)
|
| 283 |
+
is_multi_cond = isinstance(image, list)
|
| 284 |
+
|
| 285 |
+
if (
|
| 286 |
+
hasattr(self.controlnet, "controlnet_cond_embedding")
|
| 287 |
+
or is_compiled
|
| 288 |
+
and hasattr(self.controlnet._orig_mod, "controlnet_cond_embedding")
|
| 289 |
+
):
|
| 290 |
+
image = self.prepare_image(
|
| 291 |
+
image=image,
|
| 292 |
+
width=cond_width,
|
| 293 |
+
height=cond_height,
|
| 294 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 295 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 296 |
+
device=device,
|
| 297 |
+
dtype=self.controlnet.dtype,
|
| 298 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 299 |
+
guess_mode=guess_mode,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# 5. Prepare timesteps
|
| 303 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 304 |
+
timesteps = self.scheduler.timesteps
|
| 305 |
+
|
| 306 |
+
# 6. Prepare latent variables
|
| 307 |
+
num_channels_latents = self.unet.config.in_channels
|
| 308 |
+
latents = self.prepare_latents(
|
| 309 |
+
batch_size * num_images_per_prompt,
|
| 310 |
+
num_channels_latents,
|
| 311 |
+
height,
|
| 312 |
+
width,
|
| 313 |
+
prompt_embeds.dtype,
|
| 314 |
+
device,
|
| 315 |
+
generator,
|
| 316 |
+
latents,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# 7. Prepare extra step kwargs.
|
| 320 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 321 |
+
|
| 322 |
+
# CUSTOM metadata handling (auto-zero filled)
|
| 323 |
+
input_metadata = self.prepare_metadata(batch_size, metadata, 2, do_classifier_free_guidance, device, prompt_embeds.dtype)
|
| 324 |
+
ndims_cond = 3 if is_multi_cond else 2
|
| 325 |
+
cond_metadata = self.prepare_metadata(
|
| 326 |
+
batch_size, cond_metadata, ndims_cond, do_classifier_free_guidance, device, prompt_embeds.dtype
|
| 327 |
+
)
|
| 328 |
+
if input_metadata is not None:
|
| 329 |
+
assert len(input_metadata.shape) == 2 and input_metadata.shape[-1] == getattr(self.unet.config, "num_metadata", input_metadata.shape[-1])
|
| 330 |
+
if cond_metadata is not None:
|
| 331 |
+
assert len(cond_metadata.shape) == ndims_cond and cond_metadata.shape[1] == getattr(self.unet.config, "num_metadata", cond_metadata.shape[1])
|
| 332 |
+
if is_multi_cond and not is_temporal and not isinstance(self.controlnet, MultiControlNetModel):
|
| 333 |
+
assert cond_metadata.shape[2] == self.controlnet.controlnet_cond_embedding.conv_in.in_channels / 3
|
| 334 |
+
|
| 335 |
+
if input_metadata is not None:
|
| 336 |
+
assert input_metadata.shape[0] == prompt_embeds.shape[0]
|
| 337 |
+
|
| 338 |
+
if is_temporal:
|
| 339 |
+
num_cond = cond_metadata.shape[-1] if cond_metadata is not None else image.shape[1] // self.controlnet.config.conditioning_in_channels
|
| 340 |
+
image = einops.rearrange(image, 'b (t c) h w -> b c t h w', t=num_cond)
|
| 341 |
+
elif isinstance(self.controlnet, MultiControlNetModel) and cond_metadata is not None:
|
| 342 |
+
num_cond = cond_metadata.shape[-1] if cond_metadata is not None else image.shape[1] // self.controlnet.config.conditioning_in_channels
|
| 343 |
+
image = einops.rearrange(image, 'b (t c) h w -> t b c h w', t=num_cond)
|
| 344 |
+
image = [im for im in image]
|
| 345 |
+
cond_metadata = einops.rearrange(cond_metadata, 'b m t -> t b m')
|
| 346 |
+
cond_metadata = [cond_md for cond_md in cond_metadata]
|
| 347 |
+
|
| 348 |
+
# 8. Denoising loop
|
| 349 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 350 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 351 |
+
for i, t in enumerate(timesteps):
|
| 352 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 353 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 354 |
+
|
| 355 |
+
if guess_mode and do_classifier_free_guidance:
|
| 356 |
+
controlnet_latent_model_input = latents
|
| 357 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 358 |
+
else:
|
| 359 |
+
controlnet_latent_model_input = latent_model_input
|
| 360 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 361 |
+
|
| 362 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 363 |
+
controlnet_latent_model_input,
|
| 364 |
+
t,
|
| 365 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 366 |
+
controlnet_cond=image,
|
| 367 |
+
metadata=input_metadata,
|
| 368 |
+
cond_metadata=cond_metadata,
|
| 369 |
+
conditioning_scale=controlnet_conditioning_scale,
|
| 370 |
+
guess_mode=guess_mode,
|
| 371 |
+
return_dict=False,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if guess_mode and do_classifier_free_guidance:
|
| 375 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 376 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 377 |
+
|
| 378 |
+
noise_pred = self.unet(
|
| 379 |
+
latent_model_input,
|
| 380 |
+
t,
|
| 381 |
+
encoder_hidden_states=prompt_embeds,
|
| 382 |
+
metadata=input_metadata,
|
| 383 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 384 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 385 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 386 |
+
return_dict=False,
|
| 387 |
+
)[0]
|
| 388 |
+
|
| 389 |
+
if do_classifier_free_guidance:
|
| 390 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 391 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 392 |
+
|
| 393 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 394 |
+
|
| 395 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 396 |
+
progress_bar.update()
|
| 397 |
+
if callback is not None and i % callback_steps == 0:
|
| 398 |
+
callback(i, t, latents)
|
| 399 |
+
|
| 400 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 401 |
+
self.unet.to("cpu")
|
| 402 |
+
self.controlnet.to("cpu")
|
| 403 |
+
torch.cuda.empty_cache()
|
| 404 |
+
|
| 405 |
+
if output_type == "latent":
|
| 406 |
+
image = latents
|
| 407 |
+
has_nsfw_concept = None
|
| 408 |
+
elif output_type == "pil":
|
| 409 |
+
image = self.decode_latents(latents)
|
| 410 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 411 |
+
image = self.numpy_to_pil(image)
|
| 412 |
+
else:
|
| 413 |
+
image = self.decode_latents(latents)
|
| 414 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 415 |
+
|
| 416 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 417 |
+
self.final_offload_hook.offload()
|
| 418 |
+
|
| 419 |
+
if not return_dict:
|
| 420 |
+
return (image, has_nsfw_concept)
|
| 421 |
+
|
| 422 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
__all__ = ["DiffusionSatControlNetPipeline"]
|
scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "DDIMScheduler",
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"beta_end": 0.012,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.00085,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"clip_sample_range": 1.0,
|
| 9 |
+
"dynamic_thresholding_ratio": 0.995,
|
| 10 |
+
"num_train_timesteps": 1000,
|
| 11 |
+
"prediction_type": "v_prediction",
|
| 12 |
+
"rescale_betas_zero_snr": false,
|
| 13 |
+
"sample_max_value": 1.0,
|
| 14 |
+
"set_alpha_to_one": false,
|
| 15 |
+
"skip_prk_steps": true,
|
| 16 |
+
"steps_offset": 1,
|
| 17 |
+
"thresholding": false,
|
| 18 |
+
"timestep_spacing": "leading",
|
| 19 |
+
"trained_betas": null
|
| 20 |
+
}
|
text_encoder/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "stabilityai/stable-diffusion-2-1",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPTextModel"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 0,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"eos_token_id": 2,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_factor": 1.0,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 4096,
|
| 15 |
+
"layer_norm_eps": 1e-05,
|
| 16 |
+
"max_position_embeddings": 77,
|
| 17 |
+
"model_type": "clip_text_model",
|
| 18 |
+
"num_attention_heads": 16,
|
| 19 |
+
"num_hidden_layers": 23,
|
| 20 |
+
"pad_token_id": 1,
|
| 21 |
+
"projection_dim": 512,
|
| 22 |
+
"torch_dtype": "float16",
|
| 23 |
+
"transformers_version": "4.31.0",
|
| 24 |
+
"vocab_size": 49408
|
| 25 |
+
}
|
text_encoder/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
|
| 3 |
+
size 680820392
|
tokenizer/config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "ControlNetModel",
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"_name_or_path": "/data/jiabo/diffusionsat/testoutput/checkpoint-1",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"attention_head_dim": [
|
| 7 |
+
5,
|
| 8 |
+
10,
|
| 9 |
+
20,
|
| 10 |
+
20
|
| 11 |
+
],
|
| 12 |
+
"block_out_channels": [
|
| 13 |
+
320,
|
| 14 |
+
640,
|
| 15 |
+
1280,
|
| 16 |
+
1280
|
| 17 |
+
],
|
| 18 |
+
"class_embed_type": null,
|
| 19 |
+
"conditioning_embedding_out_channels": [
|
| 20 |
+
16,
|
| 21 |
+
32,
|
| 22 |
+
96,
|
| 23 |
+
256
|
| 24 |
+
],
|
| 25 |
+
"conditioning_in_channels": 3,
|
| 26 |
+
"conditioning_scale": 1,
|
| 27 |
+
"controlnet_conditioning_channel_order": "rgb",
|
| 28 |
+
"cross_attention_dim": 1024,
|
| 29 |
+
"down_block_types": [
|
| 30 |
+
"CrossAttnDownBlock2D",
|
| 31 |
+
"CrossAttnDownBlock2D",
|
| 32 |
+
"CrossAttnDownBlock2D",
|
| 33 |
+
"DownBlock2D"
|
| 34 |
+
],
|
| 35 |
+
"downsample_padding": 1,
|
| 36 |
+
"flip_sin_to_cos": true,
|
| 37 |
+
"freq_shift": 0,
|
| 38 |
+
"global_pool_conditions": false,
|
| 39 |
+
"in_channels": 4,
|
| 40 |
+
"layers_per_block": 2,
|
| 41 |
+
"mid_block_scale_factor": 1,
|
| 42 |
+
"norm_eps": 1e-05,
|
| 43 |
+
"norm_num_groups": 32,
|
| 44 |
+
"num_class_embeds": null,
|
| 45 |
+
"num_metadata": 7,
|
| 46 |
+
"only_cross_attention": false,
|
| 47 |
+
"projection_class_embeddings_input_dim": null,
|
| 48 |
+
"resnet_time_scale_shift": "default",
|
| 49 |
+
"upcast_attention": true,
|
| 50 |
+
"use_linear_projection": true,
|
| 51 |
+
"use_metadata": true
|
| 52 |
+
}
|
tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|startoftext|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|endoftext|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "!",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<|endoftext|>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": true,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"__type": "AddedToken",
|
| 5 |
+
"content": "<|startoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": true,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false
|
| 10 |
+
},
|
| 11 |
+
"clean_up_tokenization_spaces": true,
|
| 12 |
+
"do_lower_case": true,
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"__type": "AddedToken",
|
| 15 |
+
"content": "<|endoftext|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": true,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"errors": "replace",
|
| 22 |
+
"model_max_length": 77,
|
| 23 |
+
"pad_token": "<|endoftext|>",
|
| 24 |
+
"tokenizer_class": "CLIPTokenizer",
|
| 25 |
+
"unk_token": {
|
| 26 |
+
"__type": "AddedToken",
|
| 27 |
+
"content": "<|endoftext|>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": true,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unet/config.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": ["sat_unet", "SatUNet"],
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"attention_head_dim": [
|
| 6 |
+
5,
|
| 7 |
+
10,
|
| 8 |
+
20,
|
| 9 |
+
20
|
| 10 |
+
],
|
| 11 |
+
"block_out_channels": [
|
| 12 |
+
320,
|
| 13 |
+
640,
|
| 14 |
+
1280,
|
| 15 |
+
1280
|
| 16 |
+
],
|
| 17 |
+
"center_input_sample": false,
|
| 18 |
+
"class_embed_type": null,
|
| 19 |
+
"conv_in_kernel": 3,
|
| 20 |
+
"conv_out_kernel": 3,
|
| 21 |
+
"cross_attention_dim": 1024,
|
| 22 |
+
"down_block_types": [
|
| 23 |
+
"CrossAttnDownBlock2D",
|
| 24 |
+
"CrossAttnDownBlock2D",
|
| 25 |
+
"CrossAttnDownBlock2D",
|
| 26 |
+
"DownBlock2D"
|
| 27 |
+
],
|
| 28 |
+
"downsample_padding": 1,
|
| 29 |
+
"dual_cross_attention": false,
|
| 30 |
+
"flip_sin_to_cos": true,
|
| 31 |
+
"freq_shift": 0,
|
| 32 |
+
"in_channels": 4,
|
| 33 |
+
"layers_per_block": 2,
|
| 34 |
+
"mid_block_scale_factor": 1,
|
| 35 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 36 |
+
"norm_eps": 1e-05,
|
| 37 |
+
"norm_num_groups": 32,
|
| 38 |
+
"num_class_embeds": null,
|
| 39 |
+
"num_metadata": 7,
|
| 40 |
+
"only_cross_attention": false,
|
| 41 |
+
"out_channels": 4,
|
| 42 |
+
"resnet_time_scale_shift": "default",
|
| 43 |
+
"sample_size": 96,
|
| 44 |
+
"time_cond_proj_dim": null,
|
| 45 |
+
"time_embedding_type": "positional",
|
| 46 |
+
"timestep_post_act": null,
|
| 47 |
+
"up_block_types": [
|
| 48 |
+
"UpBlock2D",
|
| 49 |
+
"CrossAttnUpBlock2D",
|
| 50 |
+
"CrossAttnUpBlock2D",
|
| 51 |
+
"CrossAttnUpBlock2D"
|
| 52 |
+
],
|
| 53 |
+
"upcast_attention": true,
|
| 54 |
+
"use_linear_projection": true,
|
| 55 |
+
"use_metadata": true
|
| 56 |
+
}
|
unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef6c0264f8eb5085b08e5f16631e1ad8ba078f28d94c902276f8dfc603e3eb80
|
| 3 |
+
size 1760615624
|
unet/sat_unet.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Satellite UNet wrapper with metadata support on top of diffusers."""
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from diffusers.models.unets.unet_2d_condition import (
|
| 8 |
+
UNet2DConditionModel,
|
| 9 |
+
UNet2DConditionOutput,
|
| 10 |
+
)
|
| 11 |
+
from diffusers.utils import logging
|
| 12 |
+
|
| 13 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SatUNet(UNet2DConditionModel):
|
| 17 |
+
"""Thin wrapper around `diffusers.UNet2DConditionModel` with metadata embeddings."""
|
| 18 |
+
|
| 19 |
+
_supports_gradient_checkpointing = True
|
| 20 |
+
|
| 21 |
+
def __init__(self, *args, use_metadata: bool = True, num_metadata: int = 7, **kwargs):
|
| 22 |
+
super().__init__(*args, **kwargs)
|
| 23 |
+
|
| 24 |
+
# Track custom config entries for save/load parity with the base model.
|
| 25 |
+
self.register_to_config(use_metadata=use_metadata, num_metadata=num_metadata)
|
| 26 |
+
|
| 27 |
+
self.use_metadata = use_metadata
|
| 28 |
+
self.num_metadata = num_metadata
|
| 29 |
+
|
| 30 |
+
if use_metadata:
|
| 31 |
+
# Re-use the same dimensions as the base time embedding.
|
| 32 |
+
timestep_input_dim = self.time_embedding.linear_1.in_features
|
| 33 |
+
time_embed_dim = self.time_embedding.linear_2.out_features
|
| 34 |
+
self.metadata_embedding = nn.ModuleList(
|
| 35 |
+
[self._build_metadata_embedding(timestep_input_dim, time_embed_dim) for _ in range(num_metadata)]
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
self.metadata_embedding = None
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _build_metadata_embedding(timestep_input_dim: int, time_embed_dim: int) -> nn.Module:
|
| 42 |
+
from diffusers.models.embeddings import TimestepEmbedding
|
| 43 |
+
|
| 44 |
+
return TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 45 |
+
|
| 46 |
+
def _encode_metadata(
|
| 47 |
+
self, metadata: Optional[torch.Tensor], dtype: torch.dtype
|
| 48 |
+
) -> Optional[torch.Tensor]:
|
| 49 |
+
if self.metadata_embedding is None:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
if metadata is None:
|
| 53 |
+
raise ValueError("metadata must be provided when use_metadata=True")
|
| 54 |
+
|
| 55 |
+
if metadata.dim() != 2 or metadata.shape[1] != self.num_metadata:
|
| 56 |
+
raise ValueError(f"Invalid metadata shape {metadata.shape}, expected (batch, {self.num_metadata})")
|
| 57 |
+
|
| 58 |
+
md_bsz = metadata.shape[0]
|
| 59 |
+
# Reuse the same projection used for timestep encoding to stay aligned with base embeddings.
|
| 60 |
+
projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
|
| 61 |
+
|
| 62 |
+
md_emb = projected.new_zeros((md_bsz, projected.shape[-1]))
|
| 63 |
+
for idx, md_embed in enumerate(self.metadata_embedding):
|
| 64 |
+
md_emb = md_emb + md_embed(projected[:, idx, :])
|
| 65 |
+
|
| 66 |
+
return md_emb
|
| 67 |
+
|
| 68 |
+
def forward(
|
| 69 |
+
self,
|
| 70 |
+
sample: torch.Tensor,
|
| 71 |
+
timestep: Union[torch.Tensor, float, int],
|
| 72 |
+
encoder_hidden_states: torch.Tensor,
|
| 73 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 74 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 75 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 76 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 77 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 78 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 79 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 80 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 81 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 82 |
+
metadata: Optional[torch.Tensor] = None,
|
| 83 |
+
return_dict: bool = True,
|
| 84 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 85 |
+
# Largely mirrors `UNet2DConditionModel.forward` with a metadata injection on the timestep embedding.
|
| 86 |
+
|
| 87 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 88 |
+
forward_upsample_size = False
|
| 89 |
+
upsample_size = None
|
| 90 |
+
|
| 91 |
+
for dim in sample.shape[-2:]:
|
| 92 |
+
if dim % default_overall_up_factor != 0:
|
| 93 |
+
forward_upsample_size = True
|
| 94 |
+
break
|
| 95 |
+
|
| 96 |
+
if attention_mask is not None:
|
| 97 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 98 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 99 |
+
|
| 100 |
+
if encoder_attention_mask is not None:
|
| 101 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 102 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 103 |
+
|
| 104 |
+
if self.config.center_input_sample:
|
| 105 |
+
sample = 2 * sample - 1.0
|
| 106 |
+
|
| 107 |
+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
| 108 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 109 |
+
|
| 110 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
| 111 |
+
if class_emb is not None:
|
| 112 |
+
if self.config.class_embeddings_concat:
|
| 113 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 114 |
+
else:
|
| 115 |
+
emb = emb + class_emb
|
| 116 |
+
|
| 117 |
+
aug_emb = self.get_aug_embed(
|
| 118 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs or {}
|
| 119 |
+
)
|
| 120 |
+
if self.config.addition_embed_type == "image_hint" and aug_emb is not None:
|
| 121 |
+
aug_emb, hint = aug_emb
|
| 122 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 123 |
+
|
| 124 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 125 |
+
|
| 126 |
+
md_emb = self._encode_metadata(metadata=metadata, dtype=sample.dtype)
|
| 127 |
+
if md_emb is not None:
|
| 128 |
+
emb = emb + md_emb
|
| 129 |
+
|
| 130 |
+
if self.time_embed_act is not None:
|
| 131 |
+
emb = self.time_embed_act(emb)
|
| 132 |
+
|
| 133 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
| 134 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs or {}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
sample = self.conv_in(sample)
|
| 138 |
+
|
| 139 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 140 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 141 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 142 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 143 |
+
|
| 144 |
+
if cross_attention_kwargs is not None:
|
| 145 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 146 |
+
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
| 147 |
+
else:
|
| 148 |
+
lora_scale = 1.0
|
| 149 |
+
|
| 150 |
+
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, deprecate
|
| 151 |
+
|
| 152 |
+
if USE_PEFT_BACKEND:
|
| 153 |
+
scale_lora_layers(self, lora_scale)
|
| 154 |
+
|
| 155 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 156 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
| 157 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
| 158 |
+
deprecate(
|
| 159 |
+
"T2I should not use down_block_additional_residuals",
|
| 160 |
+
"1.3.0",
|
| 161 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
|
| 162 |
+
"and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used "
|
| 163 |
+
"for ControlNet. Please use `down_intrablock_additional_residuals` instead.",
|
| 164 |
+
standard_warn=False,
|
| 165 |
+
)
|
| 166 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
| 167 |
+
is_adapter = True
|
| 168 |
+
|
| 169 |
+
down_block_res_samples = (sample,)
|
| 170 |
+
for downsample_block in self.down_blocks:
|
| 171 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 172 |
+
additional_residuals: Dict[str, torch.Tensor] = {}
|
| 173 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
| 174 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
| 175 |
+
|
| 176 |
+
sample, res_samples = downsample_block(
|
| 177 |
+
hidden_states=sample,
|
| 178 |
+
temb=emb,
|
| 179 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 180 |
+
attention_mask=attention_mask,
|
| 181 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 182 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 183 |
+
**additional_residuals,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 187 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
| 188 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 189 |
+
|
| 190 |
+
down_block_res_samples += res_samples
|
| 191 |
+
|
| 192 |
+
if is_controlnet:
|
| 193 |
+
new_down_block_res_samples = ()
|
| 194 |
+
|
| 195 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 196 |
+
down_block_res_samples, down_block_additional_residuals
|
| 197 |
+
):
|
| 198 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 199 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 200 |
+
|
| 201 |
+
down_block_res_samples = new_down_block_res_samples
|
| 202 |
+
|
| 203 |
+
if self.mid_block is not None:
|
| 204 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 205 |
+
sample = self.mid_block(
|
| 206 |
+
sample,
|
| 207 |
+
emb,
|
| 208 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 209 |
+
attention_mask=attention_mask,
|
| 210 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 211 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
sample = self.mid_block(sample, emb)
|
| 215 |
+
|
| 216 |
+
if (
|
| 217 |
+
is_adapter
|
| 218 |
+
and len(down_intrablock_additional_residuals) > 0
|
| 219 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
| 220 |
+
):
|
| 221 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 222 |
+
|
| 223 |
+
if is_controlnet:
|
| 224 |
+
sample = sample + mid_block_additional_residual
|
| 225 |
+
|
| 226 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 227 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 228 |
+
|
| 229 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 230 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 231 |
+
|
| 232 |
+
if not is_final_block and forward_upsample_size:
|
| 233 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 234 |
+
|
| 235 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 236 |
+
sample = upsample_block(
|
| 237 |
+
hidden_states=sample,
|
| 238 |
+
temb=emb,
|
| 239 |
+
res_hidden_states_tuple=res_samples,
|
| 240 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 241 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 242 |
+
upsample_size=upsample_size,
|
| 243 |
+
attention_mask=attention_mask,
|
| 244 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
sample = upsample_block(
|
| 248 |
+
hidden_states=sample,
|
| 249 |
+
temb=emb,
|
| 250 |
+
res_hidden_states_tuple=res_samples,
|
| 251 |
+
upsample_size=upsample_size,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if self.conv_norm_out:
|
| 255 |
+
sample = self.conv_norm_out(sample)
|
| 256 |
+
sample = self.conv_act(sample)
|
| 257 |
+
sample = self.conv_out(sample)
|
| 258 |
+
|
| 259 |
+
if USE_PEFT_BACKEND:
|
| 260 |
+
unscale_lora_layers(self, lora_scale)
|
| 261 |
+
|
| 262 |
+
if not return_dict:
|
| 263 |
+
return (sample,)
|
| 264 |
+
|
| 265 |
+
return UNet2DConditionOutput(sample=sample)
|
vae/config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.17.0",
|
| 4 |
+
"_name_or_path": "stabilityai/stable-diffusion-2-1",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 4,
|
| 20 |
+
"layers_per_block": 2,
|
| 21 |
+
"norm_num_groups": 32,
|
| 22 |
+
"out_channels": 3,
|
| 23 |
+
"sample_size": 768,
|
| 24 |
+
"scaling_factor": 0.18215,
|
| 25 |
+
"up_block_types": [
|
| 26 |
+
"UpDecoderBlock2D",
|
| 27 |
+
"UpDecoderBlock2D",
|
| 28 |
+
"UpDecoderBlock2D",
|
| 29 |
+
"UpDecoderBlock2D"
|
| 30 |
+
]
|
| 31 |
+
}
|
vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
|
| 3 |
+
size 167335342
|