| | --- |
| | language: en |
| | license: mit |
| | library_name: pytorch |
| | tags: |
| | - diffusion |
| | - pathology |
| | - wsi |
| | - distillation |
| | datasets: |
| | - W8Yi/tcga-wsi-uni2h-features |
| | base_model: |
| | - StonyBrook-CVLab/PixCell-256 |
| | --- |
| | |
| | # W8Yi/distilled-wsi-diffusion |
| |  |
| |
|
| | `distilled-wsi-diffusion` is a distilled student model derived from **PixCell** for |
| | **UNI**-conditioned histopathology image generation. It is designed to preserve the |
| | visual behavior of the PixCell teacher while enabling substantially faster |
| | sampling with fewer denoising steps(**7.06x** speed up), making it practical for rapid research |
| | iteration, hypothesis testing, and interpretability workflows on WSI features. |
| |
|
| | ## Why Use This Model |
| |
|
| | - Faster inference than full-step teacher sampling for UNI-conditioned generation. |
| | - Compatible with PixCell-based conditioning workflow already used in this repo. |
| | - Useful for pathology-focused generative experiments where turnaround time matters. |
| |
|
| | ## What Is Included |
| |
|
| | - `student_model.safetensors`: distilled student weights. |
| | - `inference_config.json`: base model IDs and loading config. |
| | - `training_args_full.json`: original training args captured from checkpoint. |
| | - `checkpoint_export_summary.json`: export metadata. |
| |
|
| | ## Quick Use (In This Codebase) |
| |
|
| | This model was trained/tested with the helper code in `models/diffusion.py`. |
| |
|
| | ```python |
| | import json |
| | import torch |
| | from safetensors.torch import load_file |
| | |
| | from models.diffusion import ( |
| | PixCellConfig, |
| | build_pixcell_pipeline, |
| | build_teacher_student, |
| | sample_student_trajectory, |
| | decode_latents_to_images, |
| | ) |
| | |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | |
| | cfg = json.load(open("inference_config.json", "r")) |
| | pix_cfg = PixCellConfig( |
| | pix_model_id=cfg["pix_model_id"], |
| | pix_pipeline_id=cfg["pix_pipeline_id"], |
| | vae_model_id=cfg["vae_model_id"], |
| | vae_subfolder=cfg["vae_subfolder"], |
| | dtype=torch.float16, |
| | ) |
| | pipeline = build_pixcell_pipeline(pix_cfg, device=device) |
| | |
| | cond_dim = int(cfg["cond_dim"]) |
| | student_arch = cfg.get("student_arch", "pixcell") |
| | teacher, student = build_teacher_student( |
| | pipeline, |
| | cond_dim=cond_dim, |
| | init_student_from_teacher=True, |
| | student_arch=student_arch, |
| | ) |
| | state = load_file("student_model.safetensors") |
| | student.load_state_dict(state, strict=True) |
| | student.to(device=device, dtype=torch.float32).eval() |
| | |
| | # Replace with a real UNI feature: shape [B,1,1536] |
| | cond = torch.randn(1, 1, cond_dim, device=device, dtype=torch.float32) |
| | |
| | latents = sample_student_trajectory( |
| | student=student, |
| | cond=cond, |
| | pipeline=pipeline, |
| | latent_channels=int(pipeline.vae.config.latent_channels), |
| | latent_size=int(cfg.get("latent_size", 32)), |
| | steps=int(cfg.get("default_sample_steps", 4)), |
| | guidance_scale=float(cfg.get("guidance_student", 1.0)), |
| | ) |
| | img = decode_latents_to_images(pipeline, latents)[0] |
| | ``` |
| |
|
| | ## Generate In 3 Steps |
| |
|
| | 1. Load base PixCell pipeline + this distilled student. |
| | 2. Feed one UNI feature (`[1,1,1536]`) as condition. |
| | 3. Sample with a small step count (for example, 4) and decode. |
| |
|
| | ## Teacher vs Student (Visualization + Timing) |
| |
|
| | `compare.png` (left = teacher, right = student): |
| |
|
| |  |
| |
|
| | Teacher rollout (35 steps): 0.8908s |
| | Student rollout (4 steps): **0.1137s** |
| | Teacher decode: 0.0147s |
| | Student decode: **0.0145s** |
| | Teacher total: 0.9055s |
| | Student total: **0.1282s** |
| | Rollout speedup: **7.84x** |
| | End-to-end speedup: **7.06x** |
| |
|
| | Use the following snippet to reproduce side-by-side image and speedup numbers: |
| |
|
| | ```python |
| | import time |
| | import random |
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | from IPython.display import display |
| | |
| | from models.diffusion import ( |
| | make_uncond_embedding, |
| | scheduler_rollout, |
| | decode_latents_to_images, |
| | ) |
| | |
| | idx = random.randrange(len(test_ds)) |
| | uni_feat = test_ds[idx] # [1536] |
| | cond = uni_feat.unsqueeze(0).unsqueeze(1).to(device=device, dtype=torch.float32) # [1,1,1536] |
| | |
| | # cond: [1,1,1536] from test manifest (as in previous cell) |
| | # student, teacher, pipeline already loaded |
| | student.eval() |
| | teacher.eval() |
| | |
| | latent_channels = int(pipeline.vae.config.latent_channels) |
| | latent_size = 32 |
| | steps_student = 4 |
| | steps_teacher = 35 |
| | guidance_student = 1.0 |
| | guidance_teacher = 3.0 |
| | |
| | # fixed noise for fair comparison |
| | g = torch.Generator(device=device) |
| | g.manual_seed(1234) |
| | xT = torch.randn( |
| | (1, latent_channels, latent_size, latent_size), |
| | generator=g, |
| | device=device, |
| | dtype=torch.float32, # base noise dtype |
| | ) |
| | |
| | def sync_if_cuda(dev): |
| | if dev.type == "cuda": |
| | torch.cuda.synchronize(dev) |
| | |
| | with torch.no_grad(): |
| | # teacher/original PixCell timing |
| | sync_if_cuda(device) |
| | t0 = time.perf_counter() |
| | _, teacher_states = scheduler_rollout( |
| | model=teacher, |
| | pipeline=pipeline, |
| | xT=xT.to(dtype=next(teacher.parameters()).dtype), |
| | cond=cond.to(dtype=next(teacher.parameters()).dtype), |
| | num_steps=steps_teacher, |
| | guidance_scale=guidance_teacher, |
| | ) |
| | sync_if_cuda(device) |
| | t_teacher_rollout = time.perf_counter() - t0 |
| | lat_teacher = teacher_states[-1] |
| | |
| | # student timing |
| | sync_if_cuda(device) |
| | t0 = time.perf_counter() |
| | _, student_states = scheduler_rollout( |
| | model=student, |
| | pipeline=pipeline, |
| | xT=xT.to(dtype=next(student.parameters()).dtype), |
| | cond=cond.to(dtype=next(student.parameters()).dtype), |
| | num_steps=steps_student, |
| | guidance_scale=guidance_student, |
| | ) |
| | sync_if_cuda(device) |
| | t_student_rollout = time.perf_counter() - t0 |
| | lat_student = student_states[-1] |
| | |
| | # teacher decode timing |
| | sync_if_cuda(device) |
| | t0 = time.perf_counter() |
| | img_teacher = decode_latents_to_images(pipeline, lat_teacher)[0] |
| | sync_if_cuda(device) |
| | t_teacher_decode = time.perf_counter() - t0 |
| | |
| | # student decode timing |
| | sync_if_cuda(device) |
| | t0 = time.perf_counter() |
| | img_student = decode_latents_to_images(pipeline, lat_student)[0] |
| | sync_if_cuda(device) |
| | t_student_decode = time.perf_counter() - t0 |
| | |
| | arr_t = (img_teacher.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8) |
| | arr_s = (img_student.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8) |
| | |
| | display(Image.fromarray(np.concatenate([arr_t, arr_s], axis=1))) # left=teacher, right=student |
| | |
| | teacher_total = t_teacher_rollout + t_teacher_decode |
| | student_total = t_student_rollout + t_student_decode |
| | |
| | print(f"Teacher rollout ({steps_teacher} steps): {t_teacher_rollout:.4f}s") |
| | print(f"Student rollout ({steps_student} steps): {t_student_rollout:.4f}s") |
| | print(f"Teacher decode: {t_teacher_decode:.4f}s") |
| | print(f"Student decode: {t_student_decode:.4f}s") |
| | print(f"Teacher total: {teacher_total:.4f}s") |
| | print(f"Student total: {student_total:.4f}s") |
| | print(f"Rollout speedup: {t_teacher_rollout / max(t_student_rollout, 1e-9):.2f}x") |
| | print(f"End-to-end speedup: {teacher_total / max(student_total, 1e-9):.2f}x") |
| | ``` |
| |
|
| | ## Notes |
| |
|
| | - This is a distilled student checkpoint intended for research. |
| | - Base model/pipeline dependencies are: |
| | - `StonyBrook-CVLab/PixCell-256` |
| | - `StonyBrook-CVLab/PixCell-pipeline` |
| | - `stabilityai/stable-diffusion-3.5-large` (VAE subfolder `vae`) |
| | - Please check and comply with upstream model licenses/terms. |
| |
|