Upload distilled WSI diffusion model package
Browse files- .gitattributes +1 -0
- README.md +121 -0
- compare.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
compare.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -76,6 +76,127 @@ latents = sample_student_trajectory(
|
|
| 76 |
img = decode_latents_to_images(pipeline, latents)[0]
|
| 77 |
```
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
## Notes
|
| 80 |
|
| 81 |
- This is a distilled student checkpoint intended for research.
|
|
|
|
| 76 |
img = decode_latents_to_images(pipeline, latents)[0]
|
| 77 |
```
|
| 78 |
|
| 79 |
+
## Generate In 3 Steps
|
| 80 |
+
|
| 81 |
+
1. Load base PixCell pipeline + this distilled student.
|
| 82 |
+
2. Feed one UNI feature (`[1,1,1536]`) as condition.
|
| 83 |
+
3. Sample with a small step count (for example, 4) and decode.
|
| 84 |
+
|
| 85 |
+
## Teacher vs Student (Visualization + Timing)
|
| 86 |
+
|
| 87 |
+
`compare.png` (left = teacher, right = student):
|
| 88 |
+
|
| 89 |
+

|
| 90 |
+
|
| 91 |
+
Use the following snippet to reproduce side-by-side image and speedup numbers:
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
import time
|
| 95 |
+
import random
|
| 96 |
+
import torch
|
| 97 |
+
import numpy as np
|
| 98 |
+
from PIL import Image
|
| 99 |
+
from IPython.display import display
|
| 100 |
+
|
| 101 |
+
from models.diffusion import (
|
| 102 |
+
make_uncond_embedding,
|
| 103 |
+
scheduler_rollout,
|
| 104 |
+
decode_latents_to_images,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
idx = random.randrange(len(test_ds))
|
| 108 |
+
uni_feat = test_ds[idx] # [1536]
|
| 109 |
+
cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32) # [1,1,1536]
|
| 110 |
+
|
| 111 |
+
# cond: [1,1,1536] from test manifest (as in previous cell)
|
| 112 |
+
# student, teacher, pipeline already loaded
|
| 113 |
+
student.eval()
|
| 114 |
+
teacher.eval()
|
| 115 |
+
|
| 116 |
+
latent_channels = int(pipeline.vae.config.latent_channels)
|
| 117 |
+
latent_size = 32
|
| 118 |
+
steps_student = 4
|
| 119 |
+
steps_teacher = 35
|
| 120 |
+
guidance_student = 1.0
|
| 121 |
+
guidance_teacher = 3.0
|
| 122 |
+
|
| 123 |
+
# fixed noise for fair comparison
|
| 124 |
+
g = torch.Generator(device=device)
|
| 125 |
+
g.manual_seed(1234)
|
| 126 |
+
xT = torch.randn(
|
| 127 |
+
(1, latent_channels, latent_size, latent_size),
|
| 128 |
+
generator=g,
|
| 129 |
+
device=device,
|
| 130 |
+
dtype=torch.float32, # base noise dtype
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def sync_if_cuda(dev):
|
| 134 |
+
if dev.type == "cuda":
|
| 135 |
+
torch.cuda.synchronize(dev)
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
# teacher/original PixCell timing
|
| 139 |
+
sync_if_cuda(device)
|
| 140 |
+
t0 = time.perf_counter()
|
| 141 |
+
_, teacher_states = scheduler_rollout(
|
| 142 |
+
model=teacher,
|
| 143 |
+
pipeline=pipeline,
|
| 144 |
+
xT=xT.to(dtype=next(teacher.parameters()).dtype),
|
| 145 |
+
cond=cond.to(dtype=next(teacher.parameters()).dtype),
|
| 146 |
+
num_steps=steps_teacher,
|
| 147 |
+
guidance_scale=guidance_teacher,
|
| 148 |
+
)
|
| 149 |
+
sync_if_cuda(device)
|
| 150 |
+
t_teacher_rollout = time.perf_counter() - t0
|
| 151 |
+
lat_teacher = teacher_states[-1]
|
| 152 |
+
|
| 153 |
+
# student timing
|
| 154 |
+
sync_if_cuda(device)
|
| 155 |
+
t0 = time.perf_counter()
|
| 156 |
+
_, student_states = scheduler_rollout(
|
| 157 |
+
model=student,
|
| 158 |
+
pipeline=pipeline,
|
| 159 |
+
xT=xT.to(dtype=next(student.parameters()).dtype),
|
| 160 |
+
cond=cond.to(dtype=next(student.parameters()).dtype),
|
| 161 |
+
num_steps=steps_student,
|
| 162 |
+
guidance_scale=guidance_student,
|
| 163 |
+
)
|
| 164 |
+
sync_if_cuda(device)
|
| 165 |
+
t_student_rollout = time.perf_counter() - t0
|
| 166 |
+
lat_student = student_states[-1]
|
| 167 |
+
|
| 168 |
+
# teacher decode timing
|
| 169 |
+
sync_if_cuda(device)
|
| 170 |
+
t0 = time.perf_counter()
|
| 171 |
+
img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0]
|
| 172 |
+
sync_if_cuda(device)
|
| 173 |
+
t_teacher_decode = time.perf_counter() - t0
|
| 174 |
+
|
| 175 |
+
# student decode timing
|
| 176 |
+
sync_if_cuda(device)
|
| 177 |
+
t0 = time.perf_counter()
|
| 178 |
+
img_student = decode_latents_to_images(pipeline, lat_student)[0]
|
| 179 |
+
sync_if_cuda(device)
|
| 180 |
+
t_student_decode = time.perf_counter() - t0
|
| 181 |
+
|
| 182 |
+
arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
|
| 183 |
+
arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
|
| 184 |
+
|
| 185 |
+
display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1))) # left=teacher, right=student
|
| 186 |
+
|
| 187 |
+
teacher_total = t_teacher_rollout + t_teacher_decode
|
| 188 |
+
student_total = t_student_rollout + t_student_decode
|
| 189 |
+
|
| 190 |
+
print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s")
|
| 191 |
+
print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s")
|
| 192 |
+
print(f"Teacher decode: {t_teacher_decode:.4f}s")
|
| 193 |
+
print(f"Student decode: {t_student_decode:.4f}s")
|
| 194 |
+
print(f"Teacher total: {teacher_total:.4f}s")
|
| 195 |
+
print(f"Student total: {student_total:.4f}s")
|
| 196 |
+
print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x")
|
| 197 |
+
print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x")
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
## Notes
|
| 201 |
|
| 202 |
- This is a distilled student checkpoint intended for research.
|
compare.png
ADDED
|
Git LFS Details
|