BiliSakura commited on
Commit
66a2b45
·
verified ·
1 Parent(s): 9818de9

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +127 -0
  2. __pycache__/pipeline_hsigene.cpython-312.pyc +0 -0
  3. global_content_adapter/__init__.py +5 -0
  4. global_content_adapter/__pycache__/__init__.cpython-312.pyc +0 -0
  5. global_content_adapter/__pycache__/model.cpython-312.pyc +0 -0
  6. global_content_adapter/config.json +8 -0
  7. global_content_adapter/model.py +56 -0
  8. global_text_adapter/__init__.py +5 -0
  9. global_text_adapter/config.json +4 -0
  10. global_text_adapter/model.py +52 -0
  11. local_adapter/__pycache__/attention.cpython-312.pyc +0 -0
  12. local_adapter/__pycache__/diffusion.cpython-312.pyc +0 -0
  13. local_adapter/__pycache__/model.cpython-312.pyc +0 -0
  14. local_adapter/__pycache__/utils.cpython-312.pyc +0 -0
  15. local_adapter/attention.py +271 -0
  16. local_adapter/config.json +36 -0
  17. local_adapter/diffusion.py +608 -0
  18. local_adapter/model.py +435 -0
  19. local_adapter/utils.py +90 -0
  20. metadata_encoder/__init__.py +5 -0
  21. metadata_encoder/config.json +7 -0
  22. metadata_encoder/model.py +77 -0
  23. model_index.json +14 -0
  24. modular_pipeline.py +111 -0
  25. pipeline_hsigene.py +468 -0
  26. scheduler/scheduler_config.json +19 -0
  27. text_encoder/__init__.py +1 -0
  28. text_encoder/__pycache__/__init__.cpython-312.pyc +0 -0
  29. text_encoder/__pycache__/model.cpython-312.pyc +0 -0
  30. text_encoder/config.json +4 -0
  31. text_encoder/model.py +41 -0
  32. unet/__init__.py +5 -0
  33. unet/__pycache__/__init__.cpython-312.pyc +0 -0
  34. unet/__pycache__/attention.cpython-312.pyc +0 -0
  35. unet/__pycache__/diffusion.cpython-312.pyc +0 -0
  36. unet/__pycache__/model.cpython-312.pyc +0 -0
  37. unet/__pycache__/utils.cpython-312.pyc +0 -0
  38. unet/attention.py +271 -0
  39. unet/config.json +25 -0
  40. unet/diffusion.py +608 -0
  41. unet/model.py +35 -0
  42. unet/utils.py +90 -0
  43. vae/__init__.py +1 -0
  44. vae/__pycache__/__init__.cpython-312.pyc +0 -0
  45. vae/__pycache__/model.cpython-312.pyc +0 -0
  46. vae/__pycache__/vae_blocks.cpython-312.pyc +0 -0
  47. vae/config.json +21 -0
  48. vae/model.py +90 -0
  49. vae/utils.py +10 -0
  50. vae/vae_blocks.py +441 -0
README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: diffusers
4
+ tags:
5
+ - hsigene
6
+ - hyperspectral
7
+ - latent-diffusion
8
+ - controlnet
9
+ - arxiv:2409.12470
10
+ pipeline_tag: image-to-image
11
+ ---
12
+
13
+ # BiliSakura/HSIGene
14
+
15
+ **Hyperspectral image generation** — HSIGene converted to diffusers format. Conditional generation with local controls (HED, MLSD, sketch, segmentation), global controls (content, text), and metadata embeddings. Outputs 48-band hyperspectral images (256×256 pixels).
16
+
17
+ > Source: [HSIGene](https://arxiv.org/abs/2409.12470). Converted to diffusers format; model dir is self-contained (no external project for inference).
18
+
19
+ ## Conversion
20
+
21
+ The main diffusion checkpoint (`last.ckpt`) must be downloaded from [GoogleDrive](https://drive.google.com/file/d/1euJAbsxCgG1wIu_Eh5nPfmiSP9suWsR4/view?usp=drive_link) and placed in `projects/HSIGene-Diffusers/checkpoints/`.
22
+
23
+ **Note:** `models/raw/HSIGene` contains annotator/auxiliary models (body pose, depth, SAM, etc.) only — not the main diffusion checkpoint.
24
+
25
+ ```bash
26
+ cd projects/HSIGene-Diffusers
27
+ python convert_to_diffusers.py \
28
+ --config_path configs/inference.yaml \
29
+ --ckpt_path checkpoints/last.ckpt \
30
+ --output_dir /root/worksapce/models/BiliSakura/HSIGene
31
+ ```
32
+
33
+ ## Repository Structure (after conversion)
34
+
35
+ | Component | Path |
36
+ |------------------------|--------------------------|
37
+ | UNet (LocalControlUNet)| `unet/` |
38
+ | VAE | `vae/` |
39
+ | Text encoder (CLIP) | `text_encoder/` |
40
+ | Local adapter | `local_adapter/` |
41
+ | Global content adapter| `global_content_adapter/`|
42
+ | Global text adapter | `global_text_adapter/` |
43
+ | Metadata encoder | `metadata_encoder/` |
44
+ | Scheduler | `scheduler/` |
45
+ | Pipeline | `pipeline_hsigene.py` |
46
+ | Config | `model_index.json` |
47
+
48
+ ## Usage
49
+
50
+ **Option 1 – No `sys.path.insert` (AeroGen-style):** Load the pipeline from the model path via `importlib`; the model dir is added to the path automatically.
51
+
52
+ ```python
53
+ import importlib.util
54
+ import sys
55
+
56
+ model_path = "/path/to/HSIGene" # or "BiliSakura/HSIGene" for Hub
57
+ spec = importlib.util.spec_from_file_location("pipeline_hsigene", f"{model_path}/pipeline_hsigene.py")
58
+ mod = importlib.util.module_from_spec(spec)
59
+ sys.modules["pipeline_hsigene"] = mod
60
+ spec.loader.exec_module(mod)
61
+
62
+ pipe = mod.HSIGenePipeline.from_pretrained(model_path)
63
+ pipe = pipe.to("cuda")
64
+ ```
65
+
66
+ **Option 2 – With `sys.path.insert`:** Simpler if you are fine adding the model dir to the path once.
67
+
68
+ ```python
69
+ import sys
70
+ sys.path.insert(0, "/path/to/HSIGene")
71
+ from pipeline_hsigene import HSIGenePipeline
72
+
73
+ pipe = HSIGenePipeline.from_pretrained("/path/to/HSIGene")
74
+ pipe = pipe.to("cuda")
75
+ ```
76
+
77
+ **Option 3 – `DiffusionPipeline.from_pretrained`:** May work with `trust_remote_code=True`. If you see "raw config (list)" errors (e.g. when loading from cache), use Option 1 or 2 instead.
78
+
79
+ ```python
80
+ from diffusers import DiffusionPipeline
81
+ pipe = DiffusionPipeline.from_pretrained("/path/to/HSIGene", trust_remote_code=True)
82
+ pipe = pipe.to("cuda")
83
+ ```
84
+
85
+ **Dependencies:** `pip install diffusers transformers torch einops safetensors`
86
+
87
+ ```python
88
+ # Conditional generation
89
+ output = pipe(
90
+ prompt="Wasteland",
91
+ num_samples=1,
92
+ height=256,
93
+ width=256,
94
+ num_inference_steps=50,
95
+ local_conditions=local_tensor, # (B, 18, H, W) or None
96
+ global_conditions=global_tensor, # (B, 768) or None
97
+ metadata=metadata_tensor, # (7,) or (B, 7) or None
98
+ guidance_scale=1.0,
99
+ )
100
+ images = output.images # (B, H, W, 48) in [0, 1]
101
+ ```
102
+
103
+ ### Conditioning
104
+
105
+ - **Local**: 18-channel maps (HED, MLSD, sketch, segmentation, etc.) at 512×512 default.
106
+ - **Global**: 768-dim CLIP features from reference images.
107
+ - **Metadata**: 7-dim vector.
108
+ - **Text**: Via `prompt`; use `text_strength` to scale.
109
+
110
+ ## Model Sources
111
+
112
+ - **Paper**: [HSIGene: A Foundation Model For Hyperspectral Image Generation](https://arxiv.org/abs/2409.12470)
113
+ - **Checkpoint**: [GoogleDrive](https://drive.google.com/file/d/1euJAbsxCgG1wIu_Eh5nPfmiSP9suWsR4/view?usp=drive_link)
114
+ - **Annotators**: [BaiduNetdisk](https://pan.baidu.com/s/1K1Y__blA6uJVV9l1QG7QvQ?pwd=98f1) (code: 98f1) → `data_prepare/annotator/ckpts`
115
+
116
+ ## Citation
117
+
118
+ ```bibtex
119
+ @misc{pang2024hsigenefoundationmodelhyperspectral,
120
+ title={HSIGene: A Foundation Model For Hyperspectral Image Generation},
121
+ author={Li Pang and Datao Tang and Shuang Xu and Deyu Meng and Xiangyong Cao},
122
+ year={2024},
123
+ eprint={2409.12470},
124
+ archivePrefix={arXiv},
125
+ primaryClass={cs.CV},
126
+ }
127
+ ```
__pycache__/pipeline_hsigene.cpython-312.pyc ADDED
Binary file (20.8 kB). View file
 
global_content_adapter/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Global content adapter for HSIGene."""
2
+
3
+ from .model import GlobalContentAdapter
4
+
5
+ __all__ = ["GlobalContentAdapter"]
global_content_adapter/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (301 Bytes). View file
 
global_content_adapter/__pycache__/model.cpython-312.pyc ADDED
Binary file (3.89 kB). View file
 
global_content_adapter/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.GlobalContentAdapter",
3
+ "in_dim": 768,
4
+ "channel_mult": [
5
+ 2,
6
+ 4
7
+ ]
8
+ }
global_content_adapter/model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GlobalContentAdapter - FFN-based adapter for global content conditioning."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class GEGLU(nn.Module):
10
+ def __init__(self, dim_in, dim_out):
11
+ super().__init__()
12
+ self.proj = nn.Linear(dim_in, dim_out * 2)
13
+
14
+ def forward(self, x):
15
+ x, gate = self.proj(x).chunk(2, dim=-1)
16
+ return x * F.gelu(gate)
17
+
18
+
19
+ class FeedForward(nn.Module):
20
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
21
+ super().__init__()
22
+ inner_dim = int(dim * mult)
23
+ dim_out = dim_out if dim_out is not None else dim
24
+ project_in = (
25
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
26
+ if not glu
27
+ else GEGLU(dim, inner_dim)
28
+ )
29
+ self.net = nn.Sequential(
30
+ project_in,
31
+ nn.Dropout(dropout),
32
+ nn.Linear(inner_dim, dim_out),
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.net(x)
37
+
38
+
39
+ class GlobalContentAdapter(nn.Module):
40
+ def __init__(self, in_dim, channel_mult=None):
41
+ super().__init__()
42
+ channel_mult = channel_mult or [2, 4]
43
+ dim_out1, mult1 = in_dim * channel_mult[0], channel_mult[0] * 2
44
+ dim_out2, mult2 = in_dim * channel_mult[1], channel_mult[1] * 2 // channel_mult[0]
45
+ self.in_dim = in_dim
46
+ self.channel_mult = channel_mult
47
+ self.ff1 = FeedForward(in_dim, dim_out=dim_out1, mult=mult1, glu=True, dropout=0.0)
48
+ self.ff2 = FeedForward(dim_out1, dim_out=dim_out2, mult=mult2, glu=True, dropout=0.0)
49
+ self.norm1 = nn.LayerNorm(in_dim)
50
+ self.norm2 = nn.LayerNorm(dim_out1)
51
+
52
+ def forward(self, x):
53
+ x = self.ff1(self.norm1(x))
54
+ x = self.ff2(self.norm2(x))
55
+ x = rearrange(x, "b (n d) -> b n d", n=self.channel_mult[-1], d=self.in_dim).contiguous()
56
+ return x
global_text_adapter/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Global text adapter for HSIGene."""
2
+
3
+ from .model import GlobalTextAdapter
4
+
5
+ __all__ = ["GlobalTextAdapter"]
global_text_adapter/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.GlobalTextAdapter",
3
+ "in_dim": 768
4
+ }
global_text_adapter/model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GlobalTextAdapter - FFN-based adapter for global text conditioning."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class GEGLU(nn.Module):
9
+ def __init__(self, dim_in, dim_out):
10
+ super().__init__()
11
+ self.proj = nn.Linear(dim_in, dim_out * 2)
12
+
13
+ def forward(self, x):
14
+ x, gate = self.proj(x).chunk(2, dim=-1)
15
+ return x * F.gelu(gate)
16
+
17
+
18
+ class FeedForward(nn.Module):
19
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
20
+ super().__init__()
21
+ inner_dim = int(dim * mult)
22
+ dim_out = dim_out if dim_out is not None else dim
23
+ project_in = (
24
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
25
+ if not glu
26
+ else GEGLU(dim, inner_dim)
27
+ )
28
+ self.net = nn.Sequential(
29
+ project_in,
30
+ nn.Dropout(dropout),
31
+ nn.Linear(inner_dim, dim_out),
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.net(x)
36
+
37
+
38
+ class GlobalTextAdapter(nn.Module):
39
+ def __init__(self, in_dim, max_len=768):
40
+ super().__init__()
41
+ self.in_dim = in_dim
42
+ dim_out1 = in_dim * 2
43
+ dim_out2 = in_dim
44
+ self.ff1 = FeedForward(in_dim, dim_out=dim_out1, mult=2, glu=True, dropout=0.0)
45
+ self.ff2 = FeedForward(dim_out1, dim_out=dim_out2, mult=4, glu=True, dropout=0.0)
46
+ self.norm1 = nn.LayerNorm(in_dim)
47
+ self.norm2 = nn.LayerNorm(dim_out1)
48
+
49
+ def forward(self, x):
50
+ x = self.ff1(self.norm1(x))
51
+ x = self.ff2(self.norm2(x))
52
+ return x
local_adapter/__pycache__/attention.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
local_adapter/__pycache__/diffusion.cpython-312.pyc ADDED
Binary file (22.7 kB). View file
 
local_adapter/__pycache__/model.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
local_adapter/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.94 kB). View file
 
local_adapter/attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene attention modules - FeedForward, CrossAttention, SpatialTransformer."""
2
+
3
+ from inspect import isfunction
4
+ from typing import Optional, Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ from torch import einsum
11
+
12
+ from .utils import checkpoint, zero_module, exists
13
+
14
+ try:
15
+ import xformers
16
+ import xformers.ops
17
+ XFORMERS_IS_AVAILABLE = True
18
+ except ImportError:
19
+ XFORMERS_IS_AVAILABLE = False
20
+
21
+
22
+ def default(val, d):
23
+ if exists(val):
24
+ return val
25
+ return d() if isfunction(d) else d
26
+
27
+
28
+ import os
29
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
30
+
31
+
32
+ class GEGLU(nn.Module):
33
+ def __init__(self, dim_in, dim_out):
34
+ super().__init__()
35
+ self.proj = nn.Linear(dim_in, dim_out * 2)
36
+
37
+ def forward(self, x):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
42
+ class FeedForward(nn.Module):
43
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
44
+ super().__init__()
45
+ inner_dim = int(dim * mult)
46
+ dim_out = default(dim_out, dim)
47
+ project_in = (
48
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
49
+ if not glu
50
+ else GEGLU(dim, inner_dim)
51
+ )
52
+ self.net = nn.Sequential(
53
+ project_in,
54
+ nn.Dropout(dropout),
55
+ nn.Linear(inner_dim, dim_out),
56
+ )
57
+
58
+ def forward(self, x):
59
+ return self.net(x)
60
+
61
+
62
+ def Normalize(in_channels, num_groups=32):
63
+ return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
64
+
65
+
66
+ class CrossAttention(nn.Module):
67
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
68
+ super().__init__()
69
+ inner_dim = dim_head * heads
70
+ context_dim = default(context_dim, query_dim)
71
+ self.scale = dim_head ** -0.5
72
+ self.heads = heads
73
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
74
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
75
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
76
+ self.to_out = nn.Sequential(
77
+ nn.Linear(inner_dim, query_dim),
78
+ nn.Dropout(dropout),
79
+ )
80
+
81
+ def forward(self, x, context=None, mask=None):
82
+ h = self.heads
83
+ q = self.to_q(x)
84
+ context = default(context, x)
85
+ k = self.to_k(context)
86
+ v = self.to_v(context)
87
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
88
+ if _ATTN_PRECISION == "fp32":
89
+ with torch.autocast(enabled=False, device_type="cuda"):
90
+ q, k = q.float(), k.float()
91
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
92
+ else:
93
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
94
+ del q, k
95
+ if exists(mask):
96
+ mask = rearrange(mask, "b ... -> b (...)")
97
+ max_neg_value = -torch.finfo(sim.dtype).max
98
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
99
+ sim.masked_fill_(~mask, max_neg_value)
100
+ sim = sim.softmax(dim=-1)
101
+ out = einsum("b i j, b j d -> b i d", sim, v)
102
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
103
+ return self.to_out(out)
104
+
105
+
106
+ class MemoryEfficientCrossAttention(nn.Module):
107
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
108
+ super().__init__()
109
+ inner_dim = dim_head * heads
110
+ context_dim = default(context_dim, query_dim)
111
+ self.heads = heads
112
+ self.dim_head = dim_head
113
+ self.scale = dim_head ** -0.5
114
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
115
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
116
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
117
+ self.to_out = nn.Sequential(
118
+ nn.Linear(inner_dim, query_dim),
119
+ nn.Dropout(dropout),
120
+ )
121
+ self.attention_op: Optional[Any] = None
122
+
123
+ def forward(self, x, context=None, mask=None):
124
+ q = self.to_q(x)
125
+ context = default(context, x)
126
+ k = self.to_k(context)
127
+ v = self.to_v(context)
128
+ b, _, _ = q.shape
129
+ q, k, v = map(
130
+ lambda t: t.unsqueeze(3)
131
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
132
+ .permute(0, 2, 1, 3)
133
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
134
+ .contiguous(),
135
+ (q, k, v),
136
+ )
137
+ if XFORMERS_IS_AVAILABLE:
138
+ out = xformers.ops.memory_efficient_attention(
139
+ q, k, v, attn_bias=None, op=self.attention_op
140
+ )
141
+ else:
142
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
143
+ sim = sim.softmax(dim=-1)
144
+ out = torch.einsum("b i j, b j d -> b i d", sim, v)
145
+ out = (
146
+ out.unsqueeze(0)
147
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
148
+ .permute(0, 2, 1, 3)
149
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
150
+ )
151
+ return self.to_out(out)
152
+
153
+
154
+ class BasicTransformerBlock(nn.Module):
155
+ ATTENTION_MODES = {
156
+ "softmax": CrossAttention,
157
+ "softmax-xformers": MemoryEfficientCrossAttention,
158
+ }
159
+
160
+ def __init__(
161
+ self,
162
+ dim,
163
+ n_heads,
164
+ d_head,
165
+ dropout=0.0,
166
+ context_dim=None,
167
+ gated_ff=True,
168
+ checkpoint=True,
169
+ disable_self_attn=False,
170
+ ):
171
+ super().__init__()
172
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILABLE else "softmax"
173
+ attn_cls = self.ATTENTION_MODES[attn_mode]
174
+ self.disable_self_attn = disable_self_attn
175
+ self.attn1 = attn_cls(
176
+ query_dim=dim,
177
+ heads=n_heads,
178
+ dim_head=d_head,
179
+ dropout=dropout,
180
+ context_dim=context_dim if self.disable_self_attn else None,
181
+ )
182
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
183
+ self.attn2 = attn_cls(
184
+ query_dim=dim,
185
+ context_dim=context_dim,
186
+ heads=n_heads,
187
+ dim_head=d_head,
188
+ dropout=dropout,
189
+ )
190
+ self.norm1 = nn.LayerNorm(dim)
191
+ self.norm2 = nn.LayerNorm(dim)
192
+ self.norm3 = nn.LayerNorm(dim)
193
+ self.checkpoint = checkpoint
194
+
195
+ def forward(self, x, context=None):
196
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
197
+
198
+ def _forward(self, x, context=None):
199
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
200
+ x = self.attn2(self.norm2(x), context=context) + x
201
+ x = self.ff(self.norm3(x)) + x
202
+ return x
203
+
204
+
205
+ class SpatialTransformer(nn.Module):
206
+ def __init__(
207
+ self,
208
+ in_channels,
209
+ n_heads,
210
+ d_head,
211
+ depth=1,
212
+ dropout=0.0,
213
+ context_dim=None,
214
+ disable_self_attn=False,
215
+ use_linear=False,
216
+ use_checkpoint=True,
217
+ ):
218
+ super().__init__()
219
+ if exists(context_dim) and not isinstance(context_dim, list):
220
+ context_dim = [context_dim]
221
+ self.in_channels = in_channels
222
+ inner_dim = n_heads * d_head
223
+ self.norm = Normalize(in_channels)
224
+ if not use_linear:
225
+ self.proj_in = nn.Conv2d(
226
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
227
+ )
228
+ else:
229
+ self.proj_in = nn.Linear(in_channels, inner_dim)
230
+ self.transformer_blocks = nn.ModuleList(
231
+ [
232
+ BasicTransformerBlock(
233
+ inner_dim,
234
+ n_heads,
235
+ d_head,
236
+ dropout=dropout,
237
+ context_dim=context_dim[d] if isinstance(context_dim, list) else context_dim,
238
+ disable_self_attn=disable_self_attn,
239
+ checkpoint=use_checkpoint,
240
+ )
241
+ for d in range(depth)
242
+ ]
243
+ )
244
+ if not use_linear:
245
+ self.proj_out = zero_module(
246
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
247
+ )
248
+ else:
249
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
250
+ self.use_linear = use_linear
251
+
252
+ def forward(self, x, context=None):
253
+ if not isinstance(context, list):
254
+ context = [context]
255
+ b, c, h, w = x.shape
256
+ x_in = x
257
+ x = self.norm(x)
258
+ if not self.use_linear:
259
+ x = self.proj_in(x)
260
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
261
+ if self.use_linear:
262
+ x = self.proj_in(x)
263
+ for i, block in enumerate(self.transformer_blocks):
264
+ ctx = context[i] if i < len(context) else context[0]
265
+ x = block(x, context=ctx)
266
+ if self.use_linear:
267
+ x = self.proj_out(x)
268
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
269
+ if not self.use_linear:
270
+ x = self.proj_out(x)
271
+ return x + x_in
local_adapter/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.LocalAdapter",
3
+ "in_channels": 4,
4
+ "model_channels": 320,
5
+ "local_channels": 18,
6
+ "inject_channels": [
7
+ 192,
8
+ 256,
9
+ 384,
10
+ 512
11
+ ],
12
+ "inject_layers": [
13
+ 1,
14
+ 4,
15
+ 7,
16
+ 10
17
+ ],
18
+ "num_res_blocks": 2,
19
+ "attention_resolutions": [
20
+ 4,
21
+ 2,
22
+ 1
23
+ ],
24
+ "channel_mult": [
25
+ 1,
26
+ 2,
27
+ 4,
28
+ 4
29
+ ],
30
+ "use_checkpoint": true,
31
+ "num_heads": 8,
32
+ "use_spatial_transformer": true,
33
+ "transformer_depth": 1,
34
+ "context_dim": 768,
35
+ "legacy": false
36
+ }
local_adapter/diffusion.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene diffusion modules - UNet, ResBlock, etc. From openaimodel."""
2
+
3
+ from abc import abstractmethod
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .utils import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ zero_module,
16
+ normalization,
17
+ timestep_embedding,
18
+ exists,
19
+ )
20
+ from .attention import SpatialTransformer
21
+
22
+
23
+ def avg_pool_nd(dims, *args, **kwargs):
24
+ """Create a 1D, 2D, or 3D average pooling module."""
25
+ if dims == 1:
26
+ return nn.AvgPool1d(*args, **kwargs)
27
+ elif dims == 2:
28
+ return nn.AvgPool2d(*args, **kwargs)
29
+ elif dims == 3:
30
+ return nn.AvgPool3d(*args, **kwargs)
31
+ raise ValueError(f"unsupported dimensions: {dims}")
32
+
33
+
34
+ def convert_module_to_f16(x):
35
+ pass
36
+
37
+
38
+ def convert_module_to_f32(x):
39
+ pass
40
+
41
+
42
+ class TimestepBlock(nn.Module):
43
+ """Any module where forward() takes timestep embeddings as a second argument."""
44
+
45
+ @abstractmethod
46
+ def forward(self, x, emb):
47
+ """Apply the module to `x` given `emb` timestep embeddings."""
48
+
49
+
50
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
51
+ """Sequential module that passes timestep embeddings to children that support it."""
52
+
53
+ def forward(self, x, emb, context=None):
54
+ for layer in self:
55
+ if isinstance(layer, TimestepBlock):
56
+ x = layer(x, emb)
57
+ elif isinstance(layer, SpatialTransformer):
58
+ x = layer(x, context)
59
+ else:
60
+ x = layer(x)
61
+ return x
62
+
63
+
64
+ class Upsample(nn.Module):
65
+ """Upsampling layer with optional convolution."""
66
+
67
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
68
+ super().__init__()
69
+ self.channels = channels
70
+ self.out_channels = out_channels or channels
71
+ self.use_conv = use_conv
72
+ self.dims = dims
73
+ if use_conv:
74
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
75
+
76
+ def forward(self, x):
77
+ assert x.shape[1] == self.channels
78
+ if self.dims == 3:
79
+ x = F.interpolate(
80
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
81
+ )
82
+ else:
83
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
84
+ if self.use_conv:
85
+ x = self.conv(x)
86
+ return x
87
+
88
+
89
+ class Downsample(nn.Module):
90
+ """Downsampling layer with optional convolution."""
91
+
92
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.out_channels = out_channels or channels
96
+ self.use_conv = use_conv
97
+ self.dims = dims
98
+ stride = 2 if dims != 3 else (1, 2, 2)
99
+ if use_conv:
100
+ self.op = conv_nd(
101
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
102
+ )
103
+ else:
104
+ assert self.channels == self.out_channels
105
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
106
+
107
+ def forward(self, x):
108
+ assert x.shape[1] == self.channels
109
+ return self.op(x)
110
+
111
+
112
+ class ResBlock(TimestepBlock):
113
+ """Residual block with timestep conditioning."""
114
+
115
+ def __init__(
116
+ self,
117
+ channels,
118
+ emb_channels,
119
+ dropout,
120
+ out_channels=None,
121
+ use_conv=False,
122
+ use_scale_shift_norm=False,
123
+ dims=2,
124
+ use_checkpoint=False,
125
+ up=False,
126
+ down=False,
127
+ ):
128
+ super().__init__()
129
+ self.channels = channels
130
+ self.emb_channels = emb_channels
131
+ self.dropout = dropout
132
+ self.out_channels = out_channels or channels
133
+ self.use_conv = use_conv
134
+ self.use_checkpoint = use_checkpoint
135
+ self.use_scale_shift_norm = use_scale_shift_norm
136
+
137
+ self.in_layers = nn.Sequential(
138
+ normalization(channels),
139
+ nn.SiLU(),
140
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
141
+ )
142
+
143
+ self.updown = up or down
144
+ if up:
145
+ self.h_upd = Upsample(channels, False, dims)
146
+ self.x_upd = Upsample(channels, False, dims)
147
+ elif down:
148
+ self.h_upd = Downsample(channels, False, dims)
149
+ self.x_upd = Downsample(channels, False, dims)
150
+ else:
151
+ self.h_upd = self.x_upd = nn.Identity()
152
+
153
+ self.emb_layers = nn.Sequential(
154
+ nn.SiLU(),
155
+ linear(
156
+ emb_channels,
157
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
158
+ ),
159
+ )
160
+ self.out_layers = nn.Sequential(
161
+ normalization(self.out_channels),
162
+ nn.SiLU(),
163
+ nn.Dropout(p=dropout),
164
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
165
+ )
166
+
167
+ if self.out_channels == channels:
168
+ self.skip_connection = nn.Identity()
169
+ elif use_conv:
170
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
171
+ else:
172
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
173
+
174
+ def forward(self, x, emb):
175
+ return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
176
+
177
+ def _forward(self, x, emb):
178
+ if self.updown:
179
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
180
+ h = in_rest(x)
181
+ h = self.h_upd(h)
182
+ x = self.x_upd(x)
183
+ h = in_conv(h)
184
+ else:
185
+ h = self.in_layers(x)
186
+ emb_out = self.emb_layers(emb).type(h.dtype)
187
+ while len(emb_out.shape) < len(h.shape):
188
+ emb_out = emb_out[..., None]
189
+ if self.use_scale_shift_norm:
190
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
191
+ scale, shift = emb_out.chunk(2, dim=1)
192
+ h = out_norm(h) * (1 + scale) + shift
193
+ h = out_rest(h)
194
+ else:
195
+ h = h + emb_out
196
+ h = self.out_layers(h)
197
+ return self.skip_connection(x) + h
198
+
199
+
200
+ class AttentionBlock(nn.Module):
201
+ """Spatial self-attention block."""
202
+
203
+ def __init__(
204
+ self,
205
+ channels,
206
+ num_heads=1,
207
+ num_head_channels=-1,
208
+ use_checkpoint=False,
209
+ use_new_attention_order=False,
210
+ ):
211
+ super().__init__()
212
+ self.channels = channels
213
+ if num_head_channels == -1:
214
+ self.num_heads = num_heads
215
+ else:
216
+ assert channels % num_head_channels == 0
217
+ self.num_heads = channels // num_head_channels
218
+ self.use_checkpoint = use_checkpoint
219
+ self.norm = normalization(channels)
220
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
221
+ self.attention = (
222
+ QKVAttention(self.num_heads)
223
+ if use_new_attention_order
224
+ else QKVAttentionLegacy(self.num_heads)
225
+ )
226
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
227
+
228
+ def forward(self, x):
229
+ return checkpoint(self._forward, (x,), self.parameters(), True)
230
+
231
+ def _forward(self, x):
232
+ b, c, *spatial = x.shape
233
+ x = x.reshape(b, c, -1)
234
+ qkv = self.qkv(self.norm(x))
235
+ h = self.attention(qkv)
236
+ h = self.proj_out(h)
237
+ return (x + h).reshape(b, c, *spatial)
238
+
239
+
240
+ class QKVAttentionLegacy(nn.Module):
241
+ """QKV attention - split heads before split qkv."""
242
+
243
+ def __init__(self, n_heads):
244
+ super().__init__()
245
+ self.n_heads = n_heads
246
+
247
+ def forward(self, qkv):
248
+ bs, width, length = qkv.shape
249
+ assert width % (3 * self.n_heads) == 0
250
+ ch = width // (3 * self.n_heads)
251
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
252
+ scale = 1 / math.sqrt(math.sqrt(ch))
253
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
254
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
255
+ a = torch.einsum("bts,bcs->bct", weight, v)
256
+ return a.reshape(bs, -1, length)
257
+
258
+
259
+ class QKVAttention(nn.Module):
260
+ """QKV attention - split qkv before split heads."""
261
+
262
+ def __init__(self, n_heads):
263
+ super().__init__()
264
+ self.n_heads = n_heads
265
+
266
+ def forward(self, qkv):
267
+ bs, width, length = qkv.shape
268
+ assert width % (3 * self.n_heads) == 0
269
+ ch = width // (3 * self.n_heads)
270
+ q, k, v = qkv.chunk(3, dim=1)
271
+ scale = 1 / math.sqrt(math.sqrt(ch))
272
+ weight = torch.einsum(
273
+ "bct,bcs->bts",
274
+ (q * scale).view(bs * self.n_heads, ch, length),
275
+ (k * scale).view(bs * self.n_heads, ch, length),
276
+ )
277
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
278
+ a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
279
+ return a.reshape(bs, -1, length)
280
+
281
+
282
+ class UNetModel(nn.Module):
283
+ """Full UNet with attention and timestep embedding."""
284
+
285
+ def __init__(
286
+ self,
287
+ image_size,
288
+ in_channels,
289
+ model_channels,
290
+ out_channels,
291
+ num_res_blocks,
292
+ attention_resolutions,
293
+ dropout=0,
294
+ channel_mult=(1, 2, 4, 8),
295
+ conv_resample=True,
296
+ dims=2,
297
+ num_classes=None,
298
+ use_checkpoint=False,
299
+ use_fp16=False,
300
+ num_heads=-1,
301
+ num_head_channels=-1,
302
+ num_heads_upsample=-1,
303
+ use_scale_shift_norm=False,
304
+ resblock_updown=False,
305
+ use_new_attention_order=False,
306
+ use_spatial_transformer=False,
307
+ transformer_depth=1,
308
+ context_dim=None,
309
+ n_embed=None,
310
+ legacy=True,
311
+ disable_self_attentions=None,
312
+ num_attention_blocks=None,
313
+ disable_middle_self_attn=False,
314
+ use_linear_in_transformer=False,
315
+ ):
316
+ super().__init__()
317
+ if use_spatial_transformer:
318
+ assert context_dim is not None
319
+ if context_dim is not None:
320
+ assert use_spatial_transformer
321
+ if hasattr(context_dim, "__iter__") and not isinstance(context_dim, (list, tuple)):
322
+ context_dim = list(context_dim)
323
+
324
+ if num_heads_upsample == -1:
325
+ num_heads_upsample = num_heads
326
+ if num_heads == -1:
327
+ assert num_head_channels != -1
328
+ if num_head_channels == -1:
329
+ assert num_heads != -1
330
+
331
+ self.image_size = image_size
332
+ self.in_channels = in_channels
333
+ self.model_channels = model_channels
334
+ self.out_channels = out_channels
335
+ if isinstance(num_res_blocks, int):
336
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
337
+ else:
338
+ assert len(num_res_blocks) == len(channel_mult)
339
+ self.num_res_blocks = num_res_blocks
340
+
341
+ self.attention_resolutions = attention_resolutions
342
+ self.dropout = dropout
343
+ self.channel_mult = channel_mult
344
+ self.conv_resample = conv_resample
345
+ self.num_classes = num_classes
346
+ self.use_checkpoint = use_checkpoint
347
+ self.dtype = torch.float16 if use_fp16 else torch.float32
348
+ self.num_heads = num_heads
349
+ self.num_head_channels = num_head_channels
350
+ self.num_heads_upsample = num_heads_upsample
351
+ self.predict_codebook_ids = n_embed is not None
352
+
353
+ time_embed_dim = model_channels * 4
354
+ self.time_embed = nn.Sequential(
355
+ linear(model_channels, time_embed_dim),
356
+ nn.SiLU(),
357
+ linear(time_embed_dim, time_embed_dim),
358
+ )
359
+
360
+ if num_classes is not None:
361
+ if isinstance(num_classes, int):
362
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
363
+ elif num_classes == "continuous":
364
+ self.label_emb = nn.Linear(1, time_embed_dim)
365
+ else:
366
+ raise ValueError()
367
+
368
+ self.input_blocks = nn.ModuleList(
369
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
370
+ )
371
+ self._feature_size = model_channels
372
+ input_block_chans = [model_channels]
373
+ ch = model_channels
374
+ ds = 1
375
+
376
+ for level, mult in enumerate(channel_mult):
377
+ for nr in range(self.num_res_blocks[level]):
378
+ layers = [
379
+ ResBlock(
380
+ ch,
381
+ time_embed_dim,
382
+ dropout,
383
+ out_channels=mult * model_channels,
384
+ dims=dims,
385
+ use_checkpoint=use_checkpoint,
386
+ use_scale_shift_norm=use_scale_shift_norm,
387
+ )
388
+ ]
389
+ ch = mult * model_channels
390
+ if ds in attention_resolutions:
391
+ if num_head_channels == -1:
392
+ dim_head = ch // num_heads
393
+ else:
394
+ num_heads_cur = ch // num_head_channels
395
+ dim_head = num_head_channels
396
+ if legacy:
397
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
398
+ disabled_sa = (
399
+ disable_self_attentions[level]
400
+ if exists(disable_self_attentions)
401
+ else False
402
+ )
403
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
404
+ attn_block = (
405
+ AttentionBlock(
406
+ ch,
407
+ use_checkpoint=use_checkpoint,
408
+ num_heads=num_heads,
409
+ num_head_channels=dim_head,
410
+ use_new_attention_order=use_new_attention_order,
411
+ )
412
+ if not use_spatial_transformer
413
+ else SpatialTransformer(
414
+ ch,
415
+ num_heads,
416
+ dim_head,
417
+ depth=transformer_depth,
418
+ context_dim=context_dim,
419
+ disable_self_attn=disabled_sa,
420
+ use_linear=use_linear_in_transformer,
421
+ use_checkpoint=use_checkpoint,
422
+ )
423
+ )
424
+ layers.append(attn_block)
425
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
426
+ self._feature_size += ch
427
+ input_block_chans.append(ch)
428
+ if level != len(channel_mult) - 1:
429
+ out_ch = ch
430
+ down_block = (
431
+ ResBlock(
432
+ ch,
433
+ time_embed_dim,
434
+ dropout,
435
+ out_channels=out_ch,
436
+ dims=dims,
437
+ use_checkpoint=use_checkpoint,
438
+ use_scale_shift_norm=use_scale_shift_norm,
439
+ down=True,
440
+ )
441
+ if resblock_updown
442
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
443
+ )
444
+ self.input_blocks.append(TimestepEmbedSequential(down_block))
445
+ ch = out_ch
446
+ input_block_chans.append(ch)
447
+ ds *= 2
448
+ self._feature_size += ch
449
+
450
+ if num_head_channels == -1:
451
+ dim_head = ch // num_heads
452
+ else:
453
+ num_heads_cur = ch // num_head_channels
454
+ dim_head = num_head_channels
455
+ if legacy:
456
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
457
+ mid_attn = (
458
+ AttentionBlock(
459
+ ch,
460
+ use_checkpoint=use_checkpoint,
461
+ num_heads=num_heads,
462
+ num_head_channels=dim_head,
463
+ use_new_attention_order=use_new_attention_order,
464
+ )
465
+ if not use_spatial_transformer
466
+ else SpatialTransformer(
467
+ ch,
468
+ num_heads,
469
+ dim_head,
470
+ depth=transformer_depth,
471
+ context_dim=context_dim,
472
+ disable_self_attn=disable_middle_self_attn,
473
+ use_linear=use_linear_in_transformer,
474
+ use_checkpoint=use_checkpoint,
475
+ )
476
+ )
477
+ self.middle_block = TimestepEmbedSequential(
478
+ ResBlock(
479
+ ch,
480
+ time_embed_dim,
481
+ dropout,
482
+ dims=dims,
483
+ use_checkpoint=use_checkpoint,
484
+ use_scale_shift_norm=use_scale_shift_norm,
485
+ ),
486
+ mid_attn,
487
+ ResBlock(
488
+ ch,
489
+ time_embed_dim,
490
+ dropout,
491
+ dims=dims,
492
+ use_checkpoint=use_checkpoint,
493
+ use_scale_shift_norm=use_scale_shift_norm,
494
+ ),
495
+ )
496
+ self._feature_size += ch
497
+
498
+ self.output_blocks = nn.ModuleList([])
499
+ for level, mult in list(enumerate(channel_mult))[::-1]:
500
+ for i in range(self.num_res_blocks[level] + 1):
501
+ ich = input_block_chans.pop()
502
+ layers = [
503
+ ResBlock(
504
+ ch + ich,
505
+ time_embed_dim,
506
+ dropout,
507
+ out_channels=model_channels * mult,
508
+ dims=dims,
509
+ use_checkpoint=use_checkpoint,
510
+ use_scale_shift_norm=use_scale_shift_norm,
511
+ )
512
+ ]
513
+ ch = model_channels * mult
514
+ if ds in attention_resolutions:
515
+ if num_head_channels == -1:
516
+ dim_head = ch // num_heads
517
+ else:
518
+ num_heads_cur = ch // num_head_channels
519
+ dim_head = num_head_channels
520
+ if legacy:
521
+ dim_head = (
522
+ ch // num_heads if use_spatial_transformer else num_head_channels
523
+ )
524
+ disabled_sa = (
525
+ disable_self_attentions[level]
526
+ if exists(disable_self_attentions)
527
+ else False
528
+ )
529
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
530
+ attn_block = (
531
+ AttentionBlock(
532
+ ch,
533
+ use_checkpoint=use_checkpoint,
534
+ num_heads=num_heads_upsample,
535
+ num_head_channels=dim_head,
536
+ use_new_attention_order=use_new_attention_order,
537
+ )
538
+ if not use_spatial_transformer
539
+ else SpatialTransformer(
540
+ ch,
541
+ num_heads,
542
+ dim_head,
543
+ depth=transformer_depth,
544
+ context_dim=context_dim,
545
+ disable_self_attn=disabled_sa,
546
+ use_linear=use_linear_in_transformer,
547
+ use_checkpoint=use_checkpoint,
548
+ )
549
+ )
550
+ layers.append(attn_block)
551
+ if level and i == self.num_res_blocks[level]:
552
+ out_ch = ch
553
+ up_block = (
554
+ ResBlock(
555
+ ch,
556
+ time_embed_dim,
557
+ dropout,
558
+ out_channels=out_ch,
559
+ dims=dims,
560
+ use_checkpoint=use_checkpoint,
561
+ use_scale_shift_norm=use_scale_shift_norm,
562
+ up=True,
563
+ )
564
+ if resblock_updown
565
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
566
+ )
567
+ layers.append(up_block)
568
+ ds //= 2
569
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
570
+ self._feature_size += ch
571
+
572
+ self.out = nn.Sequential(
573
+ normalization(ch),
574
+ nn.SiLU(),
575
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
576
+ )
577
+ if self.predict_codebook_ids:
578
+ self.id_predictor = nn.Sequential(
579
+ normalization(ch),
580
+ conv_nd(dims, model_channels, n_embed, 1),
581
+ )
582
+
583
+ def forward(self, x, timesteps=None, metadata=None, context=None, y=None, **kwargs):
584
+ assert (y is not None) == (self.num_classes is not None)
585
+ hs = []
586
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
587
+ emb = self.time_embed(t_emb)
588
+ if metadata is not None:
589
+ if isinstance(metadata, (list, tuple)) and len(metadata) == 1:
590
+ metadata = metadata[0]
591
+ emb = emb + metadata
592
+
593
+ if self.num_classes is not None:
594
+ assert y.shape[0] == x.shape[0]
595
+ emb = emb + self.label_emb(y)
596
+
597
+ h = x.type(self.dtype)
598
+ for module in self.input_blocks:
599
+ h = module(h, emb, context)
600
+ hs.append(h)
601
+ h = self.middle_block(h, emb, context)
602
+ for module in self.output_blocks:
603
+ h = torch.cat([h, hs.pop()], dim=1)
604
+ h = module(h, emb, context)
605
+ h = h.type(x.dtype)
606
+ if self.predict_codebook_ids:
607
+ return self.id_predictor(h)
608
+ return self.out(h)
local_adapter/model.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene adapters - LocalAdapter, LocalControlUNetModel, GlobalContentAdapter, etc."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+ from .utils import (
9
+ checkpoint,
10
+ conv_nd,
11
+ linear,
12
+ zero_module,
13
+ timestep_embedding,
14
+ exists,
15
+ )
16
+ from .attention import SpatialTransformer
17
+ from .diffusion import (
18
+ TimestepBlock,
19
+ TimestepEmbedSequential,
20
+ ResBlock,
21
+ Downsample,
22
+ AttentionBlock,
23
+ )
24
+
25
+
26
+ class LocalTimestepEmbedSequential(nn.Sequential, TimestepBlock):
27
+ """Sequential that handles LocalResBlock, TimestepBlock, SpatialTransformer."""
28
+
29
+ def forward(self, x, emb, context=None, local_features=None):
30
+ for layer in self:
31
+ if isinstance(layer, TimestepBlock):
32
+ x = layer(x, emb)
33
+ elif isinstance(layer, SpatialTransformer):
34
+ x = layer(x, context)
35
+ elif isinstance(layer, LocalResBlock):
36
+ x = layer(x, emb, local_features)
37
+ else:
38
+ x = layer(x)
39
+ return x
40
+
41
+
42
+ class FDN(nn.Module):
43
+ def __init__(self, norm_nc, label_nc):
44
+ super().__init__()
45
+ ks = 3
46
+ pw = ks // 2
47
+ self.param_free_norm = nn.GroupNorm(32, norm_nc, affine=False)
48
+ self.conv_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
49
+ self.conv_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
50
+
51
+ def forward(self, x, local_features):
52
+ normalized = self.param_free_norm(x)
53
+ assert local_features.size()[2:] == x.size()[2:]
54
+ gamma = self.conv_gamma(local_features)
55
+ beta = self.conv_beta(local_features)
56
+ return normalized * (1 + gamma) + beta
57
+
58
+
59
+ class SelfAttention(nn.Module):
60
+ def __init__(self, in_dim):
61
+ super().__init__()
62
+ self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
63
+ self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
64
+ self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
65
+ self.softmax = nn.Softmax(dim=-1)
66
+
67
+ def forward(self, x):
68
+ batch, C, width, height = x.size()
69
+ query = self.query_conv(x).view(batch, -1, width * height).permute(0, 2, 1)
70
+ key = self.key_conv(x).view(batch, -1, width * height)
71
+ value = self.value_conv(x).view(batch, -1, width * height)
72
+ attention = self.softmax(torch.bmm(query, key))
73
+ out = torch.bmm(value, attention.permute(0, 2, 1))
74
+ out = out.view(batch, C, width, height)
75
+ return out + x
76
+
77
+
78
+ class EnhancedFDN(nn.Module):
79
+ def __init__(self, norm_nc, label_nc):
80
+ super().__init__()
81
+ self.fdn = FDN(norm_nc, label_nc)
82
+ self.attention = SelfAttention(norm_nc)
83
+
84
+ def forward(self, x, local_features):
85
+ x = self.attention(x)
86
+ return self.fdn(x, local_features)
87
+
88
+
89
+ class LocalResBlock(nn.Module):
90
+ def __init__(
91
+ self,
92
+ channels,
93
+ emb_channels,
94
+ dropout,
95
+ out_channels=None,
96
+ dims=2,
97
+ use_checkpoint=False,
98
+ inject_channels=None,
99
+ ):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.emb_channels = emb_channels
103
+ self.dropout = dropout
104
+ self.out_channels = out_channels or channels
105
+ self.use_checkpoint = use_checkpoint
106
+ self.norm_in = EnhancedFDN(channels, inject_channels)
107
+ self.norm_out = EnhancedFDN(self.out_channels, inject_channels)
108
+
109
+ self.in_layers = nn.Sequential(
110
+ nn.Identity(),
111
+ nn.SiLU(),
112
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
113
+ )
114
+ self.emb_layers = nn.Sequential(
115
+ nn.SiLU(),
116
+ linear(emb_channels, self.out_channels),
117
+ )
118
+ self.out_layers = nn.Sequential(
119
+ nn.Identity(),
120
+ nn.SiLU(),
121
+ nn.Dropout(p=dropout),
122
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
123
+ )
124
+
125
+ if self.out_channels == channels:
126
+ self.skip_connection = nn.Identity()
127
+ else:
128
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
129
+
130
+ def forward(self, x, emb, local_conditions):
131
+ return checkpoint(
132
+ self._forward, (x, emb, local_conditions), self.parameters(), self.use_checkpoint
133
+ )
134
+
135
+ def _forward(self, x, emb, local_conditions):
136
+ h = self.norm_in(x, local_conditions)
137
+ h = self.in_layers(h)
138
+ emb_out = self.emb_layers(emb).type(h.dtype)
139
+ while len(emb_out.shape) < len(h.shape):
140
+ emb_out = emb_out[..., None]
141
+ h = h + emb_out
142
+ h = self.norm_out(h, local_conditions)
143
+ h = self.out_layers(h)
144
+ return self.skip_connection(x) + h
145
+
146
+
147
+ class FeatureExtractor(nn.Module):
148
+ def __init__(self, local_channels, inject_channels, dims=2):
149
+ super().__init__()
150
+ self.pre_extractor = LocalTimestepEmbedSequential(
151
+ conv_nd(dims, local_channels, 32, 3, padding=1),
152
+ nn.SiLU(),
153
+ conv_nd(dims, 32, 64, 3, padding=1, stride=2),
154
+ nn.SiLU(),
155
+ conv_nd(dims, 64, 64, 3, padding=1),
156
+ nn.SiLU(),
157
+ conv_nd(dims, 64, 128, 3, padding=1, stride=2),
158
+ nn.SiLU(),
159
+ conv_nd(dims, 128, 128, 3, padding=1),
160
+ nn.SiLU(),
161
+ )
162
+ self.extractors = nn.ModuleList([
163
+ LocalTimestepEmbedSequential(
164
+ conv_nd(dims, 128, inject_channels[0], 3, padding=1, stride=2),
165
+ nn.SiLU(),
166
+ ),
167
+ LocalTimestepEmbedSequential(
168
+ conv_nd(dims, inject_channels[0], inject_channels[1], 3, padding=1, stride=2),
169
+ nn.SiLU(),
170
+ ),
171
+ LocalTimestepEmbedSequential(
172
+ conv_nd(dims, inject_channels[1], inject_channels[2], 3, padding=1, stride=2),
173
+ nn.SiLU(),
174
+ ),
175
+ LocalTimestepEmbedSequential(
176
+ conv_nd(dims, inject_channels[2], inject_channels[3], 3, padding=1, stride=2),
177
+ nn.SiLU(),
178
+ ),
179
+ ])
180
+ self.zero_convs = nn.ModuleList([
181
+ zero_module(conv_nd(dims, inject_channels[0], inject_channels[0], 3, padding=1)),
182
+ zero_module(conv_nd(dims, inject_channels[1], inject_channels[1], 3, padding=1)),
183
+ zero_module(conv_nd(dims, inject_channels[2], inject_channels[2], 3, padding=1)),
184
+ zero_module(conv_nd(dims, inject_channels[3], inject_channels[3], 3, padding=1)),
185
+ ])
186
+
187
+ def forward(self, local_conditions):
188
+ local_features = self.pre_extractor(local_conditions, None)
189
+ output_features = []
190
+ for idx in range(len(self.extractors)):
191
+ local_features = self.extractors[idx](local_features, None)
192
+ output_features.append(self.zero_convs[idx](local_features))
193
+ return output_features
194
+
195
+
196
+ class LocalAdapter(nn.Module):
197
+ def __init__(
198
+ self,
199
+ in_channels,
200
+ model_channels,
201
+ local_channels,
202
+ inject_channels,
203
+ inject_layers,
204
+ num_res_blocks,
205
+ attention_resolutions,
206
+ dropout=0,
207
+ channel_mult=(1, 2, 4, 8),
208
+ conv_resample=True,
209
+ dims=2,
210
+ use_checkpoint=False,
211
+ use_fp16=False,
212
+ num_heads=-1,
213
+ num_head_channels=-1,
214
+ num_heads_upsample=-1,
215
+ use_scale_shift_norm=False,
216
+ resblock_updown=False,
217
+ use_new_attention_order=False,
218
+ use_spatial_transformer=False,
219
+ transformer_depth=1,
220
+ context_dim=None,
221
+ n_embed=None,
222
+ legacy=True,
223
+ disable_self_attentions=None,
224
+ num_attention_blocks=None,
225
+ disable_middle_self_attn=False,
226
+ use_linear_in_transformer=False,
227
+ ):
228
+ super().__init__()
229
+ if context_dim is not None:
230
+ if hasattr(context_dim, "__iter__") and not isinstance(context_dim, (list, tuple)):
231
+ context_dim = list(context_dim)
232
+
233
+ if num_heads_upsample == -1:
234
+ num_heads_upsample = num_heads
235
+
236
+ self.dims = dims
237
+ self.in_channels = in_channels
238
+ self.model_channels = model_channels
239
+ self.inject_layers = inject_layers
240
+ if isinstance(num_res_blocks, int):
241
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
242
+ else:
243
+ assert len(num_res_blocks) == len(channel_mult)
244
+ self.num_res_blocks = num_res_blocks
245
+
246
+ self.attention_resolutions = attention_resolutions
247
+ self.dropout = dropout
248
+ self.channel_mult = channel_mult
249
+ self.conv_resample = conv_resample
250
+ self.use_checkpoint = use_checkpoint
251
+ self.dtype = torch.float16 if use_fp16 else torch.float32
252
+ self.num_heads = num_heads
253
+ self.num_head_channels = num_head_channels
254
+ self.num_heads_upsample = num_heads_upsample
255
+ self.predict_codebook_ids = n_embed is not None
256
+
257
+ time_embed_dim = model_channels * 4
258
+ self.time_embed = nn.Sequential(
259
+ linear(model_channels, time_embed_dim),
260
+ nn.SiLU(),
261
+ linear(time_embed_dim, time_embed_dim),
262
+ )
263
+
264
+ self.feature_extractor = FeatureExtractor(local_channels, inject_channels)
265
+ self.input_blocks = nn.ModuleList([
266
+ LocalTimestepEmbedSequential(
267
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
268
+ ),
269
+ ])
270
+ self.zero_convs = nn.ModuleList([self._make_zero_conv(model_channels)])
271
+
272
+ self._feature_size = model_channels
273
+ input_block_chans = [model_channels]
274
+ ch = model_channels
275
+ ds = 1
276
+
277
+ for level, mult in enumerate(channel_mult):
278
+ for nr in range(self.num_res_blocks[level]):
279
+ if (1 + 3 * level + nr) in self.inject_layers:
280
+ layers = [
281
+ LocalResBlock(
282
+ ch,
283
+ time_embed_dim,
284
+ dropout,
285
+ out_channels=mult * model_channels,
286
+ dims=dims,
287
+ use_checkpoint=use_checkpoint,
288
+ inject_channels=inject_channels[level],
289
+ )
290
+ ]
291
+ else:
292
+ layers = [
293
+ ResBlock(
294
+ ch,
295
+ time_embed_dim,
296
+ dropout,
297
+ out_channels=mult * model_channels,
298
+ dims=dims,
299
+ use_checkpoint=use_checkpoint,
300
+ use_scale_shift_norm=use_scale_shift_norm,
301
+ )
302
+ ]
303
+ ch = mult * model_channels
304
+ if ds in attention_resolutions:
305
+ if num_head_channels == -1:
306
+ dim_head = ch // num_heads
307
+ else:
308
+ dim_head = num_head_channels
309
+ if legacy:
310
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
311
+ disabled_sa = (
312
+ disable_self_attentions[level]
313
+ if exists(disable_self_attentions)
314
+ else False
315
+ )
316
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
317
+ block = (
318
+ AttentionBlock(
319
+ ch,
320
+ use_checkpoint=use_checkpoint,
321
+ num_heads=num_heads,
322
+ num_head_channels=dim_head,
323
+ use_new_attention_order=use_new_attention_order,
324
+ )
325
+ if not use_spatial_transformer
326
+ else SpatialTransformer(
327
+ ch,
328
+ num_heads,
329
+ dim_head,
330
+ depth=transformer_depth,
331
+ context_dim=context_dim,
332
+ disable_self_attn=disabled_sa,
333
+ use_linear=use_linear_in_transformer,
334
+ use_checkpoint=use_checkpoint,
335
+ )
336
+ )
337
+ layers.append(block)
338
+ self.input_blocks.append(LocalTimestepEmbedSequential(*layers))
339
+ self.zero_convs.append(self._make_zero_conv(ch))
340
+ self._feature_size += ch
341
+ input_block_chans.append(ch)
342
+ if level != len(channel_mult) - 1:
343
+ out_ch = ch
344
+ down_block = (
345
+ ResBlock(
346
+ ch,
347
+ time_embed_dim,
348
+ dropout,
349
+ out_channels=out_ch,
350
+ dims=dims,
351
+ use_checkpoint=use_checkpoint,
352
+ use_scale_shift_norm=use_scale_shift_norm,
353
+ down=True,
354
+ )
355
+ if resblock_updown
356
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
357
+ )
358
+ self.input_blocks.append(LocalTimestepEmbedSequential(down_block))
359
+ ch = out_ch
360
+ input_block_chans.append(ch)
361
+ self.zero_convs.append(self._make_zero_conv(ch))
362
+ ds *= 2
363
+ self._feature_size += ch
364
+
365
+ if num_head_channels == -1:
366
+ dim_head = ch // num_heads
367
+ else:
368
+ dim_head = num_head_channels
369
+ if legacy:
370
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
371
+ mid_attn = (
372
+ AttentionBlock(
373
+ ch,
374
+ use_checkpoint=use_checkpoint,
375
+ num_heads=num_heads,
376
+ num_head_channels=dim_head,
377
+ use_new_attention_order=use_new_attention_order,
378
+ )
379
+ if not use_spatial_transformer
380
+ else SpatialTransformer(
381
+ ch,
382
+ num_heads,
383
+ dim_head,
384
+ depth=transformer_depth,
385
+ context_dim=context_dim,
386
+ disable_self_attn=disable_middle_self_attn,
387
+ use_linear=use_linear_in_transformer,
388
+ use_checkpoint=use_checkpoint,
389
+ )
390
+ )
391
+ self.middle_block = LocalTimestepEmbedSequential(
392
+ ResBlock(
393
+ ch,
394
+ time_embed_dim,
395
+ dropout,
396
+ dims=dims,
397
+ use_checkpoint=use_checkpoint,
398
+ use_scale_shift_norm=use_scale_shift_norm,
399
+ ),
400
+ mid_attn,
401
+ ResBlock(
402
+ ch,
403
+ time_embed_dim,
404
+ dropout,
405
+ dims=dims,
406
+ use_checkpoint=use_checkpoint,
407
+ use_scale_shift_norm=use_scale_shift_norm,
408
+ ),
409
+ )
410
+ self.middle_block_out = self._make_zero_conv(ch)
411
+ self._feature_size += ch
412
+
413
+ def _make_zero_conv(self, channels):
414
+ return LocalTimestepEmbedSequential(
415
+ zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))
416
+ )
417
+
418
+ def forward(self, x, timesteps, context, local_conditions, **kwargs):
419
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
420
+ emb = self.time_embed(t_emb)
421
+ local_features = self.feature_extractor(local_conditions)
422
+
423
+ outs = []
424
+ h = x.type(self.dtype)
425
+ for layer_idx, (module, zero_conv) in enumerate(zip(self.input_blocks, self.zero_convs)):
426
+ if layer_idx in self.inject_layers:
427
+ feat_idx = self.inject_layers.index(layer_idx)
428
+ h = module(h, emb, context, local_features[feat_idx])
429
+ else:
430
+ h = module(h, emb, context)
431
+ outs.append(zero_conv(h, emb, context))
432
+
433
+ h = self.middle_block(h, emb, context)
434
+ outs.append(self.middle_block_out(h, emb, context))
435
+ return outs
local_adapter/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene utilities - no ldm/models imports."""
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import repeat
7
+
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
14
+ if repeat_only:
15
+ return repeat(timesteps, "b -> b d", d=dim)
16
+ half = dim // 2
17
+ freqs = torch.exp(
18
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
19
+ ).to(device=timesteps.device)
20
+ args = timesteps[:, None].float() * freqs[None]
21
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22
+ if dim % 2:
23
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
24
+ return embedding
25
+
26
+
27
+ def conv_nd(dims, *args, **kwargs):
28
+ if dims == 1:
29
+ return nn.Conv1d(*args, **kwargs)
30
+ elif dims == 2:
31
+ return nn.Conv2d(*args, **kwargs)
32
+ elif dims == 3:
33
+ return nn.Conv3d(*args, **kwargs)
34
+ raise ValueError(f"unsupported dimensions: {dims}")
35
+
36
+
37
+ def linear(*args, **kwargs):
38
+ return nn.Linear(*args, **kwargs)
39
+
40
+
41
+ def zero_module(module):
42
+ for p in module.parameters():
43
+ p.detach().zero_()
44
+ return module
45
+
46
+
47
+ def checkpoint(func, inputs, params, flag):
48
+ if flag:
49
+ return _CheckpointFunction.apply(func, len(inputs), *(tuple(inputs) + tuple(params)))
50
+ return func(*inputs)
51
+
52
+
53
+ class _CheckpointFunction(torch.autograd.Function):
54
+ @staticmethod
55
+ def forward(ctx, run_function, length, *args):
56
+ ctx.run_function = run_function
57
+ ctx.input_tensors = list(args[:length])
58
+ ctx.input_params = list(args[length:])
59
+ ctx.gpu_autocast_kwargs = {
60
+ "enabled": torch.is_autocast_enabled(),
61
+ "dtype": torch.get_autocast_gpu_dtype(),
62
+ "cache_enabled": torch.is_autocast_cache_enabled(),
63
+ }
64
+ with torch.no_grad():
65
+ output_tensors = ctx.run_function(*ctx.input_tensors)
66
+ return output_tensors
67
+
68
+ @staticmethod
69
+ def backward(ctx, *output_grads):
70
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
71
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
72
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
73
+ output_tensors = ctx.run_function(*shallow_copies)
74
+ input_grads = torch.autograd.grad(
75
+ output_tensors,
76
+ ctx.input_tensors + ctx.input_params,
77
+ output_grads,
78
+ allow_unused=True,
79
+ )
80
+ return (None, None) + input_grads
81
+
82
+
83
+ def normalization(channels):
84
+ return GroupNorm32(32, channels)
85
+
86
+
87
+ class GroupNorm32(nn.GroupNorm):
88
+ def forward(self, x):
89
+ return super().forward(x.float()).type(x.dtype)
90
+
metadata_encoder/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Metadata encoder for HSIGene."""
2
+
3
+ from .model import MetadataEmbeddings, metadata_embeddings
4
+
5
+ __all__ = ["MetadataEmbeddings", "metadata_embeddings"]
metadata_encoder/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.metadata_embeddings",
3
+ "max_value": 1000,
4
+ "embedding_dim": 320,
5
+ "metadata_dim": 7,
6
+ "max_period": 10000
7
+ }
metadata_encoder/model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metadata embeddings - SinusoidalEmbedding + MLPs for metadata conditioning."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class SinusoidalEmbedding(nn.Module):
8
+ """Sinusoidal embedding for metadata."""
9
+
10
+ def __init__(self, max_value, embedding_dim):
11
+ super().__init__()
12
+ self.max_value = max_value
13
+ self.embedding_dim = embedding_dim
14
+ self.omega = 10000.0
15
+
16
+ def forward(self, k):
17
+ device = k.device
18
+ k_normalized = k * self.max_value
19
+ embedding = torch.zeros(
20
+ (k.size(0), k.size(1), self.embedding_dim),
21
+ device=device,
22
+ dtype=k.dtype,
23
+ )
24
+ for j in range(k.size(1)):
25
+ for i in range(self.embedding_dim // 2):
26
+ omega_term = self.omega ** (-2 * i / self.embedding_dim)
27
+ embedding[:, j, 2 * i] = torch.sin(k_normalized[:, j] * omega_term)
28
+ embedding[:, j, 2 * i + 1] = torch.cos(k_normalized[:, j] * omega_term)
29
+ return embedding.view(k.size(0), -1)
30
+
31
+
32
+ def create_condition_vector(embedded_metadata, mlp_models, embedding_dim):
33
+ """Create condition vector from metadata embeddings and MLPs."""
34
+ metadata_embeddings = [
35
+ mlp_models[j](embedded_metadata[:, j * embedding_dim : (j + 1) * embedding_dim])
36
+ for j in range(len(mlp_models))
37
+ ]
38
+ return sum(metadata_embeddings)
39
+
40
+
41
+ class MetadataMLP(nn.Module):
42
+ def __init__(self, input_dim, embedding_dim):
43
+ super().__init__()
44
+ self.fc1 = nn.Linear(input_dim, embedding_dim)
45
+
46
+ def forward(self, x):
47
+ return self.fc1(x)
48
+
49
+
50
+ class MetadataEmbeddings(nn.Module):
51
+ """Metadata embeddings - SinusoidalEmbedding + MLPs."""
52
+
53
+ def __init__(self, max_value, embedding_dim, max_period, metadata_dim):
54
+ super().__init__()
55
+ self.sinusoidal_embedding = SinusoidalEmbedding(max_value, embedding_dim)
56
+ self.mlp_models = nn.ModuleList([
57
+ MetadataMLP(embedding_dim, embedding_dim * 4)
58
+ for _ in range(metadata_dim)
59
+ ])
60
+ self.max_period = max_period
61
+ self.embedding_dim = embedding_dim
62
+ self.metadata_dim = metadata_dim
63
+ self.max_value = max_value
64
+
65
+ def forward(self, metadata=None):
66
+ while isinstance(metadata, (list, tuple)) and len(metadata) == 1:
67
+ metadata = metadata[0]
68
+ if metadata.dim() == 1:
69
+ metadata = metadata.unsqueeze(0)
70
+ embedded_metadata = self.sinusoidal_embedding(metadata)
71
+ return create_condition_vector(
72
+ embedded_metadata, self.mlp_models, self.embedding_dim
73
+ )
74
+
75
+
76
+ # Alias for config compatibility
77
+ metadata_embeddings = MetadataEmbeddings
model_index.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline_hsigene", "HSIGenePipeline"],
3
+ "_diffusers_version": "0.25.0",
4
+ "scheduler": ["diffusers", "DDIMScheduler"],
5
+ "unet": ["pipeline_hsigene", "HSIGenePipeline"],
6
+ "vae": ["pipeline_hsigene", "HSIGenePipeline"],
7
+ "text_encoder": ["pipeline_hsigene", "HSIGenePipeline"],
8
+ "local_adapter": ["pipeline_hsigene", "HSIGenePipeline"],
9
+ "global_content_adapter": ["pipeline_hsigene", "HSIGenePipeline"],
10
+ "global_text_adapter": ["pipeline_hsigene", "HSIGenePipeline"],
11
+ "metadata_encoder": ["pipeline_hsigene", "HSIGenePipeline"],
12
+ "scale_factor": 0.18215,
13
+ "conditioning_key": "crossattn"
14
+ }
modular_pipeline.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HSIGene modular components: path setup and component loading.
3
+
4
+ AeroGen-style: ensure_ldm_path adds model dir to sys.path so hsigene can be imported.
5
+ No manual sys.path.insert needed when using DiffusionPipeline.from_pretrained(path).
6
+ """
7
+
8
+ import importlib
9
+ import json
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import Union
13
+
14
+ from diffusers import DDIMScheduler
15
+
16
+ # Ensure model dir is on path for hsigene imports
17
+ _pipeline_dir = Path(__file__).resolve().parent
18
+ if str(_pipeline_dir) not in sys.path:
19
+ sys.path.insert(0, str(_pipeline_dir))
20
+
21
+ _COMPONENT_NAMES = (
22
+ "unet", "vae", "text_encoder", "local_adapter",
23
+ "global_content_adapter", "global_text_adapter", "metadata_encoder",
24
+ )
25
+
26
+ _TARGET_MAP = {
27
+ "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet",
28
+ "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet",
29
+ "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
30
+ "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
31
+ "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder",
32
+ "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder",
33
+ "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter",
34
+ "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter",
35
+ "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
36
+ "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
37
+ "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
38
+ "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
39
+ "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
40
+ "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
41
+ }
42
+
43
+
44
+ def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
45
+ """Add model repo to path so hsigene can be imported. Returns resolved path."""
46
+ path = Path(pretrained_model_name_or_path)
47
+ if not path.exists():
48
+ from huggingface_hub import snapshot_download
49
+ path = Path(snapshot_download(pretrained_model_name_or_path))
50
+ path = path.resolve()
51
+ s = str(path)
52
+ if s not in sys.path:
53
+ sys.path.insert(0, s)
54
+ return path
55
+
56
+
57
+ def _get_class(target: str):
58
+ module_path, cls_name = target.rsplit(".", 1)
59
+ mod = importlib.import_module(module_path)
60
+ return getattr(mod, cls_name)
61
+
62
+
63
+ def load_component(model_path: Path, name: str):
64
+ """Load a single component (unet, vae, text_encoder, etc.)."""
65
+ import torch
66
+ path = Path(model_path)
67
+ root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path
68
+ ensure_ldm_path(root)
69
+ comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name
70
+ with open(comp_path / "config.json") as f:
71
+ cfg = json.load(f)
72
+ target = cfg.pop("_target", None)
73
+ if not target:
74
+ raise ValueError(f"No _target in {comp_path / 'config.json'}")
75
+ target = _TARGET_MAP.get(target, target)
76
+ cls_ref = _get_class(target)
77
+ params = {k: v for k, v in cfg.items() if not k.startswith("_")}
78
+ comp = cls_ref(**params)
79
+ for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
80
+ wp = comp_path / wfile
81
+ if wp.exists():
82
+ if wfile.endswith(".safetensors"):
83
+ from safetensors.torch import load_file
84
+ state = load_file(str(wp))
85
+ else:
86
+ try:
87
+ state = torch.load(wp, map_location="cpu", weights_only=True)
88
+ except TypeError:
89
+ state = torch.load(wp, map_location="cpu")
90
+ comp.load_state_dict(state, strict=True)
91
+ break
92
+ comp.eval()
93
+ return comp
94
+
95
+
96
+ def load_components(model_path: Union[str, Path]) -> dict:
97
+ """Load all pipeline components. Returns dict with components, scheduler, scale_factor."""
98
+ path = Path(ensure_ldm_path(model_path))
99
+ if path.name in _COMPONENT_NAMES and (path / "config.json").exists():
100
+ path = path.parent
101
+ scheduler = DDIMScheduler.from_pretrained(path / "scheduler")
102
+ components = {}
103
+ for name in _COMPONENT_NAMES:
104
+ components[name] = load_component(path, name)
105
+ scale_factor = 0.18215
106
+ if (path / "model_index.json").exists():
107
+ with open(path / "model_index.json") as f:
108
+ scale_factor = json.load(f).get("scale_factor", scale_factor)
109
+ components["scheduler"] = scheduler
110
+ components["scale_factor"] = scale_factor
111
+ return components
pipeline_hsigene.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGenePipeline - diffusers DiffusionPipeline for HSIGene hyperspectral generation.
2
+
3
+ AeroGen-style loading: use DiffusionPipeline.from_pretrained(path) - no sys.path.insert needed.
4
+ Self-contained: loading logic inlined (no separate modular_pipeline import).
5
+ """
6
+
7
+ import importlib
8
+ import json
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import List, Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from dataclasses import dataclass
17
+
18
+ from diffusers import DDIMScheduler, DiffusionPipeline
19
+ from diffusers.utils import BaseOutput
20
+
21
+ # Re-export for diffusers component loading (load_method lookup)
22
+ DiffusionPipeline = DiffusionPipeline
23
+
24
+ # Inline path/loading (AeroGen-style) - self-contained for diffusers cache loading
25
+ _pipeline_dir = Path(__file__).resolve().parent
26
+ if str(_pipeline_dir) not in sys.path:
27
+ sys.path.insert(0, str(_pipeline_dir))
28
+
29
+ # Register as "pipeline_hsigene" so diffusers' get_class_obj_and_candidates finds us when it does
30
+ # importlib.import_module("pipeline_hsigene") during component loading. (We may be loaded as
31
+ # "diffusers_modules.local.xxx.pipeline_hsigene" from cache, so this alias is required.)
32
+ sys.modules["pipeline_hsigene"] = sys.modules[__name__]
33
+
34
+ _COMPONENT_NAMES = (
35
+ "unet", "vae", "text_encoder", "local_adapter",
36
+ "global_content_adapter", "global_text_adapter", "metadata_encoder",
37
+ )
38
+
39
+ _TARGET_MAP = {
40
+ "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet",
41
+ "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet",
42
+ "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
43
+ "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
44
+ "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder",
45
+ "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder",
46
+ "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter",
47
+ "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter",
48
+ "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
49
+ "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
50
+ "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
51
+ "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
52
+ "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
53
+ "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
54
+ }
55
+
56
+
57
+ def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
58
+ """Add model repo to path so hsigene can be imported. Returns resolved path."""
59
+ path = Path(pretrained_model_name_or_path)
60
+ if not path.exists():
61
+ from huggingface_hub import snapshot_download
62
+ path = Path(snapshot_download(pretrained_model_name_or_path))
63
+ path = path.resolve()
64
+ s = str(path)
65
+ if s not in sys.path:
66
+ sys.path.insert(0, s)
67
+ return path
68
+
69
+
70
+ def _get_class(target: str):
71
+ module_path, cls_name = target.rsplit(".", 1)
72
+ mod = importlib.import_module(module_path)
73
+ return getattr(mod, cls_name)
74
+
75
+
76
+ def load_component(model_path: Path, name: str):
77
+ """Load a single component (unet, vae, text_encoder, etc.)."""
78
+ path = Path(model_path)
79
+ root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path
80
+ ensure_ldm_path(root)
81
+ comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name
82
+ with open(comp_path / "config.json") as f:
83
+ cfg = json.load(f)
84
+ target = cfg.pop("_target", None)
85
+ if not target:
86
+ raise ValueError(f"No _target in {comp_path / 'config.json'}")
87
+ target = _TARGET_MAP.get(target, target)
88
+ cls_ref = _get_class(target)
89
+ params = {k: v for k, v in cfg.items() if not k.startswith("_")}
90
+ comp = cls_ref(**params)
91
+ for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
92
+ wp = comp_path / wfile
93
+ if wp.exists():
94
+ if wfile.endswith(".safetensors"):
95
+ from safetensors.torch import load_file
96
+ state = load_file(str(wp))
97
+ else:
98
+ try:
99
+ state = torch.load(wp, map_location="cpu", weights_only=True)
100
+ except TypeError:
101
+ state = torch.load(wp, map_location="cpu")
102
+ comp.load_state_dict(state, strict=True)
103
+ break
104
+ comp.eval()
105
+ return comp
106
+
107
+
108
+ def load_components(model_path: Union[str, Path]) -> dict:
109
+ """Load all pipeline components."""
110
+ path = Path(ensure_ldm_path(model_path))
111
+ if path.name in _COMPONENT_NAMES and (path / "config.json").exists():
112
+ path = path.parent
113
+ scheduler = DDIMScheduler.from_pretrained(path / "scheduler")
114
+ components = {}
115
+ for name in _COMPONENT_NAMES:
116
+ components[name] = load_component(path, name)
117
+ scale_factor = 0.18215
118
+ if (path / "model_index.json").exists():
119
+ with open(path / "model_index.json") as f:
120
+ scale_factor = json.load(f).get("scale_factor", scale_factor)
121
+ components["scheduler"] = scheduler
122
+ components["scale_factor"] = scale_factor
123
+ return components
124
+
125
+
126
+ class _CRSModelWrapper(torch.nn.Module):
127
+ """Wrapper that mimics CRSControlNet interface."""
128
+
129
+ def __init__(
130
+ self,
131
+ unet,
132
+ vae,
133
+ text_encoder,
134
+ local_adapter,
135
+ global_content_adapter,
136
+ global_text_adapter,
137
+ metadata_emb,
138
+ scale_factor=0.18215,
139
+ local_control_scales=None,
140
+ ):
141
+ super().__init__()
142
+ self.model = type("Model", (), {"diffusion_model": unet})()
143
+ self.first_stage_model = vae
144
+ self.cond_stage_model = text_encoder
145
+ self.local_adapter = local_adapter
146
+ self.global_content_adapter = global_content_adapter
147
+ self.global_text_adapter = global_text_adapter
148
+ self.metadata_emb = metadata_emb
149
+ self.scale_factor = scale_factor
150
+ self.local_control_scales = local_control_scales or [1.0] * 13
151
+
152
+ @torch.no_grad()
153
+ def get_learned_conditioning(self, prompts):
154
+ return self.cond_stage_model(prompts)
155
+
156
+ def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, text_strength=1.0, **kwargs):
157
+ if metadata is None:
158
+ metadata = cond["metadata"]
159
+ metadata_emb = self.metadata_emb(metadata)
160
+ content_t = cond["global_control"][0]
161
+ global_control = self.global_content_adapter(content_t)
162
+ cond_txt = torch.cat(cond["c_crossattn"], 1)
163
+ cond_txt = self.global_text_adapter(cond_txt)
164
+ cond_txt = F.normalize(cond_txt, p=2, dim=-1) * text_strength
165
+ global_control = F.normalize(global_control, p=2, dim=-1) * global_strength
166
+ cond_txt = torch.cat([cond_txt, global_control], dim=1)
167
+ local_control = torch.cat(cond["local_control"], 1)
168
+ local_control = self.local_adapter(
169
+ x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control
170
+ )
171
+ local_control = [c * s for c, s in zip(local_control, self.local_control_scales)]
172
+ return self.model.diffusion_model(
173
+ x=x_noisy,
174
+ timesteps=t,
175
+ metadata=metadata_emb,
176
+ context=cond_txt,
177
+ local_control=local_control,
178
+ meta=True,
179
+ )
180
+
181
+ def decode_first_stage(self, z):
182
+ z = (1.0 / self.scale_factor) * z
183
+ return self.first_stage_model.decode(z)
184
+
185
+ def low_vram_shift(self, is_diffusing):
186
+ if is_diffusing:
187
+ self.model.diffusion_model = self.model.diffusion_model.cuda()
188
+ self.local_adapter = self.local_adapter.cuda()
189
+ self.global_text_adapter = self.global_text_adapter.cuda()
190
+ self.global_content_adapter = self.global_content_adapter.cuda()
191
+ self.first_stage_model = self.first_stage_model.cpu()
192
+ self.cond_stage_model = self.cond_stage_model.cpu()
193
+ else:
194
+ self.model.diffusion_model = self.model.diffusion_model.cpu()
195
+ self.local_adapter = self.local_adapter.cpu()
196
+ self.global_text_adapter = self.global_text_adapter.cpu()
197
+ self.global_content_adapter = self.global_content_adapter.cpu()
198
+ self.first_stage_model = self.first_stage_model.cuda()
199
+ self.cond_stage_model = self.cond_stage_model.cuda()
200
+
201
+
202
+ @dataclass
203
+ class HSIGeneOutput(BaseOutput):
204
+ """Output class for HSIGene pipeline."""
205
+
206
+ images: Optional[np.ndarray] = None
207
+ latents: Optional[torch.Tensor] = None
208
+
209
+
210
+ def _is_component_list(v):
211
+ """Check if value is raw config format [library, class_name]."""
212
+ return isinstance(v, (list, tuple)) and len(v) == 2 and isinstance(v[0], str) and isinstance(v[1], str)
213
+
214
+
215
+ class HSIGenePipeline(DiffusionPipeline):
216
+ """Pipeline for HSIGene hyperspectral image generation.
217
+
218
+ AeroGen-style: load with DiffusionPipeline.from_pretrained(path) - no sys.path.insert.
219
+ """
220
+
221
+ def register_modules(self, **kwargs):
222
+ """Override to handle list-format component specs from diffusers config."""
223
+ for name, module in kwargs.items():
224
+ if module is None or (isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None):
225
+ self.register_to_config(**{name: (None, None)})
226
+ setattr(self, name, module)
227
+ elif _is_component_list(module):
228
+ self.register_to_config(**{name: (module[0], module[1])})
229
+ setattr(self, name, module)
230
+ else:
231
+ from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple
232
+ library, class_name = _fetch_class_library_tuple(module)
233
+ self.register_to_config(**{name: (library, class_name)})
234
+ setattr(self, name, module)
235
+
236
+ def __init__(
237
+ self,
238
+ unet=None,
239
+ vae=None,
240
+ text_encoder=None,
241
+ local_adapter=None,
242
+ global_content_adapter=None,
243
+ global_text_adapter=None,
244
+ metadata_encoder=None,
245
+ scheduler=None,
246
+ crs_model=None,
247
+ scale_factor=0.18215,
248
+ ):
249
+ super().__init__()
250
+ if crs_model is not None:
251
+ self.register_modules(crs_model=crs_model, scheduler=scheduler)
252
+ else:
253
+ if any(_is_component_list(x) for x in (unet, vae, text_encoder, local_adapter,
254
+ global_content_adapter, global_text_adapter, metadata_encoder) if x is not None):
255
+ raise ValueError(
256
+ "HSIGene received raw config (list) instead of loaded components. "
257
+ "Use HSIGenePipeline.from_pretrained(path) directly, or ensure the model "
258
+ "directory (with hsigene package) is on the path when loading."
259
+ )
260
+ crs_model = _CRSModelWrapper(
261
+ unet=unet,
262
+ vae=vae,
263
+ text_encoder=text_encoder,
264
+ local_adapter=local_adapter,
265
+ global_content_adapter=global_content_adapter,
266
+ global_text_adapter=global_text_adapter,
267
+ metadata_emb=metadata_encoder,
268
+ scale_factor=scale_factor,
269
+ )
270
+ self.register_modules(
271
+ unet=unet,
272
+ vae=vae,
273
+ text_encoder=text_encoder,
274
+ local_adapter=local_adapter,
275
+ global_content_adapter=global_content_adapter,
276
+ global_text_adapter=global_text_adapter,
277
+ metadata_encoder=metadata_encoder,
278
+ scheduler=scheduler,
279
+ crs_model=crs_model,
280
+ )
281
+
282
+ @classmethod
283
+ def from_pretrained(
284
+ cls,
285
+ pretrained_model_name_or_path: Union[str, Path],
286
+ device: Optional[Union[str, torch.device]] = None,
287
+ subfolder: Optional[str] = None,
288
+ **kwargs,
289
+ ):
290
+ """Load from diffusers-format directory. Supports subfolder for single-component loading."""
291
+ path = Path(ensure_ldm_path(pretrained_model_name_or_path))
292
+ subfolder = kwargs.pop("subfolder", subfolder)
293
+
294
+ if subfolder in ("unet", "vae", "text_encoder", "local_adapter",
295
+ "global_content_adapter", "global_text_adapter", "metadata_encoder"):
296
+ return load_component(path, subfolder)
297
+
298
+ if path.name in ("unet", "vae", "text_encoder", "local_adapter",
299
+ "global_content_adapter", "global_text_adapter", "metadata_encoder"):
300
+ if (path / "config.json").exists():
301
+ ensure_ldm_path(path.parent)
302
+ return load_component(path.parent, path.name)
303
+
304
+ if not (path / "model_index.json").exists():
305
+ for _ in range(5):
306
+ parent = path.parent
307
+ if (parent / "model_index.json").exists():
308
+ path = parent
309
+ break
310
+ if parent == path:
311
+ break
312
+ path = parent
313
+
314
+ components = load_components(path)
315
+ pipe = cls(
316
+ unet=components["unet"],
317
+ vae=components["vae"],
318
+ text_encoder=components["text_encoder"],
319
+ local_adapter=components["local_adapter"],
320
+ global_content_adapter=components["global_content_adapter"],
321
+ global_text_adapter=components["global_text_adapter"],
322
+ metadata_encoder=components["metadata_encoder"],
323
+ scheduler=components["scheduler"],
324
+ scale_factor=components["scale_factor"],
325
+ )
326
+ if device is not None:
327
+ pipe = pipe.to(device)
328
+ return pipe
329
+
330
+ @torch.no_grad()
331
+ def __call__(
332
+ self,
333
+ prompt: Union[str, List[str]] = "",
334
+ num_samples: int = 1,
335
+ height: int = 256,
336
+ width: int = 256,
337
+ num_inference_steps: int = 50,
338
+ eta: float = 0.0,
339
+ global_strength: float = 1.0,
340
+ text_strength: Optional[float] = None,
341
+ local_conditions: Optional[torch.Tensor] = None,
342
+ global_conditions: Optional[torch.Tensor] = None,
343
+ metadata: Optional[torch.Tensor] = None,
344
+ condition_resolution: int = 512,
345
+ guidance_scale: float = 1.0,
346
+ negative_prompt: Optional[Union[str, List[str]]] = None,
347
+ generator: Optional[torch.Generator] = None,
348
+ latents: Optional[torch.Tensor] = None,
349
+ output_type: str = "numpy",
350
+ return_dict: bool = True,
351
+ save_memory: bool = False,
352
+ ):
353
+ device = next(self.crs_model.parameters()).device
354
+ if text_strength is None:
355
+ text_strength = global_strength
356
+
357
+ if isinstance(prompt, str):
358
+ prompts = [prompt] * num_samples
359
+ else:
360
+ prompts = list(prompt)
361
+ num_samples = len(prompts)
362
+
363
+ if save_memory:
364
+ self.crs_model.low_vram_shift(is_diffusing=False)
365
+
366
+ text_embedding = self.crs_model.get_learned_conditioning(prompts)
367
+
368
+ if local_conditions is None:
369
+ local_conditions = torch.zeros(
370
+ num_samples, 18, condition_resolution, condition_resolution,
371
+ device=device, dtype=torch.float32,
372
+ )
373
+ else:
374
+ local_conditions = local_conditions.to(device=device, dtype=torch.float32)
375
+
376
+ if global_conditions is None:
377
+ global_conditions = torch.zeros(
378
+ num_samples, 768, device=device, dtype=torch.float32,
379
+ )
380
+ else:
381
+ global_conditions = global_conditions.to(device=device, dtype=torch.float32)
382
+
383
+ if metadata is None:
384
+ metadata = torch.zeros(7, device=device, dtype=torch.float32)
385
+ else:
386
+ metadata = metadata.to(device=device, dtype=torch.float32)
387
+
388
+ cond = {
389
+ "local_control": [local_conditions],
390
+ "c_crossattn": [text_embedding],
391
+ "global_control": [global_conditions],
392
+ "metadata": [metadata],
393
+ }
394
+
395
+ do_cfg = guidance_scale > 1.0
396
+ if do_cfg:
397
+ if negative_prompt is None:
398
+ neg_prompts = [""] * num_samples
399
+ elif isinstance(negative_prompt, str):
400
+ neg_prompts = [negative_prompt] * num_samples
401
+ else:
402
+ neg_prompts = list(negative_prompt)
403
+ uc_text = self.crs_model.get_learned_conditioning(neg_prompts)
404
+ uncond = {
405
+ "local_control": [local_conditions],
406
+ "c_crossattn": [uc_text],
407
+ "global_control": [torch.zeros_like(global_conditions)],
408
+ "metadata": [metadata],
409
+ }
410
+
411
+ latent_shape = (num_samples, 4, height // 4, width // 4)
412
+ if latents is None:
413
+ latents = torch.randn(
414
+ latent_shape, device=device, generator=generator, dtype=torch.float32,
415
+ )
416
+ else:
417
+ latents = latents.to(device)
418
+
419
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
420
+
421
+ if save_memory:
422
+ self.crs_model.low_vram_shift(is_diffusing=True)
423
+
424
+ for t in self.progress_bar(self.scheduler.timesteps):
425
+ t_batch = t.expand(num_samples)
426
+ if do_cfg:
427
+ noise_pred_cond = self.crs_model.apply_model(
428
+ latents, t_batch, cond,
429
+ metadata=metadata,
430
+ global_strength=global_strength,
431
+ text_strength=text_strength,
432
+ )
433
+ noise_pred_uncond = self.crs_model.apply_model(
434
+ latents, t_batch, uncond,
435
+ metadata=metadata,
436
+ global_strength=global_strength,
437
+ text_strength=text_strength,
438
+ )
439
+ noise_pred = noise_pred_uncond + guidance_scale * (
440
+ noise_pred_cond - noise_pred_uncond
441
+ )
442
+ else:
443
+ noise_pred = self.crs_model.apply_model(
444
+ latents, t_batch, cond,
445
+ metadata=metadata,
446
+ global_strength=global_strength,
447
+ text_strength=text_strength,
448
+ )
449
+ latents = self.scheduler.step(
450
+ noise_pred, t, latents, eta=eta, generator=generator,
451
+ ).prev_sample
452
+
453
+ if output_type == "latent":
454
+ if not return_dict:
455
+ return (latents,)
456
+ return HSIGeneOutput(latents=latents)
457
+
458
+ if save_memory:
459
+ self.crs_model.low_vram_shift(is_diffusing=False)
460
+
461
+ images = self.crs_model.decode_first_stage(latents)
462
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
463
+ images = images * 0.5 + 0.5
464
+ images = np.clip(images, 0, 1)
465
+
466
+ if not return_dict:
467
+ return (images,)
468
+ return HSIGeneOutput(images=images)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.37.0",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.0001,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 0,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """HSIGene text encoder component."""
text_encoder/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (203 Bytes). View file
 
text_encoder/__pycache__/model.cpython-312.pyc ADDED
Binary file (2.35 kB). View file
 
text_encoder/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.CLIPTextEncoder",
3
+ "version": "openai/clip-vit-large-patch14"
4
+ }
text_encoder/model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLIP text encoder - same interface as FrozenCLIPEmbedder (forward(text) returns last_hidden_state)."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import CLIPTokenizer, CLIPTextModel
6
+
7
+
8
+ class CLIPTextEncoder(nn.Module):
9
+ """CLIP text encoder wrapping transformers CLIPTokenizer + CLIPTextModel.
10
+ Same interface as FrozenCLIPEmbedder: forward(text) returns last_hidden_state.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ version: str = "openai/clip-vit-large-patch14",
16
+ max_length: int = 77,
17
+ freeze: bool = True,
18
+ ):
19
+ super().__init__()
20
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
21
+ self.transformer = CLIPTextModel.from_pretrained(version)
22
+ self.max_length = max_length
23
+ if freeze:
24
+ self.transformer.eval()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def forward(self, text):
29
+ """Encode text. Returns last_hidden_state (B, seq_len, dim)."""
30
+ if isinstance(text, str):
31
+ text = [text]
32
+ batch_encoding = self.tokenizer(
33
+ text,
34
+ truncation=True,
35
+ max_length=self.max_length,
36
+ padding="max_length",
37
+ return_tensors="pt",
38
+ )
39
+ tokens = batch_encoding["input_ids"].to(next(self.parameters()).device)
40
+ outputs = self.transformer(input_ids=tokens)
41
+ return outputs.last_hidden_state
unet/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """HSIGene UNet component."""
2
+
3
+ from .model import HSIGeneUNet
4
+
5
+ __all__ = ["HSIGeneUNet"]
unet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (261 Bytes). View file
 
unet/__pycache__/attention.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
unet/__pycache__/diffusion.cpython-312.pyc ADDED
Binary file (22.7 kB). View file
 
unet/__pycache__/model.cpython-312.pyc ADDED
Binary file (1.81 kB). View file
 
unet/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.93 kB). View file
 
unet/attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene attention modules - FeedForward, CrossAttention, SpatialTransformer."""
2
+
3
+ from inspect import isfunction
4
+ from typing import Optional, Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ from torch import einsum
11
+
12
+ from .utils import checkpoint, zero_module, exists
13
+
14
+ try:
15
+ import xformers
16
+ import xformers.ops
17
+ XFORMERS_IS_AVAILABLE = True
18
+ except ImportError:
19
+ XFORMERS_IS_AVAILABLE = False
20
+
21
+
22
+ def default(val, d):
23
+ if exists(val):
24
+ return val
25
+ return d() if isfunction(d) else d
26
+
27
+
28
+ import os
29
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
30
+
31
+
32
+ class GEGLU(nn.Module):
33
+ def __init__(self, dim_in, dim_out):
34
+ super().__init__()
35
+ self.proj = nn.Linear(dim_in, dim_out * 2)
36
+
37
+ def forward(self, x):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
42
+ class FeedForward(nn.Module):
43
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
44
+ super().__init__()
45
+ inner_dim = int(dim * mult)
46
+ dim_out = default(dim_out, dim)
47
+ project_in = (
48
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
49
+ if not glu
50
+ else GEGLU(dim, inner_dim)
51
+ )
52
+ self.net = nn.Sequential(
53
+ project_in,
54
+ nn.Dropout(dropout),
55
+ nn.Linear(inner_dim, dim_out),
56
+ )
57
+
58
+ def forward(self, x):
59
+ return self.net(x)
60
+
61
+
62
+ def Normalize(in_channels, num_groups=32):
63
+ return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
64
+
65
+
66
+ class CrossAttention(nn.Module):
67
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
68
+ super().__init__()
69
+ inner_dim = dim_head * heads
70
+ context_dim = default(context_dim, query_dim)
71
+ self.scale = dim_head ** -0.5
72
+ self.heads = heads
73
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
74
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
75
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
76
+ self.to_out = nn.Sequential(
77
+ nn.Linear(inner_dim, query_dim),
78
+ nn.Dropout(dropout),
79
+ )
80
+
81
+ def forward(self, x, context=None, mask=None):
82
+ h = self.heads
83
+ q = self.to_q(x)
84
+ context = default(context, x)
85
+ k = self.to_k(context)
86
+ v = self.to_v(context)
87
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
88
+ if _ATTN_PRECISION == "fp32":
89
+ with torch.autocast(enabled=False, device_type="cuda"):
90
+ q, k = q.float(), k.float()
91
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
92
+ else:
93
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
94
+ del q, k
95
+ if exists(mask):
96
+ mask = rearrange(mask, "b ... -> b (...)")
97
+ max_neg_value = -torch.finfo(sim.dtype).max
98
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
99
+ sim.masked_fill_(~mask, max_neg_value)
100
+ sim = sim.softmax(dim=-1)
101
+ out = einsum("b i j, b j d -> b i d", sim, v)
102
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
103
+ return self.to_out(out)
104
+
105
+
106
+ class MemoryEfficientCrossAttention(nn.Module):
107
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
108
+ super().__init__()
109
+ inner_dim = dim_head * heads
110
+ context_dim = default(context_dim, query_dim)
111
+ self.heads = heads
112
+ self.dim_head = dim_head
113
+ self.scale = dim_head ** -0.5
114
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
115
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
116
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
117
+ self.to_out = nn.Sequential(
118
+ nn.Linear(inner_dim, query_dim),
119
+ nn.Dropout(dropout),
120
+ )
121
+ self.attention_op: Optional[Any] = None
122
+
123
+ def forward(self, x, context=None, mask=None):
124
+ q = self.to_q(x)
125
+ context = default(context, x)
126
+ k = self.to_k(context)
127
+ v = self.to_v(context)
128
+ b, _, _ = q.shape
129
+ q, k, v = map(
130
+ lambda t: t.unsqueeze(3)
131
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
132
+ .permute(0, 2, 1, 3)
133
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
134
+ .contiguous(),
135
+ (q, k, v),
136
+ )
137
+ if XFORMERS_IS_AVAILABLE:
138
+ out = xformers.ops.memory_efficient_attention(
139
+ q, k, v, attn_bias=None, op=self.attention_op
140
+ )
141
+ else:
142
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
143
+ sim = sim.softmax(dim=-1)
144
+ out = torch.einsum("b i j, b j d -> b i d", sim, v)
145
+ out = (
146
+ out.unsqueeze(0)
147
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
148
+ .permute(0, 2, 1, 3)
149
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
150
+ )
151
+ return self.to_out(out)
152
+
153
+
154
+ class BasicTransformerBlock(nn.Module):
155
+ ATTENTION_MODES = {
156
+ "softmax": CrossAttention,
157
+ "softmax-xformers": MemoryEfficientCrossAttention,
158
+ }
159
+
160
+ def __init__(
161
+ self,
162
+ dim,
163
+ n_heads,
164
+ d_head,
165
+ dropout=0.0,
166
+ context_dim=None,
167
+ gated_ff=True,
168
+ checkpoint=True,
169
+ disable_self_attn=False,
170
+ ):
171
+ super().__init__()
172
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILABLE else "softmax"
173
+ attn_cls = self.ATTENTION_MODES[attn_mode]
174
+ self.disable_self_attn = disable_self_attn
175
+ self.attn1 = attn_cls(
176
+ query_dim=dim,
177
+ heads=n_heads,
178
+ dim_head=d_head,
179
+ dropout=dropout,
180
+ context_dim=context_dim if self.disable_self_attn else None,
181
+ )
182
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
183
+ self.attn2 = attn_cls(
184
+ query_dim=dim,
185
+ context_dim=context_dim,
186
+ heads=n_heads,
187
+ dim_head=d_head,
188
+ dropout=dropout,
189
+ )
190
+ self.norm1 = nn.LayerNorm(dim)
191
+ self.norm2 = nn.LayerNorm(dim)
192
+ self.norm3 = nn.LayerNorm(dim)
193
+ self.checkpoint = checkpoint
194
+
195
+ def forward(self, x, context=None):
196
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
197
+
198
+ def _forward(self, x, context=None):
199
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
200
+ x = self.attn2(self.norm2(x), context=context) + x
201
+ x = self.ff(self.norm3(x)) + x
202
+ return x
203
+
204
+
205
+ class SpatialTransformer(nn.Module):
206
+ def __init__(
207
+ self,
208
+ in_channels,
209
+ n_heads,
210
+ d_head,
211
+ depth=1,
212
+ dropout=0.0,
213
+ context_dim=None,
214
+ disable_self_attn=False,
215
+ use_linear=False,
216
+ use_checkpoint=True,
217
+ ):
218
+ super().__init__()
219
+ if exists(context_dim) and not isinstance(context_dim, list):
220
+ context_dim = [context_dim]
221
+ self.in_channels = in_channels
222
+ inner_dim = n_heads * d_head
223
+ self.norm = Normalize(in_channels)
224
+ if not use_linear:
225
+ self.proj_in = nn.Conv2d(
226
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
227
+ )
228
+ else:
229
+ self.proj_in = nn.Linear(in_channels, inner_dim)
230
+ self.transformer_blocks = nn.ModuleList(
231
+ [
232
+ BasicTransformerBlock(
233
+ inner_dim,
234
+ n_heads,
235
+ d_head,
236
+ dropout=dropout,
237
+ context_dim=context_dim[d] if isinstance(context_dim, list) else context_dim,
238
+ disable_self_attn=disable_self_attn,
239
+ checkpoint=use_checkpoint,
240
+ )
241
+ for d in range(depth)
242
+ ]
243
+ )
244
+ if not use_linear:
245
+ self.proj_out = zero_module(
246
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
247
+ )
248
+ else:
249
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
250
+ self.use_linear = use_linear
251
+
252
+ def forward(self, x, context=None):
253
+ if not isinstance(context, list):
254
+ context = [context]
255
+ b, c, h, w = x.shape
256
+ x_in = x
257
+ x = self.norm(x)
258
+ if not self.use_linear:
259
+ x = self.proj_in(x)
260
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
261
+ if self.use_linear:
262
+ x = self.proj_in(x)
263
+ for i, block in enumerate(self.transformer_blocks):
264
+ ctx = context[i] if i < len(context) else context[0]
265
+ x = block(x, context=ctx)
266
+ if self.use_linear:
267
+ x = self.proj_out(x)
268
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
269
+ if not self.use_linear:
270
+ x = self.proj_out(x)
271
+ return x + x_in
unet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.HSIGeneUNet",
3
+ "image_size": 32,
4
+ "in_channels": 4,
5
+ "model_channels": 320,
6
+ "out_channels": 4,
7
+ "num_res_blocks": 2,
8
+ "attention_resolutions": [
9
+ 4,
10
+ 2,
11
+ 1
12
+ ],
13
+ "channel_mult": [
14
+ 1,
15
+ 2,
16
+ 4,
17
+ 4
18
+ ],
19
+ "use_checkpoint": true,
20
+ "num_heads": 8,
21
+ "use_spatial_transformer": true,
22
+ "transformer_depth": 1,
23
+ "context_dim": 768,
24
+ "legacy": false
25
+ }
unet/diffusion.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene diffusion modules - UNet, ResBlock, etc. From openaimodel."""
2
+
3
+ from abc import abstractmethod
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from .utils import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ zero_module,
16
+ normalization,
17
+ timestep_embedding,
18
+ exists,
19
+ )
20
+ from .attention import SpatialTransformer
21
+
22
+
23
+ def avg_pool_nd(dims, *args, **kwargs):
24
+ """Create a 1D, 2D, or 3D average pooling module."""
25
+ if dims == 1:
26
+ return nn.AvgPool1d(*args, **kwargs)
27
+ elif dims == 2:
28
+ return nn.AvgPool2d(*args, **kwargs)
29
+ elif dims == 3:
30
+ return nn.AvgPool3d(*args, **kwargs)
31
+ raise ValueError(f"unsupported dimensions: {dims}")
32
+
33
+
34
+ def convert_module_to_f16(x):
35
+ pass
36
+
37
+
38
+ def convert_module_to_f32(x):
39
+ pass
40
+
41
+
42
+ class TimestepBlock(nn.Module):
43
+ """Any module where forward() takes timestep embeddings as a second argument."""
44
+
45
+ @abstractmethod
46
+ def forward(self, x, emb):
47
+ """Apply the module to `x` given `emb` timestep embeddings."""
48
+
49
+
50
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
51
+ """Sequential module that passes timestep embeddings to children that support it."""
52
+
53
+ def forward(self, x, emb, context=None):
54
+ for layer in self:
55
+ if isinstance(layer, TimestepBlock):
56
+ x = layer(x, emb)
57
+ elif isinstance(layer, SpatialTransformer):
58
+ x = layer(x, context)
59
+ else:
60
+ x = layer(x)
61
+ return x
62
+
63
+
64
+ class Upsample(nn.Module):
65
+ """Upsampling layer with optional convolution."""
66
+
67
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
68
+ super().__init__()
69
+ self.channels = channels
70
+ self.out_channels = out_channels or channels
71
+ self.use_conv = use_conv
72
+ self.dims = dims
73
+ if use_conv:
74
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
75
+
76
+ def forward(self, x):
77
+ assert x.shape[1] == self.channels
78
+ if self.dims == 3:
79
+ x = F.interpolate(
80
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
81
+ )
82
+ else:
83
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
84
+ if self.use_conv:
85
+ x = self.conv(x)
86
+ return x
87
+
88
+
89
+ class Downsample(nn.Module):
90
+ """Downsampling layer with optional convolution."""
91
+
92
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.out_channels = out_channels or channels
96
+ self.use_conv = use_conv
97
+ self.dims = dims
98
+ stride = 2 if dims != 3 else (1, 2, 2)
99
+ if use_conv:
100
+ self.op = conv_nd(
101
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
102
+ )
103
+ else:
104
+ assert self.channels == self.out_channels
105
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
106
+
107
+ def forward(self, x):
108
+ assert x.shape[1] == self.channels
109
+ return self.op(x)
110
+
111
+
112
+ class ResBlock(TimestepBlock):
113
+ """Residual block with timestep conditioning."""
114
+
115
+ def __init__(
116
+ self,
117
+ channels,
118
+ emb_channels,
119
+ dropout,
120
+ out_channels=None,
121
+ use_conv=False,
122
+ use_scale_shift_norm=False,
123
+ dims=2,
124
+ use_checkpoint=False,
125
+ up=False,
126
+ down=False,
127
+ ):
128
+ super().__init__()
129
+ self.channels = channels
130
+ self.emb_channels = emb_channels
131
+ self.dropout = dropout
132
+ self.out_channels = out_channels or channels
133
+ self.use_conv = use_conv
134
+ self.use_checkpoint = use_checkpoint
135
+ self.use_scale_shift_norm = use_scale_shift_norm
136
+
137
+ self.in_layers = nn.Sequential(
138
+ normalization(channels),
139
+ nn.SiLU(),
140
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
141
+ )
142
+
143
+ self.updown = up or down
144
+ if up:
145
+ self.h_upd = Upsample(channels, False, dims)
146
+ self.x_upd = Upsample(channels, False, dims)
147
+ elif down:
148
+ self.h_upd = Downsample(channels, False, dims)
149
+ self.x_upd = Downsample(channels, False, dims)
150
+ else:
151
+ self.h_upd = self.x_upd = nn.Identity()
152
+
153
+ self.emb_layers = nn.Sequential(
154
+ nn.SiLU(),
155
+ linear(
156
+ emb_channels,
157
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
158
+ ),
159
+ )
160
+ self.out_layers = nn.Sequential(
161
+ normalization(self.out_channels),
162
+ nn.SiLU(),
163
+ nn.Dropout(p=dropout),
164
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
165
+ )
166
+
167
+ if self.out_channels == channels:
168
+ self.skip_connection = nn.Identity()
169
+ elif use_conv:
170
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
171
+ else:
172
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
173
+
174
+ def forward(self, x, emb):
175
+ return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
176
+
177
+ def _forward(self, x, emb):
178
+ if self.updown:
179
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
180
+ h = in_rest(x)
181
+ h = self.h_upd(h)
182
+ x = self.x_upd(x)
183
+ h = in_conv(h)
184
+ else:
185
+ h = self.in_layers(x)
186
+ emb_out = self.emb_layers(emb).type(h.dtype)
187
+ while len(emb_out.shape) < len(h.shape):
188
+ emb_out = emb_out[..., None]
189
+ if self.use_scale_shift_norm:
190
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
191
+ scale, shift = emb_out.chunk(2, dim=1)
192
+ h = out_norm(h) * (1 + scale) + shift
193
+ h = out_rest(h)
194
+ else:
195
+ h = h + emb_out
196
+ h = self.out_layers(h)
197
+ return self.skip_connection(x) + h
198
+
199
+
200
+ class AttentionBlock(nn.Module):
201
+ """Spatial self-attention block."""
202
+
203
+ def __init__(
204
+ self,
205
+ channels,
206
+ num_heads=1,
207
+ num_head_channels=-1,
208
+ use_checkpoint=False,
209
+ use_new_attention_order=False,
210
+ ):
211
+ super().__init__()
212
+ self.channels = channels
213
+ if num_head_channels == -1:
214
+ self.num_heads = num_heads
215
+ else:
216
+ assert channels % num_head_channels == 0
217
+ self.num_heads = channels // num_head_channels
218
+ self.use_checkpoint = use_checkpoint
219
+ self.norm = normalization(channels)
220
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
221
+ self.attention = (
222
+ QKVAttention(self.num_heads)
223
+ if use_new_attention_order
224
+ else QKVAttentionLegacy(self.num_heads)
225
+ )
226
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
227
+
228
+ def forward(self, x):
229
+ return checkpoint(self._forward, (x,), self.parameters(), True)
230
+
231
+ def _forward(self, x):
232
+ b, c, *spatial = x.shape
233
+ x = x.reshape(b, c, -1)
234
+ qkv = self.qkv(self.norm(x))
235
+ h = self.attention(qkv)
236
+ h = self.proj_out(h)
237
+ return (x + h).reshape(b, c, *spatial)
238
+
239
+
240
+ class QKVAttentionLegacy(nn.Module):
241
+ """QKV attention - split heads before split qkv."""
242
+
243
+ def __init__(self, n_heads):
244
+ super().__init__()
245
+ self.n_heads = n_heads
246
+
247
+ def forward(self, qkv):
248
+ bs, width, length = qkv.shape
249
+ assert width % (3 * self.n_heads) == 0
250
+ ch = width // (3 * self.n_heads)
251
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
252
+ scale = 1 / math.sqrt(math.sqrt(ch))
253
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale)
254
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
255
+ a = torch.einsum("bts,bcs->bct", weight, v)
256
+ return a.reshape(bs, -1, length)
257
+
258
+
259
+ class QKVAttention(nn.Module):
260
+ """QKV attention - split qkv before split heads."""
261
+
262
+ def __init__(self, n_heads):
263
+ super().__init__()
264
+ self.n_heads = n_heads
265
+
266
+ def forward(self, qkv):
267
+ bs, width, length = qkv.shape
268
+ assert width % (3 * self.n_heads) == 0
269
+ ch = width // (3 * self.n_heads)
270
+ q, k, v = qkv.chunk(3, dim=1)
271
+ scale = 1 / math.sqrt(math.sqrt(ch))
272
+ weight = torch.einsum(
273
+ "bct,bcs->bts",
274
+ (q * scale).view(bs * self.n_heads, ch, length),
275
+ (k * scale).view(bs * self.n_heads, ch, length),
276
+ )
277
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
278
+ a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
279
+ return a.reshape(bs, -1, length)
280
+
281
+
282
+ class UNetModel(nn.Module):
283
+ """Full UNet with attention and timestep embedding."""
284
+
285
+ def __init__(
286
+ self,
287
+ image_size,
288
+ in_channels,
289
+ model_channels,
290
+ out_channels,
291
+ num_res_blocks,
292
+ attention_resolutions,
293
+ dropout=0,
294
+ channel_mult=(1, 2, 4, 8),
295
+ conv_resample=True,
296
+ dims=2,
297
+ num_classes=None,
298
+ use_checkpoint=False,
299
+ use_fp16=False,
300
+ num_heads=-1,
301
+ num_head_channels=-1,
302
+ num_heads_upsample=-1,
303
+ use_scale_shift_norm=False,
304
+ resblock_updown=False,
305
+ use_new_attention_order=False,
306
+ use_spatial_transformer=False,
307
+ transformer_depth=1,
308
+ context_dim=None,
309
+ n_embed=None,
310
+ legacy=True,
311
+ disable_self_attentions=None,
312
+ num_attention_blocks=None,
313
+ disable_middle_self_attn=False,
314
+ use_linear_in_transformer=False,
315
+ ):
316
+ super().__init__()
317
+ if use_spatial_transformer:
318
+ assert context_dim is not None
319
+ if context_dim is not None:
320
+ assert use_spatial_transformer
321
+ if hasattr(context_dim, "__iter__") and not isinstance(context_dim, (list, tuple)):
322
+ context_dim = list(context_dim)
323
+
324
+ if num_heads_upsample == -1:
325
+ num_heads_upsample = num_heads
326
+ if num_heads == -1:
327
+ assert num_head_channels != -1
328
+ if num_head_channels == -1:
329
+ assert num_heads != -1
330
+
331
+ self.image_size = image_size
332
+ self.in_channels = in_channels
333
+ self.model_channels = model_channels
334
+ self.out_channels = out_channels
335
+ if isinstance(num_res_blocks, int):
336
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
337
+ else:
338
+ assert len(num_res_blocks) == len(channel_mult)
339
+ self.num_res_blocks = num_res_blocks
340
+
341
+ self.attention_resolutions = attention_resolutions
342
+ self.dropout = dropout
343
+ self.channel_mult = channel_mult
344
+ self.conv_resample = conv_resample
345
+ self.num_classes = num_classes
346
+ self.use_checkpoint = use_checkpoint
347
+ self.dtype = torch.float16 if use_fp16 else torch.float32
348
+ self.num_heads = num_heads
349
+ self.num_head_channels = num_head_channels
350
+ self.num_heads_upsample = num_heads_upsample
351
+ self.predict_codebook_ids = n_embed is not None
352
+
353
+ time_embed_dim = model_channels * 4
354
+ self.time_embed = nn.Sequential(
355
+ linear(model_channels, time_embed_dim),
356
+ nn.SiLU(),
357
+ linear(time_embed_dim, time_embed_dim),
358
+ )
359
+
360
+ if num_classes is not None:
361
+ if isinstance(num_classes, int):
362
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
363
+ elif num_classes == "continuous":
364
+ self.label_emb = nn.Linear(1, time_embed_dim)
365
+ else:
366
+ raise ValueError()
367
+
368
+ self.input_blocks = nn.ModuleList(
369
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
370
+ )
371
+ self._feature_size = model_channels
372
+ input_block_chans = [model_channels]
373
+ ch = model_channels
374
+ ds = 1
375
+
376
+ for level, mult in enumerate(channel_mult):
377
+ for nr in range(self.num_res_blocks[level]):
378
+ layers = [
379
+ ResBlock(
380
+ ch,
381
+ time_embed_dim,
382
+ dropout,
383
+ out_channels=mult * model_channels,
384
+ dims=dims,
385
+ use_checkpoint=use_checkpoint,
386
+ use_scale_shift_norm=use_scale_shift_norm,
387
+ )
388
+ ]
389
+ ch = mult * model_channels
390
+ if ds in attention_resolutions:
391
+ if num_head_channels == -1:
392
+ dim_head = ch // num_heads
393
+ else:
394
+ num_heads_cur = ch // num_head_channels
395
+ dim_head = num_head_channels
396
+ if legacy:
397
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
398
+ disabled_sa = (
399
+ disable_self_attentions[level]
400
+ if exists(disable_self_attentions)
401
+ else False
402
+ )
403
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
404
+ attn_block = (
405
+ AttentionBlock(
406
+ ch,
407
+ use_checkpoint=use_checkpoint,
408
+ num_heads=num_heads,
409
+ num_head_channels=dim_head,
410
+ use_new_attention_order=use_new_attention_order,
411
+ )
412
+ if not use_spatial_transformer
413
+ else SpatialTransformer(
414
+ ch,
415
+ num_heads,
416
+ dim_head,
417
+ depth=transformer_depth,
418
+ context_dim=context_dim,
419
+ disable_self_attn=disabled_sa,
420
+ use_linear=use_linear_in_transformer,
421
+ use_checkpoint=use_checkpoint,
422
+ )
423
+ )
424
+ layers.append(attn_block)
425
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
426
+ self._feature_size += ch
427
+ input_block_chans.append(ch)
428
+ if level != len(channel_mult) - 1:
429
+ out_ch = ch
430
+ down_block = (
431
+ ResBlock(
432
+ ch,
433
+ time_embed_dim,
434
+ dropout,
435
+ out_channels=out_ch,
436
+ dims=dims,
437
+ use_checkpoint=use_checkpoint,
438
+ use_scale_shift_norm=use_scale_shift_norm,
439
+ down=True,
440
+ )
441
+ if resblock_updown
442
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
443
+ )
444
+ self.input_blocks.append(TimestepEmbedSequential(down_block))
445
+ ch = out_ch
446
+ input_block_chans.append(ch)
447
+ ds *= 2
448
+ self._feature_size += ch
449
+
450
+ if num_head_channels == -1:
451
+ dim_head = ch // num_heads
452
+ else:
453
+ num_heads_cur = ch // num_head_channels
454
+ dim_head = num_head_channels
455
+ if legacy:
456
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
457
+ mid_attn = (
458
+ AttentionBlock(
459
+ ch,
460
+ use_checkpoint=use_checkpoint,
461
+ num_heads=num_heads,
462
+ num_head_channels=dim_head,
463
+ use_new_attention_order=use_new_attention_order,
464
+ )
465
+ if not use_spatial_transformer
466
+ else SpatialTransformer(
467
+ ch,
468
+ num_heads,
469
+ dim_head,
470
+ depth=transformer_depth,
471
+ context_dim=context_dim,
472
+ disable_self_attn=disable_middle_self_attn,
473
+ use_linear=use_linear_in_transformer,
474
+ use_checkpoint=use_checkpoint,
475
+ )
476
+ )
477
+ self.middle_block = TimestepEmbedSequential(
478
+ ResBlock(
479
+ ch,
480
+ time_embed_dim,
481
+ dropout,
482
+ dims=dims,
483
+ use_checkpoint=use_checkpoint,
484
+ use_scale_shift_norm=use_scale_shift_norm,
485
+ ),
486
+ mid_attn,
487
+ ResBlock(
488
+ ch,
489
+ time_embed_dim,
490
+ dropout,
491
+ dims=dims,
492
+ use_checkpoint=use_checkpoint,
493
+ use_scale_shift_norm=use_scale_shift_norm,
494
+ ),
495
+ )
496
+ self._feature_size += ch
497
+
498
+ self.output_blocks = nn.ModuleList([])
499
+ for level, mult in list(enumerate(channel_mult))[::-1]:
500
+ for i in range(self.num_res_blocks[level] + 1):
501
+ ich = input_block_chans.pop()
502
+ layers = [
503
+ ResBlock(
504
+ ch + ich,
505
+ time_embed_dim,
506
+ dropout,
507
+ out_channels=model_channels * mult,
508
+ dims=dims,
509
+ use_checkpoint=use_checkpoint,
510
+ use_scale_shift_norm=use_scale_shift_norm,
511
+ )
512
+ ]
513
+ ch = model_channels * mult
514
+ if ds in attention_resolutions:
515
+ if num_head_channels == -1:
516
+ dim_head = ch // num_heads
517
+ else:
518
+ num_heads_cur = ch // num_head_channels
519
+ dim_head = num_head_channels
520
+ if legacy:
521
+ dim_head = (
522
+ ch // num_heads if use_spatial_transformer else num_head_channels
523
+ )
524
+ disabled_sa = (
525
+ disable_self_attentions[level]
526
+ if exists(disable_self_attentions)
527
+ else False
528
+ )
529
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
530
+ attn_block = (
531
+ AttentionBlock(
532
+ ch,
533
+ use_checkpoint=use_checkpoint,
534
+ num_heads=num_heads_upsample,
535
+ num_head_channels=dim_head,
536
+ use_new_attention_order=use_new_attention_order,
537
+ )
538
+ if not use_spatial_transformer
539
+ else SpatialTransformer(
540
+ ch,
541
+ num_heads,
542
+ dim_head,
543
+ depth=transformer_depth,
544
+ context_dim=context_dim,
545
+ disable_self_attn=disabled_sa,
546
+ use_linear=use_linear_in_transformer,
547
+ use_checkpoint=use_checkpoint,
548
+ )
549
+ )
550
+ layers.append(attn_block)
551
+ if level and i == self.num_res_blocks[level]:
552
+ out_ch = ch
553
+ up_block = (
554
+ ResBlock(
555
+ ch,
556
+ time_embed_dim,
557
+ dropout,
558
+ out_channels=out_ch,
559
+ dims=dims,
560
+ use_checkpoint=use_checkpoint,
561
+ use_scale_shift_norm=use_scale_shift_norm,
562
+ up=True,
563
+ )
564
+ if resblock_updown
565
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
566
+ )
567
+ layers.append(up_block)
568
+ ds //= 2
569
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
570
+ self._feature_size += ch
571
+
572
+ self.out = nn.Sequential(
573
+ normalization(ch),
574
+ nn.SiLU(),
575
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
576
+ )
577
+ if self.predict_codebook_ids:
578
+ self.id_predictor = nn.Sequential(
579
+ normalization(ch),
580
+ conv_nd(dims, model_channels, n_embed, 1),
581
+ )
582
+
583
+ def forward(self, x, timesteps=None, metadata=None, context=None, y=None, **kwargs):
584
+ assert (y is not None) == (self.num_classes is not None)
585
+ hs = []
586
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
587
+ emb = self.time_embed(t_emb)
588
+ if metadata is not None:
589
+ if isinstance(metadata, (list, tuple)) and len(metadata) == 1:
590
+ metadata = metadata[0]
591
+ emb = emb + metadata
592
+
593
+ if self.num_classes is not None:
594
+ assert y.shape[0] == x.shape[0]
595
+ emb = emb + self.label_emb(y)
596
+
597
+ h = x.type(self.dtype)
598
+ for module in self.input_blocks:
599
+ h = module(h, emb, context)
600
+ hs.append(h)
601
+ h = self.middle_block(h, emb, context)
602
+ for module in self.output_blocks:
603
+ h = torch.cat([h, hs.pop()], dim=1)
604
+ h = module(h, emb, context)
605
+ h = h.type(x.dtype)
606
+ if self.predict_codebook_ids:
607
+ return self.id_predictor(h)
608
+ return self.out(h)
unet/model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene UNet - LocalControlUNetModel for hyperspectral generation."""
2
+
3
+ import torch
4
+
5
+ from .diffusion import UNetModel
6
+ from .utils import timestep_embedding
7
+
8
+
9
+ class HSIGeneUNet(UNetModel):
10
+ """UNet that accepts metadata and local_control from LocalAdapter."""
11
+
12
+ def forward(
13
+ self,
14
+ x,
15
+ timesteps=None,
16
+ metadata=None,
17
+ context=None,
18
+ local_control=None,
19
+ meta=False,
20
+ **kwargs,
21
+ ):
22
+ hs = []
23
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
24
+ emb = self.time_embed(t_emb) + metadata
25
+ h = x.type(self.dtype)
26
+ for module in self.input_blocks:
27
+ h = module(h, emb, context)
28
+ hs.append(h)
29
+ h = self.middle_block(h, emb, context)
30
+ h += local_control.pop()
31
+ for module in self.output_blocks:
32
+ h = torch.cat([h, hs.pop() + local_control.pop()], dim=1)
33
+ h = module(h, emb, context)
34
+ h = h.type(x.dtype)
35
+ return self.out(h)
unet/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene utilities - no ldm/models imports."""
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import repeat
7
+
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
14
+ if repeat_only:
15
+ return repeat(timesteps, "b -> b d", d=dim)
16
+ half = dim // 2
17
+ freqs = torch.exp(
18
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
19
+ ).to(device=timesteps.device)
20
+ args = timesteps[:, None].float() * freqs[None]
21
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22
+ if dim % 2:
23
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
24
+ return embedding
25
+
26
+
27
+ def conv_nd(dims, *args, **kwargs):
28
+ if dims == 1:
29
+ return nn.Conv1d(*args, **kwargs)
30
+ elif dims == 2:
31
+ return nn.Conv2d(*args, **kwargs)
32
+ elif dims == 3:
33
+ return nn.Conv3d(*args, **kwargs)
34
+ raise ValueError(f"unsupported dimensions: {dims}")
35
+
36
+
37
+ def linear(*args, **kwargs):
38
+ return nn.Linear(*args, **kwargs)
39
+
40
+
41
+ def zero_module(module):
42
+ for p in module.parameters():
43
+ p.detach().zero_()
44
+ return module
45
+
46
+
47
+ def checkpoint(func, inputs, params, flag):
48
+ if flag:
49
+ return _CheckpointFunction.apply(func, len(inputs), *(tuple(inputs) + tuple(params)))
50
+ return func(*inputs)
51
+
52
+
53
+ class _CheckpointFunction(torch.autograd.Function):
54
+ @staticmethod
55
+ def forward(ctx, run_function, length, *args):
56
+ ctx.run_function = run_function
57
+ ctx.input_tensors = list(args[:length])
58
+ ctx.input_params = list(args[length:])
59
+ ctx.gpu_autocast_kwargs = {
60
+ "enabled": torch.is_autocast_enabled(),
61
+ "dtype": torch.get_autocast_gpu_dtype(),
62
+ "cache_enabled": torch.is_autocast_cache_enabled(),
63
+ }
64
+ with torch.no_grad():
65
+ output_tensors = ctx.run_function(*ctx.input_tensors)
66
+ return output_tensors
67
+
68
+ @staticmethod
69
+ def backward(ctx, *output_grads):
70
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
71
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
72
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
73
+ output_tensors = ctx.run_function(*shallow_copies)
74
+ input_grads = torch.autograd.grad(
75
+ output_tensors,
76
+ ctx.input_tensors + ctx.input_params,
77
+ output_grads,
78
+ allow_unused=True,
79
+ )
80
+ return (None, None) + input_grads
81
+
82
+
83
+ def normalization(channels):
84
+ return GroupNorm32(32, channels)
85
+
86
+
87
+ class GroupNorm32(nn.GroupNorm):
88
+ def forward(self, x):
89
+ return super().forward(x.float()).type(x.dtype)
90
+
vae/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """HSIGene VAE component."""
vae/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (185 Bytes). View file
 
vae/__pycache__/model.cpython-312.pyc ADDED
Binary file (4.26 kB). View file
 
vae/__pycache__/vae_blocks.cpython-312.pyc ADDED
Binary file (22.4 kB). View file
 
vae/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_target": "hsigene.HSIGeneAutoencoderKL",
3
+ "in_channels": 48,
4
+ "out_channels": 48,
5
+ "latent_channels": 96,
6
+ "embed_dim": 4,
7
+ "block_out_channels": [
8
+ 64,
9
+ 128,
10
+ 256
11
+ ],
12
+ "num_res_blocks": 4,
13
+ "attn_resolutions": [
14
+ 16,
15
+ 32,
16
+ 64
17
+ ],
18
+ "dropout": 0.0,
19
+ "double_z": true,
20
+ "resolution": 256
21
+ }
vae/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene AutoencoderKL - nn.Module, no Lightning. Loss = Identity."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .vae_blocks import Encoder, Decoder, DiagonalGaussianDistribution
7
+
8
+
9
+ class AutoencoderKL(nn.Module):
10
+ """
11
+ AutoencoderKL - nn.Module (not Lightning).
12
+ Uses Encoder, Decoder, quant_conv, post_quant_conv.
13
+ encode() returns posterior, decode() takes z.
14
+ Loss = Identity (no-op).
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ ddconfig,
20
+ embed_dim=4,
21
+ lossconfig=None,
22
+ **kwargs,
23
+ ):
24
+ super().__init__()
25
+ self.encoder = Encoder(**ddconfig)
26
+ self.decoder = Decoder(**ddconfig)
27
+ assert ddconfig.get("double_z", True)
28
+ z_channels = ddconfig["z_channels"]
29
+ self.quant_conv = nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
30
+ self.post_quant_conv = nn.Conv2d(embed_dim, z_channels, 1)
31
+ self.embed_dim = embed_dim
32
+ self.loss = nn.Identity()
33
+
34
+ def encode(self, x):
35
+ h = self.encoder(x)
36
+ moments = self.quant_conv(h)
37
+ posterior = DiagonalGaussianDistribution(moments, deterministic=True)
38
+ return posterior
39
+
40
+ def decode(self, z):
41
+ z = self.post_quant_conv(z)
42
+ return self.decoder(z)
43
+
44
+ def forward(self, input, sample_posterior=True):
45
+ posterior = self.encode(input)
46
+ if sample_posterior:
47
+ z = posterior.sample()
48
+ else:
49
+ z = posterior.mode()
50
+ dec = self.decode(z)
51
+ return dec, posterior
52
+
53
+
54
+ class HSIGeneAutoencoderKL(AutoencoderKL):
55
+ """
56
+ HSIGene VAE with diffusers-style config.
57
+ Accepts in_channels, out_channels, latent_channels, block_out_channels.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_channels: int = 48,
63
+ out_channels: int = 48,
64
+ latent_channels: int = 96,
65
+ embed_dim: int = 4,
66
+ block_out_channels: tuple = (64, 128, 256),
67
+ num_res_blocks: int = 4,
68
+ attn_resolutions: tuple = (16, 32, 64),
69
+ dropout: float = 0.0,
70
+ double_z: bool = True,
71
+ resolution: int = 256,
72
+ **kwargs,
73
+ ):
74
+ ch = block_out_channels[0]
75
+ ch_mult = tuple(
76
+ block_out_channels[i] // ch for i in range(len(block_out_channels))
77
+ )
78
+ ddconfig = dict(
79
+ double_z=double_z,
80
+ z_channels=latent_channels,
81
+ resolution=resolution,
82
+ in_channels=in_channels,
83
+ out_ch=out_channels,
84
+ ch=ch,
85
+ ch_mult=list(ch_mult),
86
+ num_res_blocks=num_res_blocks,
87
+ attn_resolutions=list(attn_resolutions),
88
+ dropout=dropout,
89
+ )
90
+ super().__init__(ddconfig=ddconfig, embed_dim=embed_dim, **kwargs)
vae/utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """VAE utilities."""
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ def zero_module(module):
7
+ """Zero out the parameters of a module and return it."""
8
+ for p in module.parameters():
9
+ p.detach().zero_()
10
+ return module
vae/vae_blocks.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSIGene VAE blocks - ResnetBlock, Encoder, Decoder, DiagonalGaussianDistribution."""
2
+
3
+ from typing import Optional, Any
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILABLE = True
15
+ except ImportError:
16
+ XFORMERS_IS_AVAILABLE = False
17
+
18
+
19
+ def nonlinearity(x):
20
+ return x * torch.sigmoid(x)
21
+
22
+
23
+ def Normalize(in_channels, num_groups=32):
24
+ return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
25
+
26
+
27
+ class ResnetBlock(nn.Module):
28
+ def __init__(
29
+ self,
30
+ *,
31
+ in_channels,
32
+ out_channels=None,
33
+ conv_shortcut=False,
34
+ dropout,
35
+ temb_channels=512,
36
+ ):
37
+ super().__init__()
38
+ self.in_channels = in_channels
39
+ out_channels = in_channels if out_channels is None else out_channels
40
+ self.out_channels = out_channels
41
+ self.use_conv_shortcut = conv_shortcut
42
+
43
+ self.norm1 = Normalize(in_channels)
44
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
45
+ if temb_channels > 0:
46
+ self.temb_proj = nn.Linear(temb_channels, out_channels)
47
+ self.norm2 = Normalize(out_channels)
48
+ self.dropout = nn.Dropout(dropout)
49
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
50
+ if self.in_channels != self.out_channels:
51
+ if self.use_conv_shortcut:
52
+ self.conv_shortcut = nn.Conv2d(
53
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
54
+ )
55
+ else:
56
+ self.nin_shortcut = nn.Conv2d(
57
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
58
+ )
59
+
60
+ def forward(self, x, temb):
61
+ h = x
62
+ h = self.norm1(h)
63
+ h = nonlinearity(h)
64
+ h = self.conv1(h)
65
+ if temb is not None:
66
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
67
+ h = self.norm2(h)
68
+ h = nonlinearity(h)
69
+ h = self.dropout(h)
70
+ h = self.conv2(h)
71
+ if self.in_channels != self.out_channels:
72
+ if self.use_conv_shortcut:
73
+ x = self.conv_shortcut(x)
74
+ else:
75
+ x = self.nin_shortcut(x)
76
+ return x + h
77
+
78
+
79
+ class AttnBlock(nn.Module):
80
+ def __init__(self, in_channels):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ self.norm = Normalize(in_channels)
84
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
85
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
86
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
87
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
88
+
89
+ def forward(self, x):
90
+ h_ = x
91
+ h_ = self.norm(h_)
92
+ q = self.q(h_)
93
+ k = self.k(h_)
94
+ v = self.v(h_)
95
+ b, c, h, w = q.shape
96
+ q = q.reshape(b, c, h * w).permute(0, 2, 1)
97
+ k = k.reshape(b, c, h * w)
98
+ w_ = torch.bmm(q, k) * (int(c) ** -0.5)
99
+ w_ = F.softmax(w_, dim=2)
100
+ v = v.reshape(b, c, h * w)
101
+ h_ = torch.bmm(v, w_.permute(0, 2, 1))
102
+ h_ = h_.reshape(b, c, h, w)
103
+ h_ = self.proj_out(h_)
104
+ return x + h_
105
+
106
+
107
+ class MemoryEfficientAttnBlock(nn.Module):
108
+ """AttnBlock using xformers when available."""
109
+
110
+ def __init__(self, in_channels):
111
+ super().__init__()
112
+ self.in_channels = in_channels
113
+ self.norm = Normalize(in_channels)
114
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
115
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
116
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
117
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
118
+ self.attention_op: Optional[Any] = None
119
+
120
+ def forward(self, x):
121
+ h_ = self.norm(x)
122
+ q = self.q(h_)
123
+ k = self.k(h_)
124
+ v = self.v(h_)
125
+ B, C, H, W = q.shape
126
+ q, k, v = map(lambda t: rearrange(t, "b c h w -> b (h w) c"), (q, k, v))
127
+ q, k, v = map(
128
+ lambda t: t.unsqueeze(3)
129
+ .reshape(B, t.shape[1], 1, C)
130
+ .permute(0, 2, 1, 3)
131
+ .reshape(B * 1, t.shape[1], C)
132
+ .contiguous(),
133
+ (q, k, v),
134
+ )
135
+ out = xformers.ops.memory_efficient_attention(
136
+ q, k, v, attn_bias=None, op=self.attention_op
137
+ )
138
+ out = (
139
+ out.unsqueeze(0)
140
+ .reshape(B, 1, out.shape[1], C)
141
+ .permute(0, 2, 1, 3)
142
+ .reshape(B, out.shape[1], C)
143
+ )
144
+ out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
145
+ out = self.proj_out(out)
146
+ return x + out
147
+
148
+
149
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
150
+ assert attn_type in ["vanilla", "vanilla-xformers", "none"]
151
+ if XFORMERS_IS_AVAILABLE and attn_type == "vanilla":
152
+ attn_type = "vanilla-xformers"
153
+ if attn_type == "vanilla":
154
+ return AttnBlock(in_channels)
155
+ elif attn_type == "vanilla-xformers":
156
+ return MemoryEfficientAttnBlock(in_channels)
157
+ elif attn_type == "none":
158
+ return nn.Identity()
159
+ raise NotImplementedError(f"attn_type {attn_type}")
160
+
161
+
162
+ class Downsample(nn.Module):
163
+ def __init__(self, in_channels, with_conv):
164
+ super().__init__()
165
+ self.with_conv = with_conv
166
+ if self.with_conv:
167
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
168
+
169
+ def forward(self, x):
170
+ if self.with_conv:
171
+ pad = (0, 1, 0, 1)
172
+ x = F.pad(x, pad, mode="constant", value=0)
173
+ x = self.conv(x)
174
+ else:
175
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
176
+ return x
177
+
178
+
179
+ class Upsample(nn.Module):
180
+ def __init__(self, in_channels, with_conv):
181
+ super().__init__()
182
+ self.with_conv = with_conv
183
+ if self.with_conv:
184
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
185
+
186
+ def forward(self, x):
187
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
188
+ if self.with_conv:
189
+ x = self.conv(x)
190
+ return x
191
+
192
+
193
+ class Encoder(nn.Module):
194
+ def __init__(
195
+ self,
196
+ *,
197
+ ch,
198
+ out_ch,
199
+ ch_mult=(1, 2, 4, 8),
200
+ num_res_blocks,
201
+ attn_resolutions,
202
+ dropout=0.0,
203
+ resamp_with_conv=True,
204
+ in_channels,
205
+ resolution,
206
+ z_channels,
207
+ double_z=True,
208
+ use_linear_attn=False,
209
+ attn_type="vanilla",
210
+ **ignore_kwargs,
211
+ ):
212
+ super().__init__()
213
+ if use_linear_attn:
214
+ attn_type = "linear"
215
+ self.ch = ch
216
+ self.temb_ch = 0
217
+ self.num_resolutions = len(ch_mult)
218
+ self.num_res_blocks = num_res_blocks
219
+ self.resolution = resolution
220
+ self.in_channels = in_channels
221
+
222
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
223
+ curr_res = resolution
224
+ in_ch_mult = (1,) + tuple(ch_mult)
225
+ self.down = nn.ModuleList()
226
+ for i_level in range(self.num_resolutions):
227
+ block = nn.ModuleList()
228
+ attn = nn.ModuleList()
229
+ block_in = ch * in_ch_mult[i_level]
230
+ block_out = ch * ch_mult[i_level]
231
+ for i_block in range(num_res_blocks):
232
+ block.append(
233
+ ResnetBlock(
234
+ in_channels=block_in,
235
+ out_channels=block_out,
236
+ temb_channels=self.temb_ch,
237
+ dropout=dropout,
238
+ )
239
+ )
240
+ block_in = block_out
241
+ if curr_res in attn_resolutions:
242
+ attn.append(make_attn(block_in, attn_type=attn_type))
243
+ down = nn.Module()
244
+ down.block = block
245
+ down.attn = attn
246
+ if i_level != self.num_resolutions - 1:
247
+ down.downsample = Downsample(block_in, resamp_with_conv)
248
+ curr_res = curr_res // 2
249
+ self.down.append(down)
250
+
251
+ self.mid = nn.Module()
252
+ self.mid.block_1 = ResnetBlock(
253
+ in_channels=block_in,
254
+ out_channels=block_in,
255
+ temb_channels=self.temb_ch,
256
+ dropout=dropout,
257
+ )
258
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
259
+ self.mid.block_2 = ResnetBlock(
260
+ in_channels=block_in,
261
+ out_channels=block_in,
262
+ temb_channels=self.temb_ch,
263
+ dropout=dropout,
264
+ )
265
+ self.norm_out = Normalize(block_in)
266
+ self.conv_out = nn.Conv2d(
267
+ block_in, 2 * z_channels if double_z else z_channels,
268
+ kernel_size=3, stride=1, padding=1
269
+ )
270
+
271
+ def forward(self, x):
272
+ temb = None
273
+ hs = [self.conv_in(x)]
274
+ for i_level in range(self.num_resolutions):
275
+ for i_block in range(self.num_res_blocks):
276
+ h = self.down[i_level].block[i_block](hs[-1], temb)
277
+ if len(self.down[i_level].attn) > 0:
278
+ h = self.down[i_level].attn[i_block](h)
279
+ hs.append(h)
280
+ if i_level != self.num_resolutions - 1:
281
+ hs.append(self.down[i_level].downsample(hs[-1]))
282
+ h = hs[-1]
283
+ h = self.mid.block_1(h, temb)
284
+ h = self.mid.attn_1(h)
285
+ h = self.mid.block_2(h, temb)
286
+ h = self.norm_out(h)
287
+ h = nonlinearity(h)
288
+ h = self.conv_out(h)
289
+ return h
290
+
291
+
292
+ class Decoder(nn.Module):
293
+ def __init__(
294
+ self,
295
+ *,
296
+ ch,
297
+ out_ch,
298
+ ch_mult=(1, 2, 4, 8),
299
+ num_res_blocks,
300
+ attn_resolutions,
301
+ dropout=0.0,
302
+ resamp_with_conv=True,
303
+ in_channels,
304
+ resolution,
305
+ z_channels,
306
+ give_pre_end=False,
307
+ tanh_out=False,
308
+ use_linear_attn=False,
309
+ attn_type="vanilla",
310
+ **ignore_kwargs,
311
+ ):
312
+ super().__init__()
313
+ if use_linear_attn:
314
+ attn_type = "linear"
315
+ self.ch = ch
316
+ self.temb_ch = 0
317
+ self.num_resolutions = len(ch_mult)
318
+ self.num_res_blocks = num_res_blocks
319
+ self.resolution = resolution
320
+ self.in_channels = in_channels
321
+ self.give_pre_end = give_pre_end
322
+ self.tanh_out = tanh_out
323
+
324
+ in_ch_mult = (1,) + tuple(ch_mult)
325
+ block_in = ch * ch_mult[self.num_resolutions - 1]
326
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
327
+ self.z_shape = (1, z_channels, curr_res, curr_res)
328
+
329
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
330
+ self.mid = nn.Module()
331
+ self.mid.block_1 = ResnetBlock(
332
+ in_channels=block_in,
333
+ out_channels=block_in,
334
+ temb_channels=self.temb_ch,
335
+ dropout=dropout,
336
+ )
337
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
338
+ self.mid.block_2 = ResnetBlock(
339
+ in_channels=block_in,
340
+ out_channels=block_in,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout,
343
+ )
344
+
345
+ self.up = nn.ModuleList()
346
+ for i_level in reversed(range(self.num_resolutions)):
347
+ block = nn.ModuleList()
348
+ attn = nn.ModuleList()
349
+ block_out = ch * ch_mult[i_level]
350
+ for i_block in range(self.num_res_blocks + 1):
351
+ block.append(
352
+ ResnetBlock(
353
+ in_channels=block_in,
354
+ out_channels=block_out,
355
+ temb_channels=self.temb_ch,
356
+ dropout=dropout,
357
+ )
358
+ )
359
+ block_in = block_out
360
+ if curr_res in attn_resolutions:
361
+ attn.append(make_attn(block_in, attn_type=attn_type))
362
+ up = nn.Module()
363
+ up.block = block
364
+ up.attn = attn
365
+ if i_level != 0:
366
+ up.upsample = Upsample(block_in, resamp_with_conv)
367
+ curr_res = curr_res * 2
368
+ self.up.insert(0, up)
369
+
370
+ self.norm_out = Normalize(block_in)
371
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
372
+
373
+ def forward(self, z):
374
+ self.last_z_shape = z.shape
375
+ temb = None
376
+ h = self.conv_in(z)
377
+ h = self.mid.block_1(h, temb)
378
+ h = self.mid.attn_1(h)
379
+ h = self.mid.block_2(h, temb)
380
+ for i_level in reversed(range(self.num_resolutions)):
381
+ for i_block in range(self.num_res_blocks + 1):
382
+ h = self.up[i_level].block[i_block](h, temb)
383
+ if len(self.up[i_level].attn) > 0:
384
+ h = self.up[i_level].attn[i_block](h)
385
+ if i_level != 0:
386
+ h = self.up[i_level].upsample(h)
387
+ if self.give_pre_end:
388
+ return h
389
+ h = self.norm_out(h)
390
+ h = nonlinearity(h)
391
+ h = self.conv_out(h)
392
+ if self.tanh_out:
393
+ h = torch.tanh(h)
394
+ return h
395
+
396
+
397
+ class DiagonalGaussianDistribution:
398
+ def __init__(self, parameters, deterministic=False):
399
+ self.parameters = parameters
400
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
401
+ self.logvar = torch.clamp(self.logvar, -20.0, 0.0)
402
+ self.deterministic = deterministic
403
+ self.std = torch.exp(0.5 * self.logvar)
404
+ self.var = torch.exp(self.logvar)
405
+ if self.deterministic:
406
+ self.var = self.std = torch.zeros_like(self.mean, device=parameters.device)
407
+
408
+ def sample(self):
409
+ x = self.mean + self.std * torch.randn(
410
+ self.mean.shape, device=self.parameters.device
411
+ )
412
+ return x
413
+
414
+ def kl(self, other=None):
415
+ if self.deterministic:
416
+ return torch.tensor(0.0, device=self.parameters.device)
417
+ if other is None:
418
+ return 0.5 * torch.sum(
419
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
420
+ dim=[1, 2, 3],
421
+ )
422
+ return 0.5 * torch.sum(
423
+ torch.pow(self.mean - other.mean, 2) / other.var
424
+ + self.var / other.var
425
+ - 1.0
426
+ - self.logvar
427
+ + other.logvar,
428
+ dim=[1, 2, 3],
429
+ )
430
+
431
+ def nll(self, sample, dims=[1, 2, 3]):
432
+ if self.deterministic:
433
+ return torch.tensor(0.0, device=self.parameters.device)
434
+ logtwopi = np.log(2.0 * np.pi)
435
+ return 0.5 * torch.sum(
436
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
437
+ dim=dims,
438
+ )
439
+
440
+ def mode(self):
441
+ return self.mean