--- language: en license: mit library_name: pytorch tags: - diffusion - pathology - wsi - distillation datasets: - W8Yi/tcga-wsi-uni2h-features base_model: - StonyBrook-CVLab/PixCell-256 --- # W8Yi/distilled-wsi-diffusion ![Teacher vs Student](./tile.png) `distilled-wsi-diffusion` is a distilled student model derived from **PixCell** for **UNI**-conditioned histopathology image generation. It is designed to preserve the visual behavior of the PixCell teacher while enabling substantially faster sampling with fewer denoising steps(**7.06x** speed up), making it practical for rapid research iteration, hypothesis testing, and interpretability workflows on WSI features. ## Why Use This Model - Faster inference than full-step teacher sampling for UNI-conditioned generation. - Compatible with PixCell-based conditioning workflow already used in this repo. - Useful for pathology-focused generative experiments where turnaround time matters. ## What Is Included - `student_model.safetensors`: distilled student weights. - `inference_config.json`: base model IDs and loading config. - `training_args_full.json`: original training args captured from checkpoint. - `checkpoint_export_summary.json`: export metadata. ## Quick Use (In This Codebase) This model was trained/tested with the helper code in `models/diffusion.py`. ```python import json import torch from safetensors.torch import load_file from models.diffusion import ( PixCellConfig, build_pixcell_pipeline, build_teacher_student, sample_student_trajectory, decode_latents_to_images, ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") cfg = json.load(open("inference_config.json", "r")) pix_cfg = PixCellConfig( pix_model_id=cfg["pix_model_id"], pix_pipeline_id=cfg["pix_pipeline_id"], vae_model_id=cfg["vae_model_id"], vae_subfolder=cfg["vae_subfolder"], dtype=torch.float16, ) pipeline = build_pixcell_pipeline(pix_cfg, device=device) cond_dim = int(cfg["cond_dim"]) student_arch = cfg.get("student_arch", "pixcell") teacher, student = build_teacher_student( pipeline, cond_dim=cond_dim, init_student_from_teacher=True, student_arch=student_arch, ) state = load_file("student_model.safetensors") student.load_state_dict(state, strict=True) student.to(device=device, dtype=torch.float32).eval() # Replace with a real UNI feature: shape [B,1,1536] cond = torch.randn(1, 1, cond_dim, device=device, dtype=torch.float32) latents = sample_student_trajectory( student=student, cond=cond, pipeline=pipeline, latent_channels=int(pipeline.vae.config.latent_channels), latent_size=int(cfg.get("latent_size", 32)), steps=int(cfg.get("default_sample_steps", 4)), guidance_scale=float(cfg.get("guidance_student", 1.0)), ) img = decode_latents_to_images(pipeline, latents)[0] ``` ## Generate In 3 Steps 1. Load base PixCell pipeline + this distilled student. 2. Feed one UNI feature (`[1,1,1536]`) as condition. 3. Sample with a small step count (for example, 4) and decode. ## Teacher vs Student (Visualization + Timing) `compare.png` (left = teacher, right = student): ![Teacher vs Student](./compare.png) Teacher rollout (35 steps): 0.8908s Student rollout (4 steps): **0.1137s** Teacher decode: 0.0147s Student decode: **0.0145s** Teacher total: 0.9055s Student total: **0.1282s** Rollout speedup: **7.84x** End-to-end speedup: **7.06x** Use the following snippet to reproduce side-by-side image and speedup numbers: ```python import time import random import torch import numpy as np from PIL import Image from IPython.display import display from models.diffusion import ( make_uncond_embedding, scheduler_rollout, decode_latents_to_images, ) idx = random.randrange(len(test_ds)) uni_feat = test_ds[idx] # [1536] cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32) # [1,1,1536] # cond: [1,1,1536] from test manifest (as in previous cell) # student, teacher, pipeline already loaded student.eval() teacher.eval() latent_channels = int(pipeline.vae.config.latent_channels) latent_size = 32 steps_student = 4 steps_teacher = 35 guidance_student = 1.0 guidance_teacher = 3.0 # fixed noise for fair comparison g = torch.Generator(device=device) g.manual_seed(1234) xT = torch.randn( (1, latent_channels, latent_size, latent_size), generator=g, device=device, dtype=torch.float32, # base noise dtype ) def sync_if_cuda(dev): if dev.type == "cuda": torch.cuda.synchronize(dev) with torch.no_grad(): # teacher/original PixCell timing sync_if_cuda(device) t0 = time.perf_counter() _, teacher_states = scheduler_rollout( model=teacher, pipeline=pipeline, xT=xT.to(dtype=next(teacher.parameters()).dtype), cond=cond.to(dtype=next(teacher.parameters()).dtype), num_steps=steps_teacher, guidance_scale=guidance_teacher, ) sync_if_cuda(device) t_teacher_rollout = time.perf_counter() - t0 lat_teacher = teacher_states[-1] # student timing sync_if_cuda(device) t0 = time.perf_counter() _, student_states = scheduler_rollout( model=student, pipeline=pipeline, xT=xT.to(dtype=next(student.parameters()).dtype), cond=cond.to(dtype=next(student.parameters()).dtype), num_steps=steps_student, guidance_scale=guidance_student, ) sync_if_cuda(device) t_student_rollout = time.perf_counter() - t0 lat_student = student_states[-1] # teacher decode timing sync_if_cuda(device) t0 = time.perf_counter() img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0] sync_if_cuda(device) t_teacher_decode = time.perf_counter() - t0 # student decode timing sync_if_cuda(device) t0 = time.perf_counter() img_student = decode_latents_to_images(pipeline, lat_student)[0] sync_if_cuda(device) t_student_decode = time.perf_counter() - t0 arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8) arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8) display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1))) # left=teacher, right=student teacher_total = t_teacher_rollout + t_teacher_decode student_total = t_student_rollout + t_student_decode print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s") print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s") print(f"Teacher decode: {t_teacher_decode:.4f}s") print(f"Student decode: {t_student_decode:.4f}s") print(f"Teacher total: {teacher_total:.4f}s") print(f"Student total: {student_total:.4f}s") print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x") print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x") ``` ## Notes - This is a distilled student checkpoint intended for research. - Base model/pipeline dependencies are: - `StonyBrook-CVLab/PixCell-256` - `StonyBrook-CVLab/PixCell-pipeline` - `stabilityai/stable-diffusion-3.5-large` (VAE subfolder `vae`) - Please check and comply with upstream model licenses/terms.