W8Yi commited on
Commit
3e24f5d
·
verified ·
1 Parent(s): 1f1004e

Upload distilled WSI diffusion model package

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. README.md +121 -0
  3. compare.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ compare.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -76,6 +76,127 @@ latents = sample_student_trajectory(
76
  img = decode_latents_to_images(pipeline, latents)[0]
77
  ```
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ## Notes
80
 
81
  - This is a distilled student checkpoint intended for research.
 
76
  img = decode_latents_to_images(pipeline, latents)[0]
77
  ```
78
 
79
+ ## Generate In 3 Steps
80
+
81
+ 1. Load base PixCell pipeline + this distilled student.
82
+ 2. Feed one UNI feature (`[1,1,1536]`) as condition.
83
+ 3. Sample with a small step count (for example, 4) and decode.
84
+
85
+ ## Teacher vs Student (Visualization + Timing)
86
+
87
+ `compare.png` (left = teacher, right = student):
88
+
89
+ ![Teacher vs Student](./compare.png)
90
+
91
+ Use the following snippet to reproduce side-by-side image and speedup numbers:
92
+
93
+ ```python
94
+ import time
95
+ import random
96
+ import torch
97
+ import numpy as np
98
+ from PIL import Image
99
+ from IPython.display import display
100
+
101
+ from models.diffusion import (
102
+ make_uncond_embedding,
103
+ scheduler_rollout,
104
+ decode_latents_to_images,
105
+ )
106
+
107
+ idx = random.randrange(len(test_ds))
108
+ uni_feat = test_ds[idx] # [1536]
109
+ cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32) # [1,1,1536]
110
+
111
+ # cond: [1,1,1536] from test manifest (as in previous cell)
112
+ # student, teacher, pipeline already loaded
113
+ student.eval()
114
+ teacher.eval()
115
+
116
+ latent_channels = int(pipeline.vae.config.latent_channels)
117
+ latent_size = 32
118
+ steps_student = 4
119
+ steps_teacher = 35
120
+ guidance_student = 1.0
121
+ guidance_teacher = 3.0
122
+
123
+ # fixed noise for fair comparison
124
+ g = torch.Generator(device=device)
125
+ g.manual_seed(1234)
126
+ xT = torch.randn(
127
+ (1, latent_channels, latent_size, latent_size),
128
+ generator=g,
129
+ device=device,
130
+ dtype=torch.float32, # base noise dtype
131
+ )
132
+
133
+ def sync_if_cuda(dev):
134
+ if dev.type == "cuda":
135
+ torch.cuda.synchronize(dev)
136
+
137
+ with torch.no_grad():
138
+ # teacher/original PixCell timing
139
+ sync_if_cuda(device)
140
+ t0 = time.perf_counter()
141
+ _, teacher_states = scheduler_rollout(
142
+ model=teacher,
143
+ pipeline=pipeline,
144
+ xT=xT.to(dtype=next(teacher.parameters()).dtype),
145
+ cond=cond.to(dtype=next(teacher.parameters()).dtype),
146
+ num_steps=steps_teacher,
147
+ guidance_scale=guidance_teacher,
148
+ )
149
+ sync_if_cuda(device)
150
+ t_teacher_rollout = time.perf_counter() - t0
151
+ lat_teacher = teacher_states[-1]
152
+
153
+ # student timing
154
+ sync_if_cuda(device)
155
+ t0 = time.perf_counter()
156
+ _, student_states = scheduler_rollout(
157
+ model=student,
158
+ pipeline=pipeline,
159
+ xT=xT.to(dtype=next(student.parameters()).dtype),
160
+ cond=cond.to(dtype=next(student.parameters()).dtype),
161
+ num_steps=steps_student,
162
+ guidance_scale=guidance_student,
163
+ )
164
+ sync_if_cuda(device)
165
+ t_student_rollout = time.perf_counter() - t0
166
+ lat_student = student_states[-1]
167
+
168
+ # teacher decode timing
169
+ sync_if_cuda(device)
170
+ t0 = time.perf_counter()
171
+ img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0]
172
+ sync_if_cuda(device)
173
+ t_teacher_decode = time.perf_counter() - t0
174
+
175
+ # student decode timing
176
+ sync_if_cuda(device)
177
+ t0 = time.perf_counter()
178
+ img_student = decode_latents_to_images(pipeline, lat_student)[0]
179
+ sync_if_cuda(device)
180
+ t_student_decode = time.perf_counter() - t0
181
+
182
+ arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
183
+ arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
184
+
185
+ display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1))) # left=teacher, right=student
186
+
187
+ teacher_total = t_teacher_rollout + t_teacher_decode
188
+ student_total = t_student_rollout + t_student_decode
189
+
190
+ print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s")
191
+ print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s")
192
+ print(f"Teacher decode: {t_teacher_decode:.4f}s")
193
+ print(f"Student decode: {t_student_decode:.4f}s")
194
+ print(f"Teacher total: {teacher_total:.4f}s")
195
+ print(f"Student total: {student_total:.4f}s")
196
+ print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x")
197
+ print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x")
198
+ ```
199
+
200
  ## Notes
201
 
202
  - This is a distilled student checkpoint intended for research.
compare.png ADDED

Git LFS Details

  • SHA256: ce7868abb9ca039e024b10da7dc309532b47d7223e32828187642d194ec1a68d
  • Pointer size: 131 Bytes
  • Size of remote file: 281 kB