W8Yi's picture
Update README.md
ebb173d verified
---
language: en
license: mit
library_name: pytorch
tags:
- diffusion
- pathology
- wsi
- distillation
datasets:
- W8Yi/tcga-wsi-uni2h-features
base_model:
- StonyBrook-CVLab/PixCell-256
---
# W8Yi/distilled-wsi-diffusion
![Teacher vs Student](./tile.png)
`distilled-wsi-diffusion` is a distilled student model derived from **PixCell** for
**UNI**-conditioned histopathology image generation. It is designed to preserve the
visual behavior of the PixCell teacher while enabling substantially faster
sampling with fewer denoising steps(**7.06x** speed up), making it practical for rapid research
iteration, hypothesis testing, and interpretability workflows on WSI features.
## Why Use This Model
- Faster inference than full-step teacher sampling for UNI-conditioned generation.
- Compatible with PixCell-based conditioning workflow already used in this repo.
- Useful for pathology-focused generative experiments where turnaround time matters.
## What Is Included
- `student_model.safetensors`: distilled student weights.
- `inference_config.json`: base model IDs and loading config.
- `training_args_full.json`: original training args captured from checkpoint.
- `checkpoint_export_summary.json`: export metadata.
## Quick Use (In This Codebase)
This model was trained/tested with the helper code in `models/diffusion.py`.
```python
import json
import torch
from safetensors.torch import load_file
from models.diffusion import (
PixCellConfig,
build_pixcell_pipeline,
build_teacher_student,
sample_student_trajectory,
decode_latents_to_images,
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cfg = json.load(open("inference_config.json", "r"))
pix_cfg = PixCellConfig(
pix_model_id=cfg["pix_model_id"],
pix_pipeline_id=cfg["pix_pipeline_id"],
vae_model_id=cfg["vae_model_id"],
vae_subfolder=cfg["vae_subfolder"],
dtype=torch.float16,
)
pipeline = build_pixcell_pipeline(pix_cfg, device=device)
cond_dim = int(cfg["cond_dim"])
student_arch = cfg.get("student_arch", "pixcell")
teacher, student = build_teacher_student(
pipeline,
cond_dim=cond_dim,
init_student_from_teacher=True,
student_arch=student_arch,
)
state = load_file("student_model.safetensors")
student.load_state_dict(state, strict=True)
student.to(device=device, dtype=torch.float32).eval()
# Replace with a real UNI feature: shape [B,1,1536]
cond = torch.randn(1, 1, cond_dim, device=device, dtype=torch.float32)
latents = sample_student_trajectory(
student=student,
cond=cond,
pipeline=pipeline,
latent_channels=int(pipeline.vae.config.latent_channels),
latent_size=int(cfg.get("latent_size", 32)),
steps=int(cfg.get("default_sample_steps", 4)),
guidance_scale=float(cfg.get("guidance_student", 1.0)),
)
img = decode_latents_to_images(pipeline, latents)[0]
```
## Generate In 3 Steps
1. Load base PixCell pipeline + this distilled student.
2. Feed one UNI feature (`[1,1,1536]`) as condition.
3. Sample with a small step count (for example, 4) and decode.
## Teacher vs Student (Visualization + Timing)
`compare.png` (left = teacher, right = student):
![Teacher vs Student](./compare.png)
Teacher rollout (35 steps): 0.8908s
Student rollout (4 steps): **0.1137s**
Teacher decode: 0.0147s
Student decode: **0.0145s**
Teacher total: 0.9055s
Student total: **0.1282s**
Rollout speedup: **7.84x**
End-to-end speedup: **7.06x**
Use the following snippet to reproduce side-by-side image and speedup numbers:
```python
import time
import random
import torch
import numpy as np
from PIL import Image
from IPython.display import display
from models.diffusion import (
make_uncond_embedding,
scheduler_rollout,
decode_latents_to_images,
)
idx = random.randrange(len(test_ds))
uni_feat = test_ds[idx] # [1536]
cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32) # [1,1,1536]
# cond: [1,1,1536] from test manifest (as in previous cell)
# student, teacher, pipeline already loaded
student.eval()
teacher.eval()
latent_channels = int(pipeline.vae.config.latent_channels)
latent_size = 32
steps_student = 4
steps_teacher = 35
guidance_student = 1.0
guidance_teacher = 3.0
# fixed noise for fair comparison
g = torch.Generator(device=device)
g.manual_seed(1234)
xT = torch.randn(
(1, latent_channels, latent_size, latent_size),
generator=g,
device=device,
dtype=torch.float32, # base noise dtype
)
def sync_if_cuda(dev):
if dev.type == "cuda":
torch.cuda.synchronize(dev)
with torch.no_grad():
# teacher/original PixCell timing
sync_if_cuda(device)
t0 = time.perf_counter()
_, teacher_states = scheduler_rollout(
model=teacher,
pipeline=pipeline,
xT=xT.to(dtype=next(teacher.parameters()).dtype),
cond=cond.to(dtype=next(teacher.parameters()).dtype),
num_steps=steps_teacher,
guidance_scale=guidance_teacher,
)
sync_if_cuda(device)
t_teacher_rollout = time.perf_counter() - t0
lat_teacher = teacher_states[-1]
# student timing
sync_if_cuda(device)
t0 = time.perf_counter()
_, student_states = scheduler_rollout(
model=student,
pipeline=pipeline,
xT=xT.to(dtype=next(student.parameters()).dtype),
cond=cond.to(dtype=next(student.parameters()).dtype),
num_steps=steps_student,
guidance_scale=guidance_student,
)
sync_if_cuda(device)
t_student_rollout = time.perf_counter() - t0
lat_student = student_states[-1]
# teacher decode timing
sync_if_cuda(device)
t0 = time.perf_counter()
img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0]
sync_if_cuda(device)
t_teacher_decode = time.perf_counter() - t0
# student decode timing
sync_if_cuda(device)
t0 = time.perf_counter()
img_student = decode_latents_to_images(pipeline, lat_student)[0]
sync_if_cuda(device)
t_student_decode = time.perf_counter() - t0
arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1))) # left=teacher, right=student
teacher_total = t_teacher_rollout + t_teacher_decode
student_total = t_student_rollout + t_student_decode
print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s")
print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s")
print(f"Teacher decode: {t_teacher_decode:.4f}s")
print(f"Student decode: {t_student_decode:.4f}s")
print(f"Teacher total: {teacher_total:.4f}s")
print(f"Student total: {student_total:.4f}s")
print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x")
print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x")
```
## Notes
- This is a distilled student checkpoint intended for research.
- Base model/pipeline dependencies are:
- `StonyBrook-CVLab/PixCell-256`
- `StonyBrook-CVLab/PixCell-pipeline`
- `stabilityai/stable-diffusion-3.5-large` (VAE subfolder `vae`)
- Please check and comply with upstream model licenses/terms.