BiliSakura commited on
Commit
3dde604
·
verified ·
1 Parent(s): 070ea57

Add files using upload-large-folder tool

Browse files
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: diffusers
4
+ pipeline_tag: unconditional-image-generation
5
+ tags:
6
+ - zoomldm
7
+ - cdm
8
+ - dit
9
+ - histopathology
10
+ - brca
11
+ - custom-pipeline
12
+ ---
13
+
14
+ # BiliSakura/ZoomLDM-CDM-brca
15
+
16
+ Diffusers-style wrapped **CDM (DiT)** checkpoint for BRCA, converted from ZoomLDM `cdm_dit` training outputs.
17
+
18
+ ## Model Description
19
+
20
+ - **Architecture:** DiT-B style conditioning diffusion model (CDM)
21
+ - **Domain:** BRCA conditioning space used by ZoomLDM
22
+ - **Output:** conditioning tokens/embeddings (`(B, 512, 65)`)
23
+ - **Format:** custom diffusers pipeline (`pipeline.py`)
24
+
25
+ ## Intended Use
26
+
27
+ Use this model to sample BRCA conditioning embeddings that can be consumed by downstream ZoomLDM workflows.
28
+
29
+ ## Out-of-Scope Use
30
+
31
+ - Not a complete pixel-space generator by itself.
32
+ - Not intended for clinical or diagnostic use.
33
+ - Not validated for non-BRCA domains without adaptation.
34
+
35
+ ## Files
36
+
37
+ - `pipeline.py`: custom `DiffusionPipeline` implementation (`CDMDiTPipeline`)
38
+ - `model_index.json`: diffusers metadata
39
+ - `cdm/`: active model weights/config used by pipeline
40
+ - `scheduler/`: DDIM scheduler config
41
+ - `model_raw.safetensors`: non-EMA training weights (optional)
42
+ - `optimizer.pt`: optimizer state (optional)
43
+ - `config.json`: conversion metadata
44
+
45
+ ## Usage
46
+
47
+ ```python
48
+ import torch
49
+ from diffusers import DiffusionPipeline
50
+
51
+ pipe = DiffusionPipeline.from_pretrained(
52
+ "BiliSakura/ZoomLDM-CDM-brca",
53
+ custom_pipeline="pipelin.py",
54
+ trust_remote_code=True,
55
+ ).to("cuda")
56
+
57
+ out = pipe(
58
+ batch_size=2,
59
+ magnification=torch.tensor([0, 0], device="cuda"), # class labels 0..7
60
+ num_inference_steps=50,
61
+ guidance_scale=1.0,
62
+ )
63
+
64
+ samples = out.samples # (B, 512, 65)
65
+ ```
66
+
67
+ ## Limitations
68
+
69
+ - Produces conditioning embeddings, not final images.
70
+ - Requires correct class/magnification label conventions.
71
+ - Inherits data biases and quality limits from the original training data.
72
+
73
+ ## Citation
74
+
75
+ ```bibtex
76
+ @InProceedings{Yellapragada_2025_CVPR,
77
+ author = {Yellapragada, Srikar and Graikos, Alexandros and Triaridis, Kostas and Prasanna, Prateek and Gupta, Rajarsi and Saltz, Joel and Samaras, Dimitris},
78
+ title = {ZoomLDM: Latent Diffusion Model for Multi-scale Image Generation},
79
+ booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
80
+ month = {June},
81
+ year = {2025},
82
+ pages = {23453-23463}
83
+ }
84
+
85
+ @inproceedings{Peebles2023DiT,
86
+ title={Scalable Diffusion Models with Transformers},
87
+ author={Peebles, William and Xie, Saining},
88
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
89
+ year={2023}
90
+ }
91
+ ```
cdm/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CDMDiTModel",
3
+ "_diffusers_version": "0.30.0",
4
+ "num_patches": 65,
5
+ "in_channels": 512,
6
+ "hidden_size": 768,
7
+ "depth": 12,
8
+ "num_heads": 12,
9
+ "mlp_ratio": 4.0,
10
+ "num_classes": 8,
11
+ "learn_sigma": true
12
+ }
cdm/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f593642111dd86a2c03994be5dafd1d5c8f170a7bce0c1338849e5b171fb6043
3
+ size 523000344
model_index.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CDMDiTPipeline",
3
+ "_diffusers_version": "0.30.0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "DDIMScheduler"
7
+ ]
8
+ }
pipeline.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom diffusers pipeline for ZoomLDM CDM (DiT backbone).
3
+ """
4
+
5
+ import math
6
+ import json
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from diffusers import DDIMScheduler, DiffusionPipeline
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import BaseOutput
17
+
18
+
19
+ def _modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
20
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21
+
22
+
23
+ class Attention(nn.Module):
24
+ """Minimal ViT-style self-attention with timm-compatible parameter names."""
25
+
26
+ def __init__(self, dim: int, num_heads: int, qkv_bias: bool = True):
27
+ super().__init__()
28
+ if dim % num_heads != 0:
29
+ raise ValueError(f"dim ({dim}) must be divisible by num_heads ({num_heads})")
30
+ self.num_heads = num_heads
31
+ self.head_dim = dim // num_heads
32
+ self.scale = self.head_dim**-0.5
33
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
34
+ self.proj = nn.Linear(dim, dim, bias=True)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ bsz, tokens, dim = x.shape
38
+ qkv = self.qkv(x)
39
+ qkv = qkv.reshape(bsz, tokens, 3, self.num_heads, self.head_dim)
40
+ qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, H, T, D
41
+ q, k, v = qkv.unbind(0)
42
+
43
+ attn = (q @ k.transpose(-2, -1)) * self.scale
44
+ attn = attn.softmax(dim=-1)
45
+
46
+ x = attn @ v
47
+ x = x.transpose(1, 2).reshape(bsz, tokens, dim)
48
+ return self.proj(x)
49
+
50
+
51
+ class Mlp(nn.Module):
52
+ """Minimal timm-like MLP block with matching names."""
53
+
54
+ def __init__(self, in_features: int, hidden_features: int):
55
+ super().__init__()
56
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
57
+ self.act = nn.GELU(approximate="tanh")
58
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=True)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ return self.fc2(self.act(self.fc1(x)))
62
+
63
+
64
+ class DiTBlock(nn.Module):
65
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0):
66
+ super().__init__()
67
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
68
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
69
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
70
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
71
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim)
72
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
73
+
74
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
75
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
76
+ x = x + gate_msa.unsqueeze(1) * self.attn(_modulate(self.norm1(x), shift_msa, scale_msa))
77
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(_modulate(self.norm2(x), shift_mlp, scale_mlp))
78
+ return x
79
+
80
+
81
+ class TimestepEmbedder(nn.Module):
82
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
83
+ super().__init__()
84
+ self.mlp = nn.Sequential(
85
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
86
+ nn.SiLU(),
87
+ nn.Linear(hidden_size, hidden_size, bias=True),
88
+ )
89
+ self.frequency_embedding_size = frequency_embedding_size
90
+
91
+ @staticmethod
92
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
93
+ half = dim // 2
94
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
95
+ device=t.device
96
+ )
97
+ args = t[:, None].float() * freqs[None]
98
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
99
+ if dim % 2:
100
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
101
+ return embedding
102
+
103
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
104
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
105
+ return self.mlp(t_freq)
106
+
107
+
108
+ class LabelEmbedder(nn.Module):
109
+ def __init__(self, num_classes: int, hidden_size: int):
110
+ super().__init__()
111
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
112
+
113
+ def forward(self, labels: torch.Tensor) -> torch.Tensor:
114
+ return self.embedding_table(labels)
115
+
116
+
117
+ class FinalLayer(nn.Module):
118
+ def __init__(self, hidden_size: int, out_channels: int):
119
+ super().__init__()
120
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
121
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
122
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
123
+
124
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
125
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
126
+ x = _modulate(self.norm_final(x), shift, scale)
127
+ return self.linear(x)
128
+
129
+
130
+ class CDMDiTModel(ModelMixin, ConfigMixin):
131
+ @register_to_config
132
+ def __init__(
133
+ self,
134
+ num_patches: int = 65,
135
+ in_channels: int = 512,
136
+ hidden_size: int = 768,
137
+ depth: int = 12,
138
+ num_heads: int = 12,
139
+ mlp_ratio: float = 4.0,
140
+ num_classes: int = 8,
141
+ learn_sigma: bool = True,
142
+ ):
143
+ super().__init__()
144
+ self.learn_sigma = learn_sigma
145
+ self.in_channels = in_channels
146
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
147
+ self.num_patches = num_patches
148
+ self.x_embedder = nn.Linear(in_channels, hidden_size)
149
+ self.t_embedder = TimestepEmbedder(hidden_size)
150
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size)
151
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
152
+ self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)])
153
+ self.final_layer = FinalLayer(hidden_size, self.out_channels)
154
+
155
+ def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
156
+ # x: (B, C, T), output: (B, out_channels, T)
157
+ x = x.transpose(1, 2)
158
+ x = self.x_embedder(x) + self.pos_embed
159
+ t_emb = self.t_embedder(t)
160
+ y_emb = self.y_embedder(y)
161
+ c = t_emb + y_emb
162
+ for block in self.blocks:
163
+ x = block(x, c)
164
+ x = self.final_layer(x, c)
165
+ return x.transpose(1, 2)
166
+
167
+
168
+ @dataclass
169
+ class CDMPipelineOutput(BaseOutput):
170
+ samples: torch.Tensor
171
+
172
+
173
+ class CDMDiTPipeline(DiffusionPipeline):
174
+ def __init__(self, scheduler: DDIMScheduler, cdm: Optional[CDMDiTModel] = None):
175
+ super().__init__()
176
+ self.register_modules(scheduler=scheduler)
177
+ self.cdm = cdm
178
+ self._cdm_root = None
179
+ scheduler_path = getattr(getattr(scheduler, "config", None), "_name_or_path", None)
180
+ if scheduler_path:
181
+ p = Path(scheduler_path)
182
+ self._cdm_root = p.parent if p.name == "scheduler" else p
183
+
184
+ @property
185
+ def device(self) -> torch.device:
186
+ self._load_cdm_if_needed()
187
+ return next(self.cdm.parameters()).device
188
+
189
+ def to(self, *args, **kwargs):
190
+ self._load_cdm_if_needed()
191
+ self.cdm.to(*args, **kwargs)
192
+ return self
193
+
194
+ def _load_cdm_if_needed(self):
195
+ if self.cdm is not None:
196
+ return
197
+ if self._cdm_root is None:
198
+ root_from_cfg = self.config.get("_name_or_path", None)
199
+ if root_from_cfg:
200
+ self._cdm_root = Path(root_from_cfg)
201
+ if self._cdm_root is None:
202
+ raise RuntimeError("Could not infer model root for loading CDM weights.")
203
+
204
+ cdm_dir = self._cdm_root / "cdm"
205
+ with open(cdm_dir / "config.json", encoding="utf-8") as f:
206
+ cfg = json.load(f)
207
+ cfg.pop("_class_name", None)
208
+ cfg.pop("_diffusers_version", None)
209
+
210
+ cdm = CDMDiTModel(**cfg)
211
+ safetensors_path = cdm_dir / "diffusion_pytorch_model.safetensors"
212
+ bin_path = cdm_dir / "diffusion_pytorch_model.bin"
213
+ if safetensors_path.exists():
214
+ from safetensors.torch import load_file
215
+
216
+ state = load_file(str(safetensors_path))
217
+ elif bin_path.exists():
218
+ try:
219
+ state = torch.load(bin_path, map_location="cpu", weights_only=True)
220
+ except TypeError:
221
+ state = torch.load(bin_path, map_location="cpu")
222
+ else:
223
+ raise FileNotFoundError(
224
+ "No CDM weights found in cdm/ (expected diffusion_pytorch_model.safetensors or .bin)."
225
+ )
226
+ cdm.load_state_dict(state, strict=True)
227
+ cdm.eval()
228
+ self.cdm = cdm
229
+
230
+ @torch.no_grad()
231
+ def __call__(
232
+ self,
233
+ batch_size: int = 1,
234
+ magnification: Optional[torch.Tensor] = None,
235
+ num_inference_steps: int = 50,
236
+ guidance_scale: float = 1.0,
237
+ num_patches: Optional[int] = None,
238
+ return_dict: bool = True,
239
+ ):
240
+ self._load_cdm_if_needed()
241
+ device = self.device
242
+ dtype = next(self.cdm.parameters()).dtype
243
+
244
+ if magnification is None:
245
+ magnification = torch.zeros(batch_size, dtype=torch.long, device=device)
246
+ else:
247
+ magnification = magnification.to(device=device, dtype=torch.long)
248
+ if magnification.ndim == 0:
249
+ magnification = magnification.view(1)
250
+
251
+ batch_size = int(magnification.shape[0])
252
+ tokens = num_patches or self.cdm.config.num_patches
253
+ channels = self.cdm.config.in_channels
254
+
255
+ latents = torch.randn((batch_size, channels, tokens), device=device, dtype=dtype)
256
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
257
+
258
+ for t in self.progress_bar(self.scheduler.timesteps):
259
+ model_in = torch.cat([latents, latents], dim=0)
260
+ t_batch = t.expand(model_in.shape[0]).to(device)
261
+ y_in = torch.cat([torch.zeros_like(magnification), magnification], dim=0)
262
+
263
+ model_out = self.cdm(model_in, t_batch, y_in)
264
+ eps, _sigma = model_out.chunk(2, dim=1) if self.cdm.config.learn_sigma else (model_out, None)
265
+ eps_uncond, eps_cond = eps.chunk(2, dim=0)
266
+ eps_guided = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
267
+
268
+ latents = self.scheduler.step(eps_guided, t, latents).prev_sample
269
+
270
+ if not return_dict:
271
+ return (latents,)
272
+ return CDMPipelineOutput(samples=latents)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }