ZoomLDM-brca / run_demo_inference.py
BiliSakura's picture
Add files using upload-large-folder tool
f7038f8 verified
#!/usr/bin/env python3
"""Run ZoomLDM-BRCA demo inference using local demo assets."""
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DiffusionPipeline
def preprocess_brca_ssl(npy_path: Path) -> torch.Tensor:
# Copied from dataset material:
# 1) cast to float32
# 2) normalize per-feature
# 3) reshape to (1024, h, h)
# 4) adaptive avg pool to max 8x8 if needed
feat = np.load(npy_path).astype(np.float32)
if feat.ndim == 1:
feat = feat[:, None]
mean = feat.mean(axis=0, keepdims=True)
std = feat.std(axis=0, keepdims=True)
feat = (feat - mean) / (std + 1e-8)
h = int(np.sqrt(feat.shape[1]))
feat = torch.tensor(feat.reshape((-1, h, h))).float() # (1024, h, h)
if h > 8:
feat = F.adaptive_avg_pool2d(feat, (8, 8))
return feat
def main() -> None:
repo = Path(__file__).resolve().parent
demo_dir = repo / "demo_images"
demo_data = repo / "demo_data"
demo_dir.mkdir(exist_ok=True)
# Use repo-local demo assets only.
src_img = demo_dir / "input.jpeg"
src_feat = demo_data / "0_ssl_feat.npy"
if not src_img.exists():
raise FileNotFoundError(f"Missing demo input image: {src_img}")
if not src_feat.exists():
raise FileNotFoundError(f"Missing demo SSL feature: {src_feat}")
ssl_feat = preprocess_brca_ssl(src_feat).unsqueeze(0).to("cuda") # (1, 1024, h, h)
magnification = torch.tensor([0], device="cuda", dtype=torch.long)
pipe = DiffusionPipeline.from_pretrained(
str(repo),
custom_pipeline=str(repo / "pipeline_zoomldm.py"),
trust_remote_code=True,
local_files_only=True,
).to("cuda")
out = pipe(
ssl_features=ssl_feat,
magnification=magnification,
num_inference_steps=50,
guidance_scale=2.0,
generator=torch.Generator(device="cuda").manual_seed(42),
)
out.images[0].save(demo_dir / "output.jpeg")
print(f"Saved {demo_dir / 'input.jpeg'}")
print(f"Saved {demo_dir / 'output.jpeg'}")
if __name__ == "__main__":
main()