W8Yi's picture
Update README.md
ebb173d verified
metadata
language: en
license: mit
library_name: pytorch
tags:
  - diffusion
  - pathology
  - wsi
  - distillation
datasets:
  - W8Yi/tcga-wsi-uni2h-features
base_model:
  - StonyBrook-CVLab/PixCell-256

W8Yi/distilled-wsi-diffusion

Teacher vs Student

distilled-wsi-diffusion is a distilled student model derived from PixCell for UNI-conditioned histopathology image generation. It is designed to preserve the visual behavior of the PixCell teacher while enabling substantially faster sampling with fewer denoising steps(7.06x speed up), making it practical for rapid research iteration, hypothesis testing, and interpretability workflows on WSI features.

Why Use This Model

  • Faster inference than full-step teacher sampling for UNI-conditioned generation.
  • Compatible with PixCell-based conditioning workflow already used in this repo.
  • Useful for pathology-focused generative experiments where turnaround time matters.

What Is Included

  • student_model.safetensors: distilled student weights.
  • inference_config.json: base model IDs and loading config.
  • training_args_full.json: original training args captured from checkpoint.
  • checkpoint_export_summary.json: export metadata.

Quick Use (In This Codebase)

This model was trained/tested with the helper code in models/diffusion.py.

import json
import torch
from safetensors.torch import load_file

from models.diffusion import (
    PixCellConfig,
    build_pixcell_pipeline,
    build_teacher_student,
    sample_student_trajectory,
    decode_latents_to_images,
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cfg = json.load(open("inference_config.json", "r"))
pix_cfg = PixCellConfig(
    pix_model_id=cfg["pix_model_id"],
    pix_pipeline_id=cfg["pix_pipeline_id"],
    vae_model_id=cfg["vae_model_id"],
    vae_subfolder=cfg["vae_subfolder"],
    dtype=torch.float16,
)
pipeline = build_pixcell_pipeline(pix_cfg, device=device)

cond_dim = int(cfg["cond_dim"])
student_arch = cfg.get("student_arch", "pixcell")
teacher, student = build_teacher_student(
    pipeline,
    cond_dim=cond_dim,
    init_student_from_teacher=True,
    student_arch=student_arch,
)
state = load_file("student_model.safetensors")
student.load_state_dict(state, strict=True)
student.to(device=device, dtype=torch.float32).eval()

# Replace with a real UNI feature: shape [B,1,1536]
cond = torch.randn(1, 1, cond_dim, device=device, dtype=torch.float32)

latents = sample_student_trajectory(
    student=student,
    cond=cond,
    pipeline=pipeline,
    latent_channels=int(pipeline.vae.config.latent_channels),
    latent_size=int(cfg.get("latent_size", 32)),
    steps=int(cfg.get("default_sample_steps", 4)),
    guidance_scale=float(cfg.get("guidance_student", 1.0)),
)
img = decode_latents_to_images(pipeline, latents)[0]

Generate In 3 Steps

  1. Load base PixCell pipeline + this distilled student.
  2. Feed one UNI feature ([1,1,1536]) as condition.
  3. Sample with a small step count (for example, 4) and decode.

Teacher vs Student (Visualization + Timing)

compare.png (left = teacher, right = student):

Teacher vs Student

Teacher rollout (35 steps): 0.8908s
Student rollout (4 steps): 0.1137s
Teacher decode: 0.0147s
Student decode: 0.0145s
Teacher total: 0.9055s
Student total: 0.1282s
Rollout speedup: 7.84x
End-to-end speedup: 7.06x

Use the following snippet to reproduce side-by-side image and speedup numbers:

import time
import random
import torch
import numpy as np
from PIL import Image
from IPython.display import display

from models.diffusion import (
    make_uncond_embedding,
    scheduler_rollout,
    decode_latents_to_images,
)

idx = random.randrange(len(test_ds))
uni_feat = test_ds[idx]                 # [1536]
cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32)  # [1,1,1536]

# cond: [1,1,1536] from test manifest (as in previous cell)
# student, teacher, pipeline already loaded
student.eval()
teacher.eval()

latent_channels = int(pipeline.vae.config.latent_channels)
latent_size = 32
steps_student = 4
steps_teacher = 35
guidance_student = 1.0
guidance_teacher = 3.0

# fixed noise for fair comparison
g = torch.Generator(device=device)
g.manual_seed(1234)
xT = torch.randn(
    (1, latent_channels, latent_size, latent_size),
    generator=g,
    device=device,
    dtype=torch.float32,   # base noise dtype
)

def sync_if_cuda(dev):
    if dev.type == "cuda":
        torch.cuda.synchronize(dev)

with torch.no_grad():
    # teacher/original PixCell timing
    sync_if_cuda(device)
    t0 = time.perf_counter()
    _, teacher_states = scheduler_rollout(
        model=teacher,
        pipeline=pipeline,
        xT=xT.to(dtype=next(teacher.parameters()).dtype),
        cond=cond.to(dtype=next(teacher.parameters()).dtype),
        num_steps=steps_teacher,
        guidance_scale=guidance_teacher,
    )
    sync_if_cuda(device)
    t_teacher_rollout = time.perf_counter() - t0
    lat_teacher = teacher_states[-1]

    # student timing
    sync_if_cuda(device)
    t0 = time.perf_counter()
    _, student_states = scheduler_rollout(
        model=student,
        pipeline=pipeline,
        xT=xT.to(dtype=next(student.parameters()).dtype),
        cond=cond.to(dtype=next(student.parameters()).dtype),
        num_steps=steps_student,
        guidance_scale=guidance_student,
    )
    sync_if_cuda(device)
    t_student_rollout = time.perf_counter() - t0
    lat_student = student_states[-1]

    # teacher decode timing
    sync_if_cuda(device)
    t0 = time.perf_counter()
    img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0]
    sync_if_cuda(device)
    t_teacher_decode = time.perf_counter() - t0

    # student decode timing
    sync_if_cuda(device)
    t0 = time.perf_counter()
    img_student = decode_latents_to_images(pipeline, lat_student)[0]
    sync_if_cuda(device)
    t_student_decode = time.perf_counter() - t0

arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)

display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1)))  # left=teacher, right=student

teacher_total = t_teacher_rollout + t_teacher_decode
student_total = t_student_rollout + t_student_decode

print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s")
print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s")
print(f"Teacher decode: {t_teacher_decode:.4f}s")
print(f"Student decode: {t_student_decode:.4f}s")
print(f"Teacher total: {teacher_total:.4f}s")
print(f"Student total: {student_total:.4f}s")
print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x")
print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x")

Notes

  • This is a distilled student checkpoint intended for research.
  • Base model/pipeline dependencies are:
    • StonyBrook-CVLab/PixCell-256
    • StonyBrook-CVLab/PixCell-pipeline
    • stabilityai/stable-diffusion-3.5-large (VAE subfolder vae)
  • Please check and comply with upstream model licenses/terms.