language: en
license: mit
library_name: pytorch
tags:
- diffusion
- pathology
- wsi
- distillation
datasets:
- W8Yi/tcga-wsi-uni2h-features
base_model:
- StonyBrook-CVLab/PixCell-256
W8Yi/distilled-wsi-diffusion
distilled-wsi-diffusion is a distilled student model derived from PixCell for
UNI-conditioned histopathology image generation. It is designed to preserve the
visual behavior of the PixCell teacher while enabling substantially faster
sampling with fewer denoising steps(7.06x speed up), making it practical for rapid research
iteration, hypothesis testing, and interpretability workflows on WSI features.
Why Use This Model
- Faster inference than full-step teacher sampling for UNI-conditioned generation.
- Compatible with PixCell-based conditioning workflow already used in this repo.
- Useful for pathology-focused generative experiments where turnaround time matters.
What Is Included
student_model.safetensors: distilled student weights.inference_config.json: base model IDs and loading config.training_args_full.json: original training args captured from checkpoint.checkpoint_export_summary.json: export metadata.
Quick Use (In This Codebase)
This model was trained/tested with the helper code in models/diffusion.py.
import json
import torch
from safetensors.torch import load_file
from models.diffusion import (
PixCellConfig,
build_pixcell_pipeline,
build_teacher_student,
sample_student_trajectory,
decode_latents_to_images,
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cfg = json.load(open("inference_config.json", "r"))
pix_cfg = PixCellConfig(
pix_model_id=cfg["pix_model_id"],
pix_pipeline_id=cfg["pix_pipeline_id"],
vae_model_id=cfg["vae_model_id"],
vae_subfolder=cfg["vae_subfolder"],
dtype=torch.float16,
)
pipeline = build_pixcell_pipeline(pix_cfg, device=device)
cond_dim = int(cfg["cond_dim"])
student_arch = cfg.get("student_arch", "pixcell")
teacher, student = build_teacher_student(
pipeline,
cond_dim=cond_dim,
init_student_from_teacher=True,
student_arch=student_arch,
)
state = load_file("student_model.safetensors")
student.load_state_dict(state, strict=True)
student.to(device=device, dtype=torch.float32).eval()
# Replace with a real UNI feature: shape [B,1,1536]
cond = torch.randn(1, 1, cond_dim, device=device, dtype=torch.float32)
latents = sample_student_trajectory(
student=student,
cond=cond,
pipeline=pipeline,
latent_channels=int(pipeline.vae.config.latent_channels),
latent_size=int(cfg.get("latent_size", 32)),
steps=int(cfg.get("default_sample_steps", 4)),
guidance_scale=float(cfg.get("guidance_student", 1.0)),
)
img = decode_latents_to_images(pipeline, latents)[0]
Generate In 3 Steps
- Load base PixCell pipeline + this distilled student.
- Feed one UNI feature (
[1,1,1536]) as condition. - Sample with a small step count (for example, 4) and decode.
Teacher vs Student (Visualization + Timing)
compare.png (left = teacher, right = student):
Teacher rollout (35 steps): 0.8908s
Student rollout (4 steps): 0.1137s
Teacher decode: 0.0147s
Student decode: 0.0145s
Teacher total: 0.9055s
Student total: 0.1282s
Rollout speedup: 7.84x
End-to-end speedup: 7.06x
Use the following snippet to reproduce side-by-side image and speedup numbers:
import time
import random
import torch
import numpy as np
from PIL import Image
from IPython.display import display
from models.diffusion import (
make_uncond_embedding,
scheduler_rollout,
decode_latents_to_images,
)
idx = random.randrange(len(test_ds))
uni_feat = test_ds[idx] # [1536]
cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32) # [1,1,1536]
# cond: [1,1,1536] from test manifest (as in previous cell)
# student, teacher, pipeline already loaded
student.eval()
teacher.eval()
latent_channels = int(pipeline.vae.config.latent_channels)
latent_size = 32
steps_student = 4
steps_teacher = 35
guidance_student = 1.0
guidance_teacher = 3.0
# fixed noise for fair comparison
g = torch.Generator(device=device)
g.manual_seed(1234)
xT = torch.randn(
(1, latent_channels, latent_size, latent_size),
generator=g,
device=device,
dtype=torch.float32, # base noise dtype
)
def sync_if_cuda(dev):
if dev.type == "cuda":
torch.cuda.synchronize(dev)
with torch.no_grad():
# teacher/original PixCell timing
sync_if_cuda(device)
t0 = time.perf_counter()
_, teacher_states = scheduler_rollout(
model=teacher,
pipeline=pipeline,
xT=xT.to(dtype=next(teacher.parameters()).dtype),
cond=cond.to(dtype=next(teacher.parameters()).dtype),
num_steps=steps_teacher,
guidance_scale=guidance_teacher,
)
sync_if_cuda(device)
t_teacher_rollout = time.perf_counter() - t0
lat_teacher = teacher_states[-1]
# student timing
sync_if_cuda(device)
t0 = time.perf_counter()
_, student_states = scheduler_rollout(
model=student,
pipeline=pipeline,
xT=xT.to(dtype=next(student.parameters()).dtype),
cond=cond.to(dtype=next(student.parameters()).dtype),
num_steps=steps_student,
guidance_scale=guidance_student,
)
sync_if_cuda(device)
t_student_rollout = time.perf_counter() - t0
lat_student = student_states[-1]
# teacher decode timing
sync_if_cuda(device)
t0 = time.perf_counter()
img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0]
sync_if_cuda(device)
t_teacher_decode = time.perf_counter() - t0
# student decode timing
sync_if_cuda(device)
t0 = time.perf_counter()
img_student = decode_latents_to_images(pipeline, lat_student)[0]
sync_if_cuda(device)
t_student_decode = time.perf_counter() - t0
arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1))) # left=teacher, right=student
teacher_total = t_teacher_rollout + t_teacher_decode
student_total = t_student_rollout + t_student_decode
print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s")
print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s")
print(f"Teacher decode: {t_teacher_decode:.4f}s")
print(f"Student decode: {t_student_decode:.4f}s")
print(f"Teacher total: {teacher_total:.4f}s")
print(f"Student total: {student_total:.4f}s")
print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x")
print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x")
Notes
- This is a distilled student checkpoint intended for research.
- Base model/pipeline dependencies are:
StonyBrook-CVLab/PixCell-256StonyBrook-CVLab/PixCell-pipelinestabilityai/stable-diffusion-3.5-large(VAE subfoldervae)
- Please check and comply with upstream model licenses/terms.

