BiliSakura commited on
Commit
d295ca1
·
verified ·
1 Parent(s): f265a49

Add files using upload-large-folder tool

Browse files
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: diffusers
4
+ tags:
5
+ - aerogen
6
+ - remote-sensing
7
+ - object-detection
8
+ - latent-diffusion
9
+ - bounding-box
10
+ - arxiv:2411.15497
11
+ pipeline_tag: text-to-image
12
+ language:
13
+ - en
14
+ ---
15
+
16
+ # BiliSakura/AeroGen
17
+
18
+ **Aerial image generation** conditioned on bounding boxes (horizontal or rotated) and object categories. AeroGen is the first model to simultaneously support horizontal and rotated bounding box condition generation for remote sensing imagery.
19
+
20
+ Converted to diffusers format. **Self-contained** — no external code repo needed; all required code is bundled.
21
+
22
+ ## Model Details
23
+
24
+ - **Model type**: Latent diffusion with UNet + VAE + CLIP text encoder + RBoxEncoder (condition encoder)
25
+ - **Conditioning**: Bounding boxes (8 coords for rotated, 4 for axis-aligned), category CLIP embeddings, spatial masks
26
+ - **Scheduler**: DDIMScheduler, 1000 steps, scaled_linear
27
+ - **Output**: 512×512 RGB aerial images
28
+ - **License**: Apache 2.0
29
+
30
+ ### Repository Structure
31
+
32
+ | Component | Path |
33
+ |-------------------|----------------------|
34
+ | Pipeline | `pipeline.py` |
35
+ | UNet | `unet/` |
36
+ | VAE | `vae/` |
37
+ | Text encoder | `text_encoder/` |
38
+ | Condition encoder | `condition_encoder/` |
39
+ | Scheduler | `scheduler/` |
40
+ | Config | `model_index.json` |
41
+
42
+ ## Inference
43
+
44
+ **Dependencies:** `pip install diffusers transformers torch einops safetensors pyyaml`
45
+
46
+ ```python
47
+ from diffusers import DiffusionPipeline
48
+ import torch
49
+
50
+ pipe = DiffusionPipeline.from_pretrained(
51
+ "BiliSakura/AeroGen",
52
+ trust_remote_code=True,
53
+ )
54
+ pipe = pipe.to("cuda")
55
+ ```
56
+
57
+ ### Conditioning Format
58
+
59
+ | Input | Shape | Description |
60
+ |---------------------|--------------|----------------------------------------------------------|
61
+ | `bboxes` | (B, N, 8) | Rotated box corners [x1,y1,x2,y2,x3,y3,x4,y4], normalized |
62
+ | `bboxes` | (B, N, 4) | Axis-aligned [x1,y1,x2,y2], normalized |
63
+ | `category_conditions` | (B, N, 768) | CLIP text embeddings per object (e.g. encode class name) |
64
+ | `mask_conditions` | (B, N, 64, 64) | Spatial mask per object (64×64 for 512px output) |
65
+ | `mask_vector` | (B, N) | 1 = valid object, 0 = padding |
66
+
67
+ For layout preparation and DIOR-R format, see the [original AeroGen repo](https://github.com/Sonettoo/AeroGen).
68
+
69
+ ## Model Sources
70
+
71
+ - **Source**: [Sonetto702/AeroGen](https://huggingface.co/Sonetto702/AeroGen)
72
+ - **Paper**: [AeroGen: Enhancing Remote Sensing Object Detection with Diffusion-Driven Data Generation](https://arxiv.org/abs/2411.15497)
73
+ - **Original repo**: [Sonettoo/AeroGen](https://github.com/Sonettoo/AeroGen)
74
+ - **Conversion**: Checkpoint converted to diffusers format (self-contained, no external repo)
75
+
76
+ ## Citation
77
+
78
+ ```bibtex
79
+ @inproceedings{tangAeroGenEnhancingRemote2025,
80
+ title = {{{AeroGen}}: {{Enhancing Remote Sensing Object Detection}} with {{Diffusion-Driven Data Generation}}},
81
+ shorttitle = {{{AeroGen}}},
82
+ booktitle = {{{CVPR}}},
83
+ author = {Tang, Datao and Cao, Xiangyong and Wu, Xuan and Li, Jialin and Yao, Jing and Bai, Xueru and Jiang, Dongsheng and Li, Yin and Meng, Deyu},
84
+ year = 2025,
85
+ pages = {3614--3624},
86
+ urldate = {2025-11-20}
87
+ }
88
+ ```
condition_encoder/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """AeroGen condition encoder (RBoxEncoder)."""
2
+
3
+ from .rbox_encoder import RBoxEncoder
4
+
5
+ __all__ = ["RBoxEncoder"]
condition_encoder/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "target": "condition_encoder.rbox_encoder.RBoxEncoder",
3
+ "params": {
4
+ "in_dim": 768,
5
+ "out_dim": 768
6
+ }
7
+ }
condition_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb07185ec37c39776b9ab1bd2ebeb6483bf70ec441161b0661f5ceed3b7b5972
3
+ size 4467872
condition_encoder/rbox_encoder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RBoxEncoder - pure PyTorch, no ldm/bldm dependency.
3
+
4
+ Encodes rotated bounding boxes (8 coords) with Fourier embedding and text embeddings.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class FourierEmbedder:
12
+ def __init__(self, num_freqs=64, temperature=100):
13
+ self.num_freqs = num_freqs
14
+ self.temperature = temperature
15
+ self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
16
+
17
+ @torch.no_grad()
18
+ def __call__(self, x, cat_dim=-1):
19
+ out = []
20
+ for freq in self.freq_bands:
21
+ out.append(torch.sin(freq * x))
22
+ out.append(torch.cos(freq * x))
23
+ return torch.cat(out, cat_dim)
24
+
25
+
26
+ class RBoxEncoder(nn.Module):
27
+ """Encoder for rotated bounding boxes (8 coords) with text embeddings."""
28
+
29
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
30
+ super().__init__()
31
+ self.in_dim = in_dim
32
+ self.out_dim = out_dim
33
+
34
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
35
+ self.position_dim = fourier_freqs * 2 * 8 # 2 is sin&cos, 8 is xyxyxyxy
36
+
37
+ self.linears = nn.Sequential(
38
+ nn.Linear(self.in_dim + self.position_dim, 512),
39
+ nn.SiLU(),
40
+ nn.Linear(512, 512),
41
+ nn.SiLU(),
42
+ nn.Linear(512, out_dim),
43
+ )
44
+
45
+ self.null_text_feature = nn.Parameter(torch.zeros([self.in_dim]))
46
+ self.null_position_feature = nn.Parameter(torch.zeros([self.position_dim]))
47
+
48
+ def forward(self, boxes=None, masks=None, text_embeddings=None, **kwargs):
49
+ # Pipeline passes boxes=[bboxes], masks=[mask_vector], text_embeddings=[category_conditions]
50
+ boxes = (boxes or kwargs.get("boxes", [[]]))[0]
51
+ masks = (masks or kwargs.get("masks", [[]]))[0]
52
+ text_embeddings = (text_embeddings or kwargs.get("text_embeddings", [[]]))[0]
53
+
54
+ B, N, _ = boxes.shape
55
+ masks = masks.unsqueeze(-1)
56
+
57
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*8 --> B*N*C
58
+
59
+ text_null = self.null_text_feature.view(1, 1, -1)
60
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
61
+
62
+ text_embeddings = text_embeddings * masks + (1 - masks) * text_null
63
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
64
+
65
+ objs = self.linears(torch.cat([text_embeddings, xyxy_embedding], dim=-1))
66
+ assert objs.shape == torch.Size([B, N, self.out_dim])
67
+ return objs
model_index.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline", "AeroGenPipeline"],
3
+ "_diffusers_version": "0.25.0",
4
+ "condition_encoder": [
5
+ "pipeline",
6
+ "AeroGenPipeline"
7
+ ],
8
+ "scheduler": [
9
+ "diffusers",
10
+ "DDIMScheduler"
11
+ ],
12
+ "text_encoder": [
13
+ "pipeline",
14
+ "AeroGenPipeline"
15
+ ],
16
+ "unet": [
17
+ "pipeline",
18
+ "AeroGenPipeline"
19
+ ],
20
+ "vae": [
21
+ "pipeline",
22
+ "AeroGenPipeline"
23
+ ],
24
+ "scale_factor": 0.18215
25
+ }
modular_pipeline.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AeroGen modular components: scheduler config, component loading, and path setup.
3
+
4
+ Self-contained - no ldm/bldm. Scheduler is created in-code (no scheduler/ folder required).
5
+ """
6
+
7
+ import importlib
8
+ import json
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Optional, Union
12
+
13
+ from diffusers import DDIMScheduler
14
+
15
+ # Ensure model dir is on path for local module imports (unet, text_encoder, condition_encoder)
16
+ _pipeline_dir = Path(__file__).resolve().parent
17
+ if str(_pipeline_dir) not in sys.path:
18
+ sys.path.insert(0, str(_pipeline_dir))
19
+
20
+ # Default DDIM scheduler config (matches scheduler/scheduler_config.json)
21
+ DEFAULT_SCHEDULER_CONFIG = {
22
+ "num_train_timesteps": 1000,
23
+ "beta_start": 0.00085,
24
+ "beta_end": 0.012,
25
+ "beta_schedule": "scaled_linear",
26
+ "clip_sample": False,
27
+ "set_alpha_to_one": False,
28
+ "prediction_type": "epsilon",
29
+ }
30
+
31
+
32
+ def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
33
+ """Add model repo to path so local modules can be imported. Returns resolved path."""
34
+ path = Path(pretrained_model_name_or_path)
35
+ if not path.exists():
36
+ from huggingface_hub import snapshot_download
37
+ path = Path(snapshot_download(pretrained_model_name_or_path))
38
+ path = path.resolve()
39
+ s = str(path)
40
+ if s not in sys.path:
41
+ sys.path.insert(0, s)
42
+ return path
43
+
44
+
45
+ def ensure_ldm_path_from_config(config_path: str) -> None:
46
+ """Walk up from config file dir to find project root and add to path."""
47
+ d = Path(config_path).resolve().parent
48
+ for _ in range(10):
49
+ if (d / "pipeline.py").exists() or (d / "unet").is_dir():
50
+ s = str(d)
51
+ if s not in sys.path:
52
+ sys.path.insert(0, s)
53
+ return
54
+ parent = d.parent
55
+ if parent == d:
56
+ break
57
+ d = parent
58
+
59
+
60
+ def _get_class_from_string(target: str):
61
+ """Resolve class from dotted path (diffusers-style, no OmegaConf)."""
62
+ module_path, cls_name = target.rsplit(".", 1)
63
+ mod = importlib.import_module(module_path)
64
+ return getattr(mod, cls_name)
65
+
66
+
67
+ def _instantiate_from_config(config: dict):
68
+ """Instantiate from dict with 'target' and 'params' (diffusers-style, no OmegaConf)."""
69
+ if not isinstance(config, dict) or "target" not in config:
70
+ raise KeyError("Expected key 'target' to instantiate.")
71
+ cls = _get_class_from_string(config["target"])
72
+ params = dict(config.get("params") or {})
73
+ params.pop("ckpt_path", None)
74
+ params.pop("ignore_keys", None)
75
+ params.pop("target", None) # avoid passing target into constructor
76
+ return cls(**params)
77
+
78
+
79
+ def create_scheduler(model_path: Path) -> DDIMScheduler:
80
+ """Create DDIMScheduler from path/scheduler if exists, else from defaults."""
81
+ scheduler_path = model_path / "scheduler"
82
+ if scheduler_path.exists() and (scheduler_path / "scheduler_config.json").exists():
83
+ return DDIMScheduler.from_pretrained(scheduler_path)
84
+ return DDIMScheduler(**DEFAULT_SCHEDULER_CONFIG)
85
+
86
+
87
+ def load_component(model_path: Path, name: str):
88
+ """Load a custom component (unet, vae, text_encoder, condition_encoder).
89
+
90
+ VAE: Uses diffusers AutoencoderKL.from_pretrained when saved in diffusers format
91
+ (config has down_block_types, no target). Otherwise uses target/params.
92
+
93
+ When diffusers loads a single component, it passes the component subfolder path
94
+ (e.g. .../unet). We detect that and use it directly.
95
+ """
96
+ import torch
97
+ path = Path(model_path)
98
+ # Ensure model root is on sys.path for imports (unet, text_encoder, condition_encoder)
99
+ root = path.parent if path.name in ("unet", "vae", "text_encoder", "condition_encoder") and (path / "config.json").exists() else path
100
+ ensure_ldm_path(root)
101
+ # If path is already a component folder (has config.json), use it directly
102
+ if (path / "config.json").exists() and path.name in ("unet", "vae", "text_encoder", "condition_encoder"):
103
+ comp_path = path
104
+ else:
105
+ comp_path = path / name
106
+ with open(comp_path / "config.json") as f:
107
+ cfg = json.load(f)
108
+
109
+ # Diffusers native format (e.g. AutoencoderKL.save_pretrained): no "target" key
110
+ if "target" not in cfg and name == "vae":
111
+ from diffusers import AutoencoderKL
112
+ return AutoencoderKL.from_pretrained(comp_path)
113
+
114
+ component = _instantiate_from_config(cfg)
115
+ safetensors_path = comp_path / "diffusion_pytorch_model.safetensors"
116
+ bin_path = comp_path / "diffusion_pytorch_model.bin"
117
+ if safetensors_path.exists():
118
+ import safetensors.torch
119
+ state = safetensors.torch.load_file(str(safetensors_path))
120
+ elif bin_path.exists():
121
+ try:
122
+ state = torch.load(str(bin_path), map_location="cpu", weights_only=True)
123
+ except TypeError:
124
+ state = torch.load(str(bin_path), map_location="cpu")
125
+ else:
126
+ raise FileNotFoundError(
127
+ f"No weights in {comp_path} "
128
+ "(expected diffusion_pytorch_model.safetensors or .bin)"
129
+ )
130
+ # UNet: AeroGenUNet2DConditionModel wraps UNetModel in self.model, so expects "model.xxx" keys.
131
+ # Older checkpoints may have been saved without the "model." prefix.
132
+ if name == "unet" and state and not any(k.startswith("model.") for k in state.keys()):
133
+ state = {"model." + k: v for k, v in state.items()}
134
+ component.load_state_dict(state, strict=True)
135
+ component.eval()
136
+ return component
137
+
138
+
139
+ def load_components(
140
+ model_path: Union[str, Path],
141
+ ) -> dict:
142
+ """Load all pipeline components. Returns dict with unet, vae, text_encoder, condition_encoder, scheduler, scale_factor."""
143
+ path = Path(ensure_ldm_path(model_path))
144
+ # If path points to a component subfolder (e.g. .../unet), use parent as model root
145
+ if path.name in ("unet", "vae", "text_encoder", "condition_encoder") and (path / "config.json").exists():
146
+ path = path.parent
147
+ scheduler = create_scheduler(path)
148
+ unet = load_component(path, "unet")
149
+ vae = load_component(path, "vae")
150
+ text_encoder = load_component(path, "text_encoder")
151
+ condition_encoder = load_component(path, "condition_encoder")
152
+
153
+ scale_factor = 0.18215
154
+ model_index_path = path / "model_index.json"
155
+ if model_index_path.exists():
156
+ with open(model_index_path) as f:
157
+ model_index = json.load(f)
158
+ scale_factor = model_index.get("scale_factor", scale_factor)
159
+
160
+ return {
161
+ "unet": unet,
162
+ "vae": vae,
163
+ "text_encoder": text_encoder,
164
+ "condition_encoder": condition_encoder,
165
+ "scheduler": scheduler,
166
+ "scale_factor": scale_factor,
167
+ }
pipeline.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AeroGen Pipeline using native HuggingFace Diffusers.
3
+
4
+ This module provides a DiffusionPipeline subclass that wraps AeroGen's
5
+ custom UNet, condition encoder, VAE, and text encoder into a standard
6
+ diffusers pipeline interface, using DDIMScheduler for the denoising loop.
7
+
8
+ Usage:
9
+ # Load from config + checkpoint
10
+ pipeline = AeroGenPipeline.from_pretrained_checkpoint(
11
+ config_path="configs/.../v1-finetune-DIOR-R.yaml",
12
+ checkpoint_path="./ckpt/aerogen_diorr_last.ckpt",
13
+ )
14
+
15
+ # Load from diffusers-format (after convert_to_diffusers.py)
16
+ pipeline = AeroGenPipeline.from_pretrained("/path/to/AeroGen")
17
+ """
18
+
19
+ import json
20
+ import os
21
+ import sys
22
+ from dataclasses import dataclass
23
+ from pathlib import Path
24
+ from typing import List, Optional, Union
25
+
26
+ # Ensure model repo is on path for trust_remote_code / custom_pipeline loading
27
+ _pipeline_dir = Path(__file__).resolve().parent
28
+ if str(_pipeline_dir) not in sys.path:
29
+ sys.path.insert(0, str(_pipeline_dir))
30
+
31
+ import einops
32
+ import numpy as np
33
+ import torch
34
+ import yaml
35
+ from diffusers import DDIMScheduler, DiffusionPipeline
36
+ from diffusers.utils import BaseOutput
37
+ from PIL import Image
38
+
39
+ from modular_pipeline import (
40
+ ensure_ldm_path,
41
+ ensure_ldm_path_from_config,
42
+ load_component,
43
+ load_components,
44
+ create_scheduler,
45
+ _instantiate_from_config,
46
+ )
47
+
48
+
49
+ @dataclass
50
+ class AeroGenPipelineOutput(BaseOutput):
51
+ """Output class for AeroGen pipeline.
52
+
53
+ Attributes:
54
+ images: List of generated PIL images.
55
+ """
56
+
57
+ images: List[Image.Image]
58
+
59
+
60
+ class AeroGenPipeline(DiffusionPipeline):
61
+ """Pipeline for AeroGen: conditional aerial image generation with
62
+ bounding box and category controls.
63
+
64
+ This pipeline wraps AeroGen's custom components (UNet, condition encoder,
65
+ VAE, text encoder) and uses a native diffusers DDIMScheduler for the
66
+ denoising loop, replacing the original custom DDIM sampler.
67
+
68
+ Args:
69
+ unet: The custom UNet model (openaimodel_bbox_v2.UNetModel).
70
+ scheduler: A diffusers DDIMScheduler instance.
71
+ vae: The VAE model (AutoencoderKL) for latent encoding/decoding.
72
+ text_encoder: The frozen CLIP text encoder for prompt conditioning.
73
+ condition_encoder: The RBoxEncoder or BoxEncoder for bbox conditioning.
74
+ scale_factor: VAE latent scale factor (default: 0.18215 for SD 1.x).
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ unet: torch.nn.Module,
80
+ scheduler: DDIMScheduler,
81
+ vae: torch.nn.Module,
82
+ text_encoder: torch.nn.Module,
83
+ condition_encoder: torch.nn.Module,
84
+ scale_factor: float = 0.18215,
85
+ ):
86
+ super().__init__()
87
+ self.register_modules(
88
+ unet=unet,
89
+ scheduler=scheduler,
90
+ vae=vae,
91
+ text_encoder=text_encoder,
92
+ condition_encoder=condition_encoder,
93
+ )
94
+ self.vae_scale_factor = scale_factor
95
+
96
+ @property
97
+ def device(self) -> torch.device:
98
+ """Return the device of the pipeline's first nn.Module parameter."""
99
+ for module in [self.unet, self.vae, self.text_encoder, self.condition_encoder]:
100
+ if isinstance(module, torch.nn.Module):
101
+ params = list(module.parameters())
102
+ if params:
103
+ return params[0].device
104
+ return torch.device("cpu")
105
+
106
+ @property
107
+ def _execution_device(self) -> torch.device:
108
+ return self.device
109
+
110
+ @classmethod
111
+ def from_pretrained_checkpoint(
112
+ cls,
113
+ config_path: str,
114
+ checkpoint_path: str,
115
+ device: str = "cuda",
116
+ ) -> "AeroGenPipeline":
117
+ """Load an AeroGenPipeline from a YAML config and checkpoint.
118
+
119
+ DEPRECATED: ldm/bldm have been removed. Use from_pretrained() with a
120
+ diffusers-format model (converted via convert_to_diffusers_lowvram.py).
121
+ """
122
+ raise NotImplementedError(
123
+ "from_pretrained_checkpoint is no longer supported (ldm/bldm removed). "
124
+ "Use AeroGenPipeline.from_pretrained() with a diffusers-format model."
125
+ )
126
+
127
+ @classmethod
128
+ def from_pretrained(
129
+ cls,
130
+ pretrained_model_name_or_path: Union[str, Path],
131
+ device: Optional[Union[str, torch.device]] = None,
132
+ subfolder: Optional[str] = None,
133
+ **kwargs,
134
+ ) -> Union["AeroGenPipeline", torch.nn.Module]:
135
+ """Load AeroGenPipeline from a diffusers-format directory.
136
+
137
+ Supports native diffusers loading via DiffusionPipeline.from_pretrained(..., trust_remote_code=True).
138
+ When subfolder is provided (e.g. by diffusers for component loading), returns only that component.
139
+
140
+ Args:
141
+ pretrained_model_name_or_path: Path to the diffusers-format
142
+ directory or HuggingFace repo ID.
143
+ device: Device to load the model onto.
144
+ subfolder: If set, load only this component (unet, vae, text_encoder, condition_encoder).
145
+
146
+ Returns:
147
+ An AeroGenPipeline instance, or a single component if subfolder is set.
148
+ """
149
+ path = Path(ensure_ldm_path(pretrained_model_name_or_path))
150
+
151
+ # Single-component loading (for diffusers trust_remote_code component loading)
152
+ subfolder = kwargs.pop("subfolder", subfolder)
153
+ if subfolder in ("unet", "vae", "text_encoder", "condition_encoder"):
154
+ return load_component(path, subfolder)
155
+
156
+ # When diffusers loads a component, it passes the component subfolder path directly
157
+ if path.name in ("unet", "vae", "text_encoder", "condition_encoder") and (path / "config.json").exists():
158
+ ensure_ldm_path(path.parent) # Ensure model root is on sys.path for imports
159
+ return load_component(path.parent, path.name)
160
+
161
+ # Ensure we have model root (diffusers may pass a subfolder when loading full pipeline)
162
+ if not (path / "model_index.json").exists():
163
+ for _ in range(5):
164
+ parent = path.parent
165
+ if (parent / "model_index.json").exists():
166
+ path = parent
167
+ break
168
+ if parent == path:
169
+ break
170
+ path = parent
171
+
172
+ components = load_components(path)
173
+ pipe = cls(
174
+ unet=components["unet"],
175
+ scheduler=components["scheduler"],
176
+ vae=components["vae"],
177
+ text_encoder=components["text_encoder"],
178
+ condition_encoder=components["condition_encoder"],
179
+ scale_factor=components["scale_factor"],
180
+ )
181
+
182
+ if device is not None:
183
+ pipe = pipe.to(device)
184
+ return pipe
185
+
186
+ def _encode_prompt(self, prompt: Union[str, List[str]]) -> torch.Tensor:
187
+ """Encode text prompt(s) using the frozen CLIP text encoder."""
188
+ if isinstance(prompt, str):
189
+ prompt = [prompt]
190
+ return self.text_encoder.encode(prompt)
191
+
192
+ def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
193
+ """Decode latent representations using the VAE."""
194
+ latents = (1.0 / self.vae_scale_factor) * latents
195
+ image = self.vae.decode(latents)
196
+ return image
197
+
198
+ @torch.no_grad()
199
+ def __call__(
200
+ self,
201
+ prompt: Union[str, List[str]],
202
+ bboxes: torch.Tensor,
203
+ category_conditions: torch.Tensor,
204
+ mask_conditions: torch.Tensor,
205
+ mask_vector: torch.Tensor,
206
+ num_inference_steps: int = 50,
207
+ guidance_scale: float = 7.5,
208
+ eta: float = 0.2,
209
+ height: int = 512,
210
+ width: int = 512,
211
+ num_images_per_prompt: int = 1,
212
+ generator: Optional[torch.Generator] = None,
213
+ output_type: str = "pil",
214
+ ) -> AeroGenPipelineOutput:
215
+ """Generate aerial images conditioned on bounding boxes and categories.
216
+
217
+ Args:
218
+ prompt: Text prompt(s) describing the aerial scene.
219
+ bboxes: Bounding box coordinates tensor of shape (B, N, 8) for
220
+ rotated boxes or (B, N, 4) for axis-aligned boxes.
221
+ category_conditions: Category embedding tensor of shape
222
+ (B, N, 768).
223
+ mask_conditions: Spatial mask tensor of shape (B, N, H, W).
224
+ mask_vector: Binary vector indicating valid objects, shape (B, N).
225
+ num_inference_steps: Number of DDIM denoising steps.
226
+ guidance_scale: Classifier-free guidance scale. Values > 1.0
227
+ enable guidance.
228
+ eta: DDIM eta parameter controlling stochasticity.
229
+ height: Output image height (must be divisible by 8).
230
+ width: Output image width (must be divisible by 8).
231
+ num_images_per_prompt: Number of images to generate per prompt.
232
+ generator: Optional torch.Generator for reproducibility.
233
+ output_type: Output format, either "pil" for PIL images or
234
+ "tensor" for raw image tensors.
235
+
236
+ Returns:
237
+ AeroGenPipelineOutput with the generated images.
238
+ """
239
+ device = self._execution_device
240
+
241
+ if isinstance(prompt, str):
242
+ prompt = [prompt]
243
+ batch_size = len(prompt)
244
+
245
+ # Repeat conditions for num_images_per_prompt
246
+ if num_images_per_prompt > 1:
247
+ prompt = prompt * num_images_per_prompt
248
+ bboxes = torch.cat(
249
+ [bboxes] * num_images_per_prompt, dim=0
250
+ )
251
+ category_conditions = torch.cat(
252
+ [category_conditions] * num_images_per_prompt, dim=0
253
+ )
254
+ mask_conditions = torch.cat(
255
+ [mask_conditions] * num_images_per_prompt, dim=0
256
+ )
257
+ mask_vector = torch.cat(
258
+ [mask_vector] * num_images_per_prompt, dim=0
259
+ )
260
+
261
+ total_batch = batch_size * num_images_per_prompt
262
+
263
+ # 1. Encode text prompts
264
+ text_embeddings = self._encode_prompt(prompt)
265
+
266
+ # 2. Encode unconditional prompt for CFG
267
+ if guidance_scale > 1.0:
268
+ uncond_embeddings = self._encode_prompt([""] * total_batch)
269
+
270
+ # 3. Move conditions to device
271
+ bboxes = bboxes.to(device).float()
272
+ category_conditions = category_conditions.to(device).float()
273
+ mask_conditions = mask_conditions.to(device).float()
274
+ mask_vector = mask_vector.to(device).float()
275
+
276
+ # 4. Encode bbox conditions
277
+ control = self.condition_encoder(
278
+ text_embeddings=[category_conditions],
279
+ masks=[mask_vector],
280
+ boxes=[bboxes],
281
+ )
282
+
283
+ # 5. Prepare latent noise
284
+ latent_shape = (
285
+ total_batch,
286
+ 4,
287
+ height // 8,
288
+ width // 8,
289
+ )
290
+ latents = torch.randn(
291
+ latent_shape, device=device, generator=generator
292
+ )
293
+
294
+ # 6. Set up scheduler timesteps
295
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
296
+
297
+ # 7. Scale initial noise by scheduler init_noise_sigma
298
+ latents = latents * self.scheduler.init_noise_sigma
299
+
300
+ # 8. Denoising loop
301
+ for t in self.scheduler.timesteps:
302
+ timesteps = torch.full(
303
+ (total_batch,), t, device=device, dtype=torch.long
304
+ )
305
+
306
+ if guidance_scale > 1.0:
307
+ # Classifier-free guidance: run model twice
308
+ latent_input = torch.cat([latents, latents], dim=0)
309
+ timestep_input = torch.cat([timesteps, timesteps], dim=0)
310
+
311
+ context_in = torch.cat(
312
+ [uncond_embeddings, text_embeddings], dim=0
313
+ )
314
+ control_in = torch.cat([control, control], dim=0)
315
+ category_in = [
316
+ torch.cat(
317
+ [category_conditions, category_conditions], dim=0
318
+ )
319
+ ]
320
+ mask_in = [
321
+ torch.cat(
322
+ [mask_conditions, mask_conditions], dim=0
323
+ )
324
+ ]
325
+
326
+ noise_pred = self.unet(
327
+ x=latent_input,
328
+ timesteps=timestep_input,
329
+ context=context_in,
330
+ control=control_in,
331
+ category_control=category_in,
332
+ mask_control=mask_in,
333
+ )
334
+
335
+ noise_uncond, noise_text = noise_pred.chunk(2)
336
+ noise_pred = noise_uncond + guidance_scale * (
337
+ noise_text - noise_uncond
338
+ )
339
+ else:
340
+ noise_pred = self.unet(
341
+ x=latents,
342
+ timesteps=timesteps,
343
+ context=text_embeddings,
344
+ control=control,
345
+ category_control=[category_conditions],
346
+ mask_control=[mask_conditions],
347
+ )
348
+
349
+ # Use diffusers scheduler step
350
+ scheduler_output = self.scheduler.step(
351
+ model_output=noise_pred,
352
+ timestep=t,
353
+ sample=latents,
354
+ eta=eta,
355
+ generator=generator,
356
+ )
357
+ latents = scheduler_output.prev_sample
358
+
359
+ # 9. Decode latents
360
+ images = self._decode_latents(latents)
361
+
362
+ # 10. Post-process
363
+ if output_type == "pil":
364
+ images = (
365
+ einops.rearrange(images, "b c h w -> b h w c") * 127.5 + 127.5
366
+ )
367
+ images = images.cpu().numpy().clip(0, 255).astype(np.uint8)
368
+ images = [Image.fromarray(img) for img in images]
369
+ elif output_type == "tensor":
370
+ images = images.cpu()
371
+ else:
372
+ raise ValueError(
373
+ f"Unknown output_type '{output_type}'. "
374
+ "Use 'pil' or 'tensor'."
375
+ )
376
+
377
+ return AeroGenPipelineOutput(images=images)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.37.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 0,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """AeroGen text encoder (CLIP)."""
2
+
3
+ from .clip_text_encoder import AeroGenCLIPTextEncoder
4
+
5
+ __all__ = ["AeroGenCLIPTextEncoder"]
text_encoder/clip_text_encoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLIP text encoder for AeroGen. Uses transformers only (no ldm)."""
2
+
3
+ import torch.nn as nn
4
+ from transformers import CLIPTokenizer, CLIPTextModel
5
+
6
+
7
+ class AeroGenCLIPTextEncoder(nn.Module):
8
+ """CLIP text encoder compatible with FrozenCLIPEmbedder interface.
9
+ Uses transformers CLIPTextModel + CLIPTokenizer. No ldm dependency.
10
+ """
11
+
12
+ def __init__(self, version: str = "openai/clip-vit-large-patch14", device: str = "cuda", max_length: int = 77):
13
+ super().__init__()
14
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
15
+ self.transformer = CLIPTextModel.from_pretrained(version)
16
+ self.device = device
17
+ self.max_length = max_length
18
+ self.freeze()
19
+
20
+ def freeze(self):
21
+ self.transformer = self.transformer.eval()
22
+ for param in self.parameters():
23
+ param.requires_grad = False
24
+
25
+ def forward(self, text):
26
+ if isinstance(text, str):
27
+ text = [text]
28
+ batch_encoding = self.tokenizer(
29
+ text,
30
+ truncation=True,
31
+ max_length=self.max_length,
32
+ return_length=True,
33
+ return_overflowing_tokens=False,
34
+ padding="max_length",
35
+ return_tensors="pt",
36
+ )
37
+ device = next(self.parameters()).device
38
+ tokens = batch_encoding["input_ids"].to(device)
39
+ outputs = self.transformer(input_ids=tokens)
40
+ return outputs.last_hidden_state
41
+
42
+ def encode(self, text):
43
+ return self(text)
text_encoder/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "target": "text_encoder.clip_text_encoder.AeroGenCLIPTextEncoder",
3
+ "params": {
4
+ "version": "openai/clip-vit-large-patch14"
5
+ }
6
+ }
text_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:651247bce4134453769880497b0ff59124fe047ee7cd7c91ed55308e6503195d
3
+ size 492267488
unet/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """AeroGen UNet components."""
2
+
3
+ from .unet_aerogen import AeroGenUNet2DConditionModel
4
+
5
+ __all__ = ["AeroGenUNet2DConditionModel"]
unet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (286 Bytes). View file
 
unet/__pycache__/attention_dual.cpython-312.pyc ADDED
Binary file (9.96 kB). View file
 
unet/__pycache__/diffusion_util.cpython-312.pyc ADDED
Binary file (5.76 kB). View file
 
unet/__pycache__/mask_attention.cpython-312.pyc ADDED
Binary file (8.41 kB). View file
 
unet/__pycache__/openaimodel_bbox.cpython-312.pyc ADDED
Binary file (31.3 kB). View file
 
unet/__pycache__/unet_aerogen.cpython-312.pyc ADDED
Binary file (2.47 kB). View file
 
unet/attention_dual.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SpatialTransformer with control support - self-contained, no ldm."""
2
+
3
+ from inspect import isfunction
4
+ import math
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, einsum
8
+ from einops import rearrange, repeat
9
+
10
+ from .diffusion_util import checkpoint
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ if exists(val):
19
+ return val
20
+ return d() if isfunction(d) else d
21
+
22
+
23
+ def Normalize(in_channels):
24
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
25
+
26
+
27
+ class GEGLU(nn.Module):
28
+ def __init__(self, dim_in, dim_out):
29
+ super().__init__()
30
+ self.proj = nn.Linear(dim_in, dim_out * 2)
31
+
32
+ def forward(self, x):
33
+ x, gate = self.proj(x).chunk(2, dim=-1)
34
+ return x * F.gelu(gate)
35
+
36
+
37
+ class FeedForward(nn.Module):
38
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
39
+ super().__init__()
40
+ inner_dim = int(dim * mult)
41
+ dim_out = default(dim_out, dim)
42
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
43
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
44
+
45
+ def forward(self, x):
46
+ return self.net(x)
47
+
48
+
49
+ def zero_module(module):
50
+ for p in module.parameters():
51
+ p.detach().zero_()
52
+ return module
53
+
54
+
55
+ class CrossAttention(nn.Module):
56
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
57
+ super().__init__()
58
+ inner_dim = dim_head * heads
59
+ context_dim = default(context_dim, query_dim)
60
+ self.scale = dim_head ** -0.5
61
+ self.heads = heads
62
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
63
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
64
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
65
+ self.to_k_control = nn.Linear(context_dim, inner_dim, bias=False)
66
+ self.to_v_control = nn.Linear(context_dim, inner_dim, bias=False)
67
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
68
+
69
+ def forward(self, x, context=None, control=None, mask=None, lambda_=1):
70
+ h = self.heads
71
+ q = self.to_q(x)
72
+ context = default(context, x)
73
+ k = self.to_k(context)
74
+ v = self.to_v(context)
75
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
76
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
77
+ if exists(mask):
78
+ k_control = self.to_k_control(control)
79
+ v_control = self.to_v_control(control)
80
+ k_control, v_control = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (k_control, v_control))
81
+ sim_control = einsum("b i d, b j d -> b i j", q, k_control) * self.scale
82
+ attn_control = sim_control.softmax(dim=-1)
83
+ out_control = einsum("b i j, b j d -> b i d", attn_control, v_control)
84
+ out_control = rearrange(out_control, "(b h) n d -> b n (h d)", h=h)
85
+ attn = sim.softmax(dim=-1)
86
+ out = einsum("b i j, b j d -> b i d", attn, v)
87
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
88
+ if exists(mask):
89
+ out = out + lambda_ * out_control
90
+ return self.to_out(out)
91
+
92
+
93
+ class BasicTransformerBlock(nn.Module):
94
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint_use=True):
95
+ super().__init__()
96
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)
97
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
98
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout)
99
+ self.norm1 = nn.LayerNorm(dim)
100
+ self.norm2 = nn.LayerNorm(dim)
101
+ self.norm3 = nn.LayerNorm(dim)
102
+ self.checkpoint_use = checkpoint_use
103
+
104
+ def forward(self, x, context=None, control=None, mask=None):
105
+ return checkpoint(self._forward, (x, context, control, mask), self.parameters(), self.checkpoint_use)
106
+
107
+ def _forward(self, x, context=None, control=None, mask=None):
108
+ x = self.attn1(self.norm1(x)) + x
109
+ x = self.attn2(self.norm2(x), context=context, control=control, mask=mask) + x
110
+ x = self.ff(self.norm3(x)) + x
111
+ return x
112
+
113
+
114
+ class SpatialTransformer(nn.Module):
115
+ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None):
116
+ super().__init__()
117
+ self.in_channels = in_channels
118
+ inner_dim = n_heads * d_head
119
+ self.norm = Normalize(in_channels)
120
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
121
+ self.transformer_blocks = nn.ModuleList(
122
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for _ in range(depth)]
123
+ )
124
+ self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
125
+
126
+ def forward(self, x, context=None, control=None, mask=None):
127
+ b, c, h, w = x.shape
128
+ x_in = x
129
+ x = self.norm(x)
130
+ x = self.proj_in(x)
131
+ x = rearrange(x, "b c h w -> b (h w) c")
132
+ for block in self.transformer_blocks:
133
+ x = block(x, context=context, control=control, mask=mask)
134
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
135
+ x = self.proj_out(x)
136
+ return x + x_in
unet/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "target": "unet.unet_aerogen.AeroGenUNet2DConditionModel",
3
+ "params": {
4
+ "image_size": 32,
5
+ "in_channels": 4,
6
+ "out_channels": 4,
7
+ "model_channels": 320,
8
+ "attention_resolutions": [
9
+ 4,
10
+ 2,
11
+ 1
12
+ ],
13
+ "num_res_blocks": 2,
14
+ "channel_mult": [
15
+ 1,
16
+ 2,
17
+ 4,
18
+ 4
19
+ ],
20
+ "num_heads": 8,
21
+ "use_spatial_transformer": true,
22
+ "transformer_depth": 1,
23
+ "context_dim": 768,
24
+ "use_checkpoint": true,
25
+ "legacy": false
26
+ }
27
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd4aeb03ce266621e08c8c6bd1a964b44b6e0764045c7849dfd15b887a0533e7
3
+ size 3622518160
unet/diffusion_util.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diffusion util functions - self-contained, no ldm dependency."""
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import repeat
7
+
8
+
9
+ def checkpoint(func, inputs, params, flag):
10
+ if flag:
11
+ args = tuple(inputs) + tuple(params)
12
+ return _CheckpointFunction.apply(func, len(inputs), *args)
13
+ return func(*inputs)
14
+
15
+
16
+ class _CheckpointFunction(torch.autograd.Function):
17
+ @staticmethod
18
+ def forward(ctx, run_function, length, *args):
19
+ ctx.run_function = run_function
20
+ ctx.input_tensors = list(args[:length])
21
+ ctx.input_params = list(args[length:])
22
+ with torch.no_grad():
23
+ output_tensors = ctx.run_function(*ctx.input_tensors)
24
+ return output_tensors
25
+
26
+ @staticmethod
27
+ def backward(ctx, *output_grads):
28
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
29
+ with torch.enable_grad():
30
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
31
+ output_tensors = ctx.run_function(*shallow_copies)
32
+ input_grads = torch.autograd.grad(
33
+ output_tensors,
34
+ ctx.input_tensors + ctx.input_params,
35
+ output_grads,
36
+ allow_unused=True,
37
+ )
38
+ return (None, None) + input_grads
39
+
40
+
41
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
42
+ if not repeat_only:
43
+ half = dim // 2
44
+ freqs = torch.exp(
45
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
46
+ ).to(device=timesteps.device)
47
+ args = timesteps[:, None].float() * freqs[None]
48
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
49
+ if dim % 2:
50
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
51
+ else:
52
+ embedding = repeat(timesteps, "b -> b d", d=dim)
53
+ return embedding
54
+
55
+
56
+ def zero_module(module):
57
+ for p in module.parameters():
58
+ p.detach().zero_()
59
+ return module
60
+
61
+
62
+ class GroupNorm32(nn.GroupNorm):
63
+ def forward(self, x):
64
+ return super().forward(x.float()).type(x.dtype)
65
+
66
+
67
+ def normalization(channels):
68
+ return GroupNorm32(32, channels)
69
+
70
+
71
+ def conv_nd(dims, *args, **kwargs):
72
+ if dims == 1:
73
+ return nn.Conv1d(*args, **kwargs)
74
+ elif dims == 2:
75
+ return nn.Conv2d(*args, **kwargs)
76
+ elif dims == 3:
77
+ return nn.Conv3d(*args, **kwargs)
78
+ raise ValueError(f"unsupported dimensions: {dims}")
79
+
80
+
81
+ def linear(*args, **kwargs):
82
+ return nn.Linear(*args, **kwargs)
83
+
84
+
85
+ def avg_pool_nd(dims, *args, **kwargs):
86
+ if dims == 1:
87
+ return nn.AvgPool1d(*args, **kwargs)
88
+ elif dims == 2:
89
+ return nn.AvgPool2d(*args, **kwargs)
90
+ elif dims == 3:
91
+ return nn.AvgPool3d(*args, **kwargs)
92
+ raise ValueError(f"unsupported dimensions: {dims}")
unet/mask_attention.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MaskCrossAttention - self-contained, no ldm/bldm."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import einsum
7
+ from einops import rearrange, repeat
8
+
9
+ from .diffusion_util import zero_module
10
+
11
+
12
+ def exists(val):
13
+ return val is not None
14
+
15
+
16
+ def default(val, d):
17
+ return val if val is not None else d
18
+
19
+
20
+ def Normalize(in_channels):
21
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
22
+
23
+
24
+ class GEGLU(nn.Module):
25
+ def __init__(self, dim_in, dim_out):
26
+ super().__init__()
27
+ self.proj = nn.Linear(dim_in, dim_out * 2)
28
+
29
+ def forward(self, x):
30
+ x, gate = self.proj(x).chunk(2, dim=-1)
31
+ return x * F.gelu(gate)
32
+
33
+
34
+ class FeedForward(nn.Module):
35
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
36
+ super().__init__()
37
+ inner_dim = int(dim * mult)
38
+ dim_out = default(dim_out, dim)
39
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
40
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
41
+
42
+ def forward(self, x):
43
+ return self.net(x)
44
+
45
+
46
+ class CrossAttention(nn.Module):
47
+ """Simple cross-attention for MaskCrossAttention (no control branch)."""
48
+
49
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
50
+ super().__init__()
51
+ inner_dim = dim_head * heads
52
+ self.scale = dim_head ** -0.5
53
+ self.heads = heads
54
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
55
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
56
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
57
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
58
+
59
+ def forward(self, x, context=None, mask=None):
60
+ h = self.heads
61
+ q = self.to_q(x)
62
+ k = self.to_k(context)
63
+ v = self.to_v(context)
64
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
65
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
66
+ if exists(mask):
67
+ mask = rearrange(mask, "b ... -> b (...)")
68
+ max_neg_value = -torch.finfo(sim.dtype).max
69
+ mask = repeat(mask, "b j -> (b h) j (c)", h=h, c=sim.shape[-1])
70
+ sim_copy = sim.clone()
71
+ sim_copy.masked_fill_(~mask, max_neg_value)
72
+ sim = sim_copy
73
+ sim = sim.softmax(dim=-1)
74
+ out = einsum("b i j, b j d -> b i d", sim, v)
75
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
76
+ return self.to_out(out)
77
+
78
+
79
+ class MaskCrossAttention(nn.Module):
80
+ def __init__(self, in_channels, n_heads, d_head, inner_dim=320, context_dim=None):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ self.inner_dim = inner_dim
84
+ self.norm1 = Normalize(in_channels)
85
+ self.norm2 = nn.LayerNorm(inner_dim)
86
+ self.norm3 = nn.LayerNorm(inner_dim)
87
+ self.ffn = FeedForward(dim=inner_dim, dim_out=inner_dim, dropout=0.0, glu=True)
88
+ self.proj_in = nn.Linear(in_channels, 320)
89
+ self.crossattn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim)
90
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
91
+
92
+ def forward(self, x, category_control, mask_control, timesteps, attention_strength=0.2, ts_m=200):
93
+ ts = timesteps[0].item()
94
+ if ts < ts_m:
95
+ return x
96
+ x_in = x
97
+ mask_control = mask_control[0]
98
+ category_control = category_control[0]
99
+ b, c, h, w = x.shape
100
+ _, n, _, _ = mask_control.shape
101
+ x = repeat(x, "b c h w -> (b n) c h w", n=n)
102
+ mask_control_in = rearrange(mask_control, "b n h w -> (b n) h w").contiguous().bool()
103
+ category_control = rearrange(category_control.unsqueeze(2), "b n c l -> (b n) c l").contiguous()
104
+ x = rearrange(x, "(b n) c h w -> (b n) (h w) c", b=b, n=n).contiguous()
105
+ x = self.proj_in(x)
106
+ x = self.crossattn(self.norm2(x), category_control, mask_control_in) + x
107
+ x = self.ffn(self.norm3(x)) + x
108
+ x = self.proj_out(x)
109
+ x = rearrange(x, "(b n) (h w) c -> b n c h w", b=b, n=n, h=h, w=w).contiguous()
110
+ mask_control = mask_control.unsqueeze(2).expand(-1, -1, c, -1, -1)
111
+ x = x * mask_control
112
+ x_sum = x.sum(dim=1)
113
+ return attention_strength * x_sum + (1 - attention_strength) * x_in
unet/openaimodel_bbox.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # attn+maskattn+noise
2
+
3
+ from abc import abstractmethod
4
+ from functools import partial
5
+ import math
6
+ from typing import Iterable
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .diffusion_util import (
14
+ checkpoint,
15
+ conv_nd,
16
+ linear,
17
+ avg_pool_nd,
18
+ zero_module,
19
+ normalization,
20
+ timestep_embedding,
21
+ )
22
+ from .attention_dual import SpatialTransformer
23
+ from .mask_attention import MaskCrossAttention
24
+
25
+
26
+ # dummy replace
27
+ def convert_module_to_f16(x):
28
+ pass
29
+
30
+ def convert_module_to_f32(x):
31
+ pass
32
+
33
+
34
+ ## go
35
+ class AttentionPool2d(nn.Module):
36
+ """
37
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ spacial_dim: int,
43
+ embed_dim: int,
44
+ num_heads_channels: int,
45
+ output_dim: int = None,
46
+ ):
47
+ super().__init__()
48
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
49
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
50
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
51
+ self.num_heads = embed_dim // num_heads_channels
52
+ self.attention = QKVAttention(self.num_heads)
53
+
54
+ def forward(self, x):
55
+ b, c, *_spatial = x.shape
56
+ x = x.reshape(b, c, -1) # NC(HW)
57
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
58
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
59
+ x = self.qkv_proj(x)
60
+ x = self.attention(x)
61
+ x = self.c_proj(x)
62
+ return x[:, :, 0]
63
+
64
+
65
+ class TimestepBlock(nn.Module):
66
+ """
67
+ Any module where forward() takes timestep embeddings as a second argument.
68
+ """
69
+
70
+ @abstractmethod
71
+ def forward(self, x, emb):
72
+ """
73
+ Apply the module to `x` given `emb` timestep embeddings.
74
+ """
75
+
76
+
77
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
78
+ """
79
+ A sequential module that passes timestep embeddings to the children that
80
+ support it as an extra input.
81
+ """
82
+
83
+ def forward(self, x, emb, context=None, control=None, mask=None):
84
+ for layer in self:
85
+ if isinstance(layer, TimestepBlock):
86
+ x = layer(x, emb)
87
+ elif isinstance(layer, SpatialTransformer):
88
+ x = layer(x, context, control = control, mask = mask)
89
+ else:
90
+ x = layer(x)
91
+ return x
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ """
96
+ An upsampling layer with an optional convolution.
97
+ :param channels: channels in the inputs and outputs.
98
+ :param use_conv: a bool determining if a convolution is applied.
99
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
100
+ upsampling occurs in the inner-two dimensions.
101
+ """
102
+
103
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
104
+ super().__init__()
105
+ self.channels = channels
106
+ self.out_channels = out_channels or channels
107
+ self.use_conv = use_conv
108
+ self.dims = dims
109
+ if use_conv:
110
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
111
+
112
+ def forward(self, x):
113
+ assert x.shape[1] == self.channels
114
+ if self.dims == 3:
115
+ x = F.interpolate(
116
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
117
+ )
118
+ else:
119
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
120
+ if self.use_conv:
121
+ x = self.conv(x)
122
+ return x
123
+
124
+ class TransposedUpsample(nn.Module):
125
+ 'Learned 2x upsampling without padding'
126
+ def __init__(self, channels, out_channels=None, ks=5):
127
+ super().__init__()
128
+ self.channels = channels
129
+ self.out_channels = out_channels or channels
130
+
131
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
132
+
133
+ def forward(self,x):
134
+ return self.up(x)
135
+
136
+
137
+ class Downsample(nn.Module):
138
+ """
139
+ A downsampling layer with an optional convolution.
140
+ :param channels: channels in the inputs and outputs.
141
+ :param use_conv: a bool determining if a convolution is applied.
142
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
143
+ downsampling occurs in the inner-two dimensions.
144
+ """
145
+
146
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
147
+ super().__init__()
148
+ self.channels = channels
149
+ self.out_channels = out_channels or channels
150
+ self.use_conv = use_conv
151
+ self.dims = dims
152
+ stride = 2 if dims != 3 else (1, 2, 2)
153
+ if use_conv:
154
+ self.op = conv_nd(
155
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
156
+ )
157
+ else:
158
+ assert self.channels == self.out_channels
159
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
160
+
161
+ def forward(self, x):
162
+ assert x.shape[1] == self.channels
163
+ return self.op(x)
164
+
165
+
166
+ class ResBlock(TimestepBlock):
167
+ """
168
+ A residual block that can optionally change the number of channels.
169
+ :param channels: the number of input channels.
170
+ :param emb_channels: the number of timestep embedding channels.
171
+ :param dropout: the rate of dropout.
172
+ :param out_channels: if specified, the number of out channels.
173
+ :param use_conv: if True and out_channels is specified, use a spatial
174
+ convolution instead of a smaller 1x1 convolution to change the
175
+ channels in the skip connection.
176
+ :param dims: determines if the signal is 1D, 2D, or 3D.
177
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
178
+ :param up: if True, use this block for upsampling.
179
+ :param down: if True, use this block for downsampling.
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ channels,
185
+ emb_channels,
186
+ dropout,
187
+ out_channels=None,
188
+ use_conv=False,
189
+ use_scale_shift_norm=False,
190
+ dims=2,
191
+ use_checkpoint=False,
192
+ up=False,
193
+ down=False,
194
+ ):
195
+ super().__init__()
196
+ self.channels = channels
197
+ self.emb_channels = emb_channels
198
+ self.dropout = dropout
199
+ self.out_channels = out_channels or channels
200
+ self.use_conv = use_conv
201
+ self.use_checkpoint = use_checkpoint
202
+ self.use_scale_shift_norm = use_scale_shift_norm
203
+
204
+ self.in_layers = nn.Sequential(
205
+ normalization(channels),
206
+ nn.SiLU(),
207
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
208
+ )
209
+
210
+ self.updown = up or down
211
+
212
+ if up:
213
+ self.h_upd = Upsample(channels, False, dims)
214
+ self.x_upd = Upsample(channels, False, dims)
215
+ elif down:
216
+ self.h_upd = Downsample(channels, False, dims)
217
+ self.x_upd = Downsample(channels, False, dims)
218
+ else:
219
+ self.h_upd = self.x_upd = nn.Identity()
220
+
221
+ self.emb_layers = nn.Sequential(
222
+ nn.SiLU(),
223
+ linear(
224
+ emb_channels,
225
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
226
+ ),
227
+ )
228
+ self.out_layers = nn.Sequential(
229
+ normalization(self.out_channels),
230
+ nn.SiLU(),
231
+ nn.Dropout(p=dropout),
232
+ zero_module(
233
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
234
+ ),
235
+ )
236
+
237
+ if self.out_channels == channels:
238
+ self.skip_connection = nn.Identity()
239
+ elif use_conv:
240
+ self.skip_connection = conv_nd(
241
+ dims, channels, self.out_channels, 3, padding=1
242
+ )
243
+ else:
244
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
245
+
246
+ def forward(self, x, emb):
247
+ """
248
+ Apply the block to a Tensor, conditioned on a timestep embedding.
249
+ :param x: an [N x C x ...] Tensor of features.
250
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
251
+ :return: an [N x C x ...] Tensor of outputs.
252
+ """
253
+ return checkpoint(
254
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
255
+ )
256
+
257
+
258
+ def _forward(self, x, emb):
259
+ if self.updown:
260
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
261
+ h = in_rest(x)
262
+ h = self.h_upd(h)
263
+ x = self.x_upd(x)
264
+ h = in_conv(h)
265
+ else:
266
+ h = self.in_layers(x)
267
+ emb_out = self.emb_layers(emb).type(h.dtype)
268
+ while len(emb_out.shape) < len(h.shape):
269
+ emb_out = emb_out[..., None]
270
+ if self.use_scale_shift_norm:
271
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
272
+ scale, shift = th.chunk(emb_out, 2, dim=1)
273
+ h = out_norm(h) * (1 + scale) + shift
274
+ h = out_rest(h)
275
+ else:
276
+ h = h + emb_out
277
+ h = self.out_layers(h)
278
+ return self.skip_connection(x) + h
279
+
280
+
281
+ class AttentionBlock(nn.Module):
282
+ """
283
+ An attention block that allows spatial positions to attend to each other.
284
+ Originally ported from here, but adapted to the N-d case.
285
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ channels,
291
+ num_heads=1,
292
+ num_head_channels=-1,
293
+ use_checkpoint=False,
294
+ use_new_attention_order=False,
295
+ ):
296
+ super().__init__()
297
+ self.channels = channels
298
+ if num_head_channels == -1:
299
+ self.num_heads = num_heads
300
+ else:
301
+ assert (
302
+ channels % num_head_channels == 0
303
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
304
+ self.num_heads = channels // num_head_channels
305
+ self.use_checkpoint = use_checkpoint
306
+ self.norm = normalization(channels)
307
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
308
+ if use_new_attention_order:
309
+ # split qkv before split heads
310
+ self.attention = QKVAttention(self.num_heads)
311
+ else:
312
+ # split heads before split qkv
313
+ self.attention = QKVAttentionLegacy(self.num_heads)
314
+
315
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
316
+
317
+ def forward(self, x):
318
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
319
+ #return pt_checkpoint(self._forward, x) # pytorch
320
+
321
+ def _forward(self, x):
322
+ b, c, *spatial = x.shape
323
+ x = x.reshape(b, c, -1)
324
+ qkv = self.qkv(self.norm(x))
325
+ h = self.attention(qkv)
326
+ h = self.proj_out(h)
327
+ return (x + h).reshape(b, c, *spatial)
328
+
329
+
330
+ def count_flops_attn(model, _x, y):
331
+ """
332
+ A counter for the `thop` package to count the operations in an
333
+ attention operation.
334
+ Meant to be used like:
335
+ macs, params = thop.profile(
336
+ model,
337
+ inputs=(inputs, timestamps),
338
+ custom_ops={QKVAttention: QKVAttention.count_flops},
339
+ )
340
+ """
341
+ b, c, *spatial = y[0].shape
342
+ num_spatial = int(np.prod(spatial))
343
+ # We perform two matmuls with the same number of ops.
344
+ # The first computes the weight matrix, the second computes
345
+ # the combination of the value vectors.
346
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
347
+ model.total_ops += th.DoubleTensor([matmul_ops])
348
+
349
+
350
+ class QKVAttentionLegacy(nn.Module):
351
+ """
352
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
353
+ """
354
+
355
+ def __init__(self, n_heads):
356
+ super().__init__()
357
+ self.n_heads = n_heads
358
+
359
+ def forward(self, qkv):
360
+ """
361
+ Apply QKV attention.
362
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
363
+ :return: an [N x (H * C) x T] tensor after attention.
364
+ """
365
+ bs, width, length = qkv.shape
366
+ assert width % (3 * self.n_heads) == 0
367
+ ch = width // (3 * self.n_heads)
368
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
369
+ scale = 1 / math.sqrt(math.sqrt(ch))
370
+ weight = th.einsum(
371
+ "bct,bcs->bts", q * scale, k * scale
372
+ ) # More stable with f16 than dividing afterwards
373
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
374
+ a = th.einsum("bts,bcs->bct", weight, v)
375
+ return a.reshape(bs, -1, length)
376
+
377
+ @staticmethod
378
+ def count_flops(model, _x, y):
379
+ return count_flops_attn(model, _x, y)
380
+
381
+
382
+ class QKVAttention(nn.Module):
383
+ """
384
+ A module which performs QKV attention and splits in a different order.
385
+ """
386
+
387
+ def __init__(self, n_heads):
388
+ super().__init__()
389
+ self.n_heads = n_heads
390
+
391
+ def forward(self, qkv):
392
+ """
393
+ Apply QKV attention.
394
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
395
+ :return: an [N x (H * C) x T] tensor after attention.
396
+ """
397
+ bs, width, length = qkv.shape
398
+ assert width % (3 * self.n_heads) == 0
399
+ ch = width // (3 * self.n_heads)
400
+ q, k, v = qkv.chunk(3, dim=1)
401
+ scale = 1 / math.sqrt(math.sqrt(ch))
402
+ weight = th.einsum(
403
+ "bct,bcs->bts",
404
+ (q * scale).view(bs * self.n_heads, ch, length),
405
+ (k * scale).view(bs * self.n_heads, ch, length),
406
+ ) # More stable with f16 than dividing afterwards
407
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
408
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
409
+ return a.reshape(bs, -1, length)
410
+
411
+ @staticmethod
412
+ def count_flops(model, _x, y):
413
+ return count_flops_attn(model, _x, y)
414
+
415
+
416
+ class UNetModel(nn.Module):
417
+ """
418
+ The full UNet model with attention and timestep embedding.
419
+ :param in_channels: channels in the input Tensor.
420
+ :param model_channels: base channel count for the model.
421
+ :param out_channels: channels in the output Tensor.
422
+ :param num_res_blocks: number of residual blocks per downsample.
423
+ :param attention_resolutions: a collection of downsample rates at which
424
+ attention will take place. May be a set, list, or tuple.
425
+ For example, if this contains 4, then at 4x downsampling, attention
426
+ will be used.
427
+ :param dropout: the dropout probability.
428
+ :param channel_mult: channel multiplier for each level of the UNet.
429
+ :param conv_resample: if True, use learned convolutions for upsampling and
430
+ downsampling.
431
+ :param dims: determines if the signal is 1D, 2D, or 3D.
432
+ :param num_classes: if specified (as an int), then this model will be
433
+ class-conditional with `num_classes` classes.
434
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
435
+ :param num_heads: the number of attention heads in each attention layer.
436
+ :param num_heads_channels: if specified, ignore num_heads and instead use
437
+ a fixed channel width per attention head.
438
+ :param num_heads_upsample: works with num_heads to set a different number
439
+ of heads for upsampling. Deprecated.
440
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
441
+ :param resblock_updown: use residual blocks for up/downsampling.
442
+ :param use_new_attention_order: use a different attention pattern for potentially
443
+ increased efficiency.
444
+ """
445
+
446
+ def __init__(
447
+ self,
448
+ image_size,
449
+ in_channels,
450
+ model_channels,
451
+ out_channels,
452
+ num_res_blocks,
453
+ attention_resolutions,
454
+ dropout=0,
455
+ channel_mult=(1, 2, 4, 8),
456
+ conv_resample=True,
457
+ dims=2,
458
+ num_classes=None,
459
+ use_checkpoint=False,
460
+ use_fp16=False,
461
+ num_heads=-1,
462
+ num_head_channels=-1,
463
+ num_heads_upsample=-1,
464
+ use_scale_shift_norm=False,
465
+ resblock_updown=False,
466
+ use_new_attention_order=False,
467
+ use_spatial_transformer=False, # custom transformer support
468
+ transformer_depth=1, # custom transformer support
469
+ context_dim=None, # custom transformer support
470
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
471
+ legacy=True,
472
+
473
+
474
+ ):
475
+ super().__init__()
476
+ if use_spatial_transformer:
477
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
478
+
479
+ if context_dim is not None:
480
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
481
+ if hasattr(context_dim, '__iter__') and not isinstance(context_dim, (int, float, str)):
482
+ context_dim = list(context_dim)
483
+
484
+ if num_heads_upsample == -1:
485
+ num_heads_upsample = num_heads
486
+
487
+ if num_heads == -1:
488
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
489
+
490
+ if num_head_channels == -1:
491
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
492
+
493
+ self.image_size = image_size
494
+ self.in_channels = in_channels
495
+ self.model_channels = model_channels
496
+ self.out_channels = out_channels
497
+ self.num_res_blocks = num_res_blocks
498
+ self.attention_resolutions = attention_resolutions
499
+ self.dropout = dropout
500
+ self.channel_mult = channel_mult
501
+ self.conv_resample = conv_resample
502
+ self.num_classes = num_classes
503
+ self.use_checkpoint = use_checkpoint
504
+ self.dtype = th.float16 if use_fp16 else th.float32
505
+ self.num_heads = num_heads
506
+ self.num_head_channels = num_head_channels
507
+ self.num_heads_upsample = num_heads_upsample
508
+ self.predict_codebook_ids = n_embed is not None
509
+
510
+ time_embed_dim = model_channels * 4
511
+ self.time_embed = nn.Sequential(
512
+ linear(model_channels, time_embed_dim),
513
+ nn.SiLU(),
514
+ linear(time_embed_dim, time_embed_dim),
515
+ )
516
+
517
+ if self.num_classes is not None:
518
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
519
+
520
+ self.maskcrossattention=MaskCrossAttention(in_channels = 320,n_heads=8,
521
+ d_head=40,context_dim=768)
522
+
523
+ self.input_blocks = nn.ModuleList(
524
+ [
525
+ TimestepEmbedSequential(
526
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
527
+ )
528
+ ]
529
+ )
530
+ self._feature_size = model_channels
531
+ input_block_chans = [model_channels]
532
+ ch = model_channels
533
+ ds = 1
534
+ for level, mult in enumerate(channel_mult):
535
+ for _ in range(num_res_blocks):
536
+ layers = [
537
+ ResBlock(
538
+ ch,
539
+ time_embed_dim,
540
+ dropout,
541
+ out_channels=mult * model_channels,
542
+ dims=dims,
543
+ use_checkpoint=use_checkpoint,
544
+ use_scale_shift_norm=use_scale_shift_norm,
545
+ )
546
+ ]
547
+ ch = mult * model_channels
548
+ if ds in attention_resolutions:
549
+ if num_head_channels == -1:
550
+ dim_head = ch // num_heads
551
+ else:
552
+ num_heads = ch // num_head_channels
553
+ dim_head = num_head_channels
554
+ if legacy:
555
+ #num_heads = 1
556
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
557
+ layers.append(
558
+ AttentionBlock(
559
+ ch,
560
+ use_checkpoint=use_checkpoint,
561
+ num_heads=num_heads,
562
+ num_head_channels=dim_head,
563
+ use_new_attention_order=use_new_attention_order,
564
+ ) if not use_spatial_transformer else SpatialTransformer(
565
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
566
+ )
567
+ )
568
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
569
+ self._feature_size += ch
570
+ input_block_chans.append(ch)
571
+ if level != len(channel_mult) - 1:
572
+ out_ch = ch
573
+ self.input_blocks.append(
574
+ TimestepEmbedSequential(
575
+ ResBlock(
576
+ ch,
577
+ time_embed_dim,
578
+ dropout,
579
+ out_channels=out_ch,
580
+ dims=dims,
581
+ use_checkpoint=use_checkpoint,
582
+ use_scale_shift_norm=use_scale_shift_norm,
583
+ down=True,
584
+ )
585
+ if resblock_updown
586
+ else Downsample(
587
+ ch, conv_resample, dims=dims, out_channels=out_ch
588
+ )
589
+ )
590
+ )
591
+ ch = out_ch
592
+ input_block_chans.append(ch)
593
+ ds *= 2
594
+ self._feature_size += ch
595
+
596
+ if num_head_channels == -1:
597
+ dim_head = ch // num_heads
598
+ else:
599
+ num_heads = ch // num_head_channels
600
+ dim_head = num_head_channels
601
+ if legacy:
602
+ #num_heads = 1
603
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
604
+ self.middle_block = TimestepEmbedSequential(
605
+ ResBlock(
606
+ ch,
607
+ time_embed_dim,
608
+ dropout,
609
+ dims=dims,
610
+ use_checkpoint=use_checkpoint,
611
+ use_scale_shift_norm=use_scale_shift_norm,
612
+ ),
613
+ AttentionBlock(
614
+ ch,
615
+ use_checkpoint=use_checkpoint,
616
+ num_heads=num_heads,
617
+ num_head_channels=dim_head,
618
+ use_new_attention_order=use_new_attention_order,
619
+ ) if not use_spatial_transformer else SpatialTransformer(
620
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
621
+ ),
622
+ ResBlock(
623
+ ch,
624
+ time_embed_dim,
625
+ dropout,
626
+ dims=dims,
627
+ use_checkpoint=use_checkpoint,
628
+ use_scale_shift_norm=use_scale_shift_norm,
629
+ ),
630
+ )
631
+ self._feature_size += ch
632
+
633
+ self.output_blocks = nn.ModuleList([])
634
+ for level, mult in list(enumerate(channel_mult))[::-1]:
635
+ for i in range(num_res_blocks + 1):
636
+ ich = input_block_chans.pop()
637
+ layers = [
638
+ ResBlock(
639
+ ch + ich,
640
+ time_embed_dim,
641
+ dropout,
642
+ out_channels=model_channels * mult,
643
+ dims=dims,
644
+ use_checkpoint=use_checkpoint,
645
+ use_scale_shift_norm=use_scale_shift_norm,
646
+ )
647
+ ]
648
+ ch = model_channels * mult
649
+ if ds in attention_resolutions:
650
+ if num_head_channels == -1:
651
+ dim_head = ch // num_heads
652
+ else:
653
+ num_heads = ch // num_head_channels
654
+ dim_head = num_head_channels
655
+ if legacy:
656
+ #num_heads = 1
657
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
658
+ layers.append(
659
+ AttentionBlock(
660
+ ch,
661
+ use_checkpoint=use_checkpoint,
662
+ num_heads=num_heads_upsample,
663
+ num_head_channels=dim_head,
664
+ use_new_attention_order=use_new_attention_order,
665
+ ) if not use_spatial_transformer else SpatialTransformer(
666
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
667
+ )
668
+ )
669
+ if level and i == num_res_blocks:
670
+ out_ch = ch
671
+ layers.append(
672
+ ResBlock(
673
+ ch,
674
+ time_embed_dim,
675
+ dropout,
676
+ out_channels=out_ch,
677
+ dims=dims,
678
+ use_checkpoint=use_checkpoint,
679
+ use_scale_shift_norm=use_scale_shift_norm,
680
+ up=True,
681
+ )
682
+ if resblock_updown
683
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
684
+ )
685
+ ds //= 2
686
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
687
+ self._feature_size += ch
688
+
689
+ self.out = nn.Sequential(
690
+ normalization(ch),
691
+ nn.SiLU(),
692
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
693
+ )
694
+ if self.predict_codebook_ids:
695
+ self.id_predictor = nn.Sequential(
696
+ normalization(ch),
697
+ conv_nd(dims, model_channels, n_embed, 1),
698
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
699
+ )
700
+
701
+ def convert_to_fp16(self):
702
+ """
703
+ Convert the torso of the model to float16.
704
+ """
705
+ self.input_blocks.apply(convert_module_to_f16)
706
+ self.middle_block.apply(convert_module_to_f16)
707
+ self.output_blocks.apply(convert_module_to_f16)
708
+
709
+ def convert_to_fp32(self):
710
+ """
711
+ Convert the torso of the model to float32.
712
+ """
713
+ self.input_blocks.apply(convert_module_to_f32)
714
+ self.middle_block.apply(convert_module_to_f32)
715
+ self.output_blocks.apply(convert_module_to_f32)
716
+
717
+
718
+ #x=x_noisy, timesteps=t, context=control,mask_control=mask_control,masks=mask_vector
719
+
720
+ def forward(self, x, timesteps=None, context=None, control=None, category_control=None, mask_control=None, y=None,**kwargs):
721
+ """
722
+ Apply the model to an input batch.
723
+ :param x: an [N x C x ...] Tensor of inputs.
724
+ :param timesteps: a 1-D batch of timesteps.
725
+ :param context: conditioning plugged in via crossattn
726
+ :param y: an [N] Tensor of labels, if class-conditional.
727
+ :return: an [N x C x ...] Tensor of outputs.
728
+ """
729
+ assert (y is not None) == (
730
+ self.num_classes is not None
731
+ ), "must specify y if and only if the model is class-conditional"
732
+ hs = []
733
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
734
+ emb = self.time_embed(t_emb)
735
+
736
+ # add noise mask
737
+ mask = th.max(mask_control[0], dim=1).values
738
+
739
+ if self.num_classes is not None:
740
+ assert y.shape == (x.shape[0],)
741
+ emb = emb + self.label_emb(y)
742
+
743
+ h = x.type(self.dtype)
744
+ for module in self.input_blocks:
745
+ if mask_control is not None:
746
+ h = module(h, emb, context, control, mask)
747
+ h=self.maskcrossattention(h,category_control=category_control,mask_control=mask_control,timesteps=timesteps)
748
+ category_control=None
749
+ mask_control=None
750
+ else:
751
+ h = module(h, emb, context, control, mask)
752
+ hs.append(h)
753
+ h = self.middle_block(h, emb, context, control, mask)
754
+ for module in self.output_blocks:
755
+ h = th.cat([h, hs.pop()], dim=1)
756
+ h = module(h, emb, context, control, mask)
757
+ h = h.type(x.dtype)
758
+ if self.predict_codebook_ids:
759
+ return self.id_predictor(h)
760
+ else:
761
+ return self.out(h)
unet/unet_aerogen.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AeroGen UNet: diffusers ModelMixin wrapper for the custom bbox-conditioned UNet.
3
+
4
+ Self-contained - no ldm/bldm dependency. Uses local openaimodel_bbox.UNetModel.
5
+ """
6
+
7
+ from diffusers import ModelMixin
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+
10
+ from .openaimodel_bbox import UNetModel
11
+
12
+
13
+ class AeroGenUNet2DConditionModel(ModelMixin, ConfigMixin):
14
+ """
15
+ Diffusers-compatible wrapper for AeroGen's bbox-conditioned UNet.
16
+ Forward signature: x, timesteps, context, control, category_control, mask_control.
17
+ """
18
+
19
+ @register_to_config
20
+ def __init__(
21
+ self,
22
+ image_size: int = 32,
23
+ in_channels: int = 4,
24
+ out_channels: int = 4,
25
+ model_channels: int = 320,
26
+ attention_resolutions: tuple = (4, 2, 1),
27
+ num_res_blocks: int = 2,
28
+ channel_mult: tuple = (1, 2, 4, 4),
29
+ num_heads: int = 8,
30
+ use_spatial_transformer: bool = True,
31
+ transformer_depth: int = 1,
32
+ context_dim: int = 768,
33
+ use_checkpoint: bool = True,
34
+ legacy: bool = False,
35
+ **kwargs,
36
+ ):
37
+ super().__init__()
38
+ self.model = UNetModel(
39
+ image_size=image_size,
40
+ in_channels=in_channels,
41
+ model_channels=model_channels,
42
+ out_channels=out_channels,
43
+ num_res_blocks=num_res_blocks,
44
+ attention_resolutions=list(attention_resolutions),
45
+ channel_mult=list(channel_mult),
46
+ num_heads=num_heads,
47
+ use_spatial_transformer=use_spatial_transformer,
48
+ transformer_depth=transformer_depth,
49
+ context_dim=context_dim,
50
+ use_checkpoint=use_checkpoint,
51
+ legacy=legacy,
52
+ **kwargs,
53
+ )
54
+
55
+ def forward(
56
+ self,
57
+ x,
58
+ timesteps,
59
+ context=None,
60
+ control=None,
61
+ category_control=None,
62
+ mask_control=None,
63
+ **kwargs,
64
+ ):
65
+ return self.model(
66
+ x,
67
+ timesteps,
68
+ context=context,
69
+ control=control,
70
+ category_control=category_control or [],
71
+ mask_control=mask_control or [],
72
+ **kwargs,
73
+ )
vae/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "target": "ldm.models.autoencoder.AutoencoderKL",
3
+ "params": {
4
+ "embed_dim": 4,
5
+ "monitor": "val/rec_loss",
6
+ "ddconfig": {
7
+ "double_z": true,
8
+ "z_channels": 4,
9
+ "resolution": 256,
10
+ "in_channels": 3,
11
+ "out_ch": 3,
12
+ "ch": 128,
13
+ "ch_mult": [
14
+ 1,
15
+ 2,
16
+ 4,
17
+ 4
18
+ ],
19
+ "num_res_blocks": 2,
20
+ "attn_resolutions": [],
21
+ "dropout": 0.0
22
+ },
23
+ "lossconfig": {
24
+ "target": "torch.nn.Identity"
25
+ }
26
+ }
27
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3738ea388c6a1583f992ab0e164e18d8ad96b6bde143269a25cc0fc994de42b9
3
+ size 334640988