| |
| """Run ZoomLDM-BRCA demo inference using local demo assets.""" |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from diffusers import DiffusionPipeline |
|
|
|
|
| def preprocess_brca_ssl(npy_path: Path) -> torch.Tensor: |
| |
| |
| |
| |
| |
| feat = np.load(npy_path).astype(np.float32) |
| if feat.ndim == 1: |
| feat = feat[:, None] |
| mean = feat.mean(axis=0, keepdims=True) |
| std = feat.std(axis=0, keepdims=True) |
| feat = (feat - mean) / (std + 1e-8) |
| h = int(np.sqrt(feat.shape[1])) |
| feat = torch.tensor(feat.reshape((-1, h, h))).float() |
| if h > 8: |
| feat = F.adaptive_avg_pool2d(feat, (8, 8)) |
| return feat |
|
|
|
|
| def main() -> None: |
| repo = Path(__file__).resolve().parent |
| demo_dir = repo / "demo_images" |
| demo_data = repo / "demo_data" |
| demo_dir.mkdir(exist_ok=True) |
|
|
| |
| src_img = demo_dir / "input.jpeg" |
| src_feat = demo_data / "0_ssl_feat.npy" |
| if not src_img.exists(): |
| raise FileNotFoundError(f"Missing demo input image: {src_img}") |
| if not src_feat.exists(): |
| raise FileNotFoundError(f"Missing demo SSL feature: {src_feat}") |
|
|
| ssl_feat = preprocess_brca_ssl(src_feat).unsqueeze(0).to("cuda") |
| magnification = torch.tensor([0], device="cuda", dtype=torch.long) |
|
|
| pipe = DiffusionPipeline.from_pretrained( |
| str(repo), |
| custom_pipeline=str(repo / "pipeline_zoomldm.py"), |
| trust_remote_code=True, |
| local_files_only=True, |
| ).to("cuda") |
|
|
| out = pipe( |
| ssl_features=ssl_feat, |
| magnification=magnification, |
| num_inference_steps=50, |
| guidance_scale=2.0, |
| generator=torch.Generator(device="cuda").manual_seed(42), |
| ) |
| out.images[0].save(demo_dir / "output.jpeg") |
| print(f"Saved {demo_dir / 'input.jpeg'}") |
| print(f"Saved {demo_dir / 'output.jpeg'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|