diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9b08b06191f61126a482ed40be4f95f3a04edcfa 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,99 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00013_test_0_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00013_test_0_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00013_test_0_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00013_test_0_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00013_test_0_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00013_test_0_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00198_test_0_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00198_test_0_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00198_test_0_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00198_test_0_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00198_test_0_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_0_00198_test_0_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00276_test_1+_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00276_test_1+_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00276_test_1+_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00276_test_1+_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00276_test_1+_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00276_test_1+_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00791_test_1+_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00791_test_1+_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00791_test_1+_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00791_test_1+_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00791_test_1+_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_1+_00791_test_1+_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00259_test_2+_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00259_test_2+_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00259_test_2+_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00259_test_2+_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00259_test_2+_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00259_test_2+_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00293_test_2+_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00293_test_2+_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00293_test_2+_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00293_test_2+_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00293_test_2+_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_2+_00293_test_2+_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00220_test_3+_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00220_test_3+_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00220_test_3+_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00220_test_3+_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00220_test_3+_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00220_test_3+_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00277_test_3+_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00277_test_3+_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00277_test_3+_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00277_test_3+_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00277_test_3+_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/BCI_HER2_3+_00277_test_3+_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_35M2101733_7_3_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_35M2101733_7_3_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_35M2101733_7_3_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_35M2101733_7_3_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_35M2101733_7_3_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_35M2101733_7_3_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_40M2101566_28_8_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_40M2101566_28_8_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_40M2101566_28_8_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_40M2101566_28_8_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_40M2101566_28_8_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_ER_40M2101566_28_8_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_19M2102438_35_28_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_19M2102438_35_28_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_19M2102438_35_28_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_19M2102438_35_28_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_19M2102438_35_28_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_19M2102438_35_28_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_67M2100642_15_18_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_67M2100642_15_18_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_67M2100642_15_18_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_67M2100642_15_18_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_67M2100642_15_18_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_HER2_67M2100642_15_18_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_10M2102916_10_20_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_10M2102916_10_20_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_10M2102916_10_20_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_10M2102916_10_20_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_10M2102916_10_20_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_10M2102916_10_20_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_80M2100377_14_29_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_80M2100377_14_29_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_80M2100377_14_29_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_80M2100377_14_29_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_80M2100377_14_29_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_Ki67_80M2100377_14_29_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_17M2102569_15_14_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_17M2102569_15_14_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_17M2102569_15_14_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_17M2102569_15_14_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_17M2102569_15_14_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_17M2102569_15_14_he.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_28M2101987_14_22_gen_er.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_28M2101987_14_22_gen_her2.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_28M2101987_14_22_gen_ki67.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_28M2101987_14_22_gen_pr.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_28M2101987_14_22_gt.png filter=lfs diff=lfs merge=lfs -text +gallery/images/MIST_PR_28M2101987_14_22_he.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 846959bad6b3c7b176ce6bf8c6e2482f04f30320..aa7c1199f04b91a94489488b8a785ab0005ac7d7 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,28 @@ --- -title: UNIStainNet -emoji: 👁 -colorFrom: yellow +title: UNIStainNet - Virtual IHC Staining +emoji: 🔬 +colorFrom: blue colorTo: purple sdk: gradio -sdk_version: 6.10.0 +sdk_version: "6.10.0" app_file: app.py pinned: false +license: mit +hardware: cpu-basic --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# UNIStainNet: Foundation-Model-Guided Virtual Staining + +Virtual staining of H&E histopathology images to IHC (HER2, Ki67, ER, PR) using a single unified 42M-parameter SPADE-UNet conditioned on dense spatial tokens from a frozen UNI pathology foundation model. + +## Features +- **Upload** an H&E image and generate IHC stains in real-time +- **Cross-stain comparison**: Generate all 4 stains from a single input +- **Gallery**: Browse pre-computed examples (no GPU needed) + +## Architecture +| Component | Details | +|-----------|---------| +| Generator | SPADE-UNet with UNI spatial conditioning + FiLM stain embeddings | +| UNI Features | 4x4 sub-crop tiling → UNI ViT-L/16 → 32x32 spatial tokens (1024-dim) | +| Parameters | 42M (generator), UNI frozen (303M) | diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7910ea234049eb88ea0d6ef450f5c693d2e368 --- /dev/null +++ b/app.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python3 +""" +UNIStainNet Interactive Demo — Hugging Face Spaces (ZeroGPU) + +Virtual staining of H&E histopathology images to IHC (HER2, Ki67, ER, PR). +Uses @spaces.GPU for on-demand GPU allocation on ZeroGPU. +""" + +import json +import os +import time +from pathlib import Path + +import gradio as gr +import numpy as np +import torch + +# ZeroGPU support: use @spaces.GPU if available, otherwise no-op +try: + import spaces + GPU_AVAILABLE = torch.cuda.is_available() +except ImportError: + spaces = None + GPU_AVAILABLE = torch.cuda.is_available() + + +def _gpu_decorator(duration=60): + """Apply @spaces.GPU if available, otherwise return identity decorator.""" + if spaces is not None and hasattr(spaces, "GPU"): + return spaces.GPU(duration=duration) + return lambda fn: fn +import torch.nn.functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from PIL import Image +from huggingface_hub import hf_hub_download + +from src.models.trainer import UNIStainNetTrainer +from src.data.mist_dataset import STAIN_TO_LABEL, LABEL_TO_STAIN + +# ── Constants ──────────────────────────────────────────────────────── +STAIN_NAMES = ["HER2", "Ki67", "ER", "PR"] +GALLERY_DIR = Path(__file__).parent / "gallery" +TARGET_SIZE = 512 + +# Model repo where checkpoint is stored (uploaded separately) +MODEL_REPO = os.environ.get("MODEL_REPO", "faceless-void/UNIStainNet") +CHECKPOINT_FILENAME = "mist_multistain_last.ckpt" + +# ── Global model cache (loaded lazily on GPU request) ──────────────── +_model_cache = {"model": None, "uni_model": None, "spatial_pool_size": 32} + + +def _get_checkpoint_path(): + """Download checkpoint from HF Hub if not local.""" + local_path = Path(__file__).parent / "checkpoints" / CHECKPOINT_FILENAME + if local_path.exists(): + return str(local_path) + # Download from HF model repo + return hf_hub_download(repo_id=MODEL_REPO, filename=CHECKPOINT_FILENAME) + + +def _load_models(): + """Load UNIStainNet + UNI (called inside @spaces.GPU function).""" + if _model_cache["model"] is None: + import timm + + ckpt_path = _get_checkpoint_path() + print(f"Loading UNIStainNet from {ckpt_path} ...") + model = UNIStainNetTrainer.load_from_checkpoint(ckpt_path, strict=False) + model = model.cuda().eval() + _model_cache["model"] = model + _model_cache["spatial_pool_size"] = getattr( + model.hparams, "uni_spatial_size", 32 + ) + print(" Generator loaded") + + print("Loading UNI ViT-L/16 ...") + uni_model = timm.create_model( + "hf-hub:MahmoodLab/uni", + pretrained=True, + init_values=1e-5, + dynamic_img_size=True, + ) + uni_model = uni_model.cuda().eval() + _model_cache["uni_model"] = uni_model + print(" UNI loaded") + else: + # Models already loaded — move to current GPU device + _model_cache["model"] = _model_cache["model"].cuda() + _model_cache["uni_model"] = _model_cache["uni_model"].cuda() + + return _model_cache["model"], _model_cache["uni_model"], _model_cache["spatial_pool_size"] + + +# ── Preprocessing helpers ──────────────────────────────────────────── + +def preprocess_he(pil_image, target_size=TARGET_SIZE): + """Center-crop and resize H&E to target_size x target_size.""" + w, h = pil_image.size + short = min(w, h) + left = (w - short) // 2 + top = (h - short) // 2 + pil_image = pil_image.crop((left, top, left + short, top + short)) + if short != target_size: + pil_image = pil_image.resize((target_size, target_size), Image.BICUBIC) + return pil_image + + +def pil_to_tensor(pil_image): + """PIL → [1, 3, H, W] in [-1, 1].""" + t = TF.to_tensor(pil_image) + t = TF.normalize(t, [0.5] * 3, [0.5] * 3) + return t.unsqueeze(0) + + +def tensor_to_pil(tensor): + """[1, 3, H, W] in [-1, 1] → PIL.""" + t = ((tensor[0].cpu() + 1) / 2).clamp(0, 1) + return TF.to_pil_image(t) + + +def extract_uni_features(uni_model, he_tensor_01, spatial_pool_size=32): + """Extract UNI spatial features from H&E crop ([1,3,H,W] in [0,1]).""" + uni_transform = T.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + B = he_tensor_01.shape[0] + num_crops = 4 + patches_per_side = 14 + crop_h = he_tensor_01.shape[2] // num_crops + crop_w = he_tensor_01.shape[3] // num_crops + + sub_crops = [] + for i in range(num_crops): + for j in range(num_crops): + sub = he_tensor_01[ + :, :, i * crop_h : (i + 1) * crop_h, j * crop_w : (j + 1) * crop_w + ] + sub = F.interpolate(sub, size=(224, 224), mode="bicubic", align_corners=False) + sub = torch.stack([uni_transform(s) for s in sub]) + sub_crops.append(sub) + + all_crops = torch.stack(sub_crops, dim=1).reshape(B * 16, 3, 224, 224).cuda() + + with torch.no_grad(): + all_feats = uni_model.forward_features(all_crops) + patch_tokens = all_feats[:, 1:, :] + + patch_tokens = patch_tokens.reshape( + B, num_crops, num_crops, patches_per_side, patches_per_side, 1024 + ) + full_size = num_crops * patches_per_side + full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5).reshape( + B, full_size, full_size, 1024 + ) + + S = spatial_pool_size + if S < full_size: + grid_bchw = full_grid.permute(0, 3, 1, 2) + pooled = F.adaptive_avg_pool2d(grid_bchw, S) + result = pooled.permute(0, 2, 3, 1) + else: + result = full_grid + + return result.reshape(B, S * S, 1024) + + +# ── GPU-accelerated inference functions ────────────────────────────── + +@_gpu_decorator(duration=60) +def generate_single_stain(image, stain, guidance_scale): + """Generate a single IHC stain from an H&E image (GPU).""" + if image is None: + return None, "No image uploaded" + + t0 = time.time() + model, uni_model, spatial_pool_size = _load_models() + + he_pil = preprocess_he(image) + he_tensor = pil_to_tensor(he_pil).cuda() + he_01 = ((he_tensor + 1) / 2).clamp(0, 1) + + uni_feats = extract_uni_features(uni_model, he_01, spatial_pool_size).cuda() + label = STAIN_TO_LABEL[stain] + labels = torch.tensor([label], device="cuda", dtype=torch.long) + + with torch.no_grad(): + gen = model.generate(he_tensor, uni_feats, labels, guidance_scale=guidance_scale) + + result = tensor_to_pil(gen) + elapsed = time.time() - t0 + return result, f"{elapsed:.2f}s" + + +@_gpu_decorator(duration=120) +def generate_all_stains(image, guidance_scale): + """Generate all 4 IHC stains from one H&E image (GPU).""" + if image is None: + return None, None, None, None, None, "No image uploaded" + + t0 = time.time() + model, uni_model, spatial_pool_size = _load_models() + + he_pil = preprocess_he(image) + he_tensor = pil_to_tensor(he_pil).cuda() + he_01 = ((he_tensor + 1) / 2).clamp(0, 1) + + uni_feats = extract_uni_features(uni_model, he_01, spatial_pool_size).cuda() + + results = {} + for stain in STAIN_NAMES: + label = STAIN_TO_LABEL[stain] + labels = torch.tensor([label], device="cuda", dtype=torch.long) + with torch.no_grad(): + gen = model.generate( + he_tensor, uni_feats, labels, guidance_scale=guidance_scale + ) + results[stain] = tensor_to_pil(gen) + + elapsed = time.time() - t0 + return ( + he_pil, + results["HER2"], + results["Ki67"], + results["ER"], + results["PR"], + f"{elapsed:.2f}s", + ) + + +# ── Gallery helpers ────────────────────────────────────────────────── + +def load_gallery(): + meta_path = GALLERY_DIR / "metadata.json" + if not meta_path.exists(): + return None + with open(meta_path) as f: + return json.load(f) + + +def show_gallery(name, gallery): + if not name or not gallery or name not in gallery: + return None, None, None, None, None, None + entry = gallery[name] + base = GALLERY_DIR / "images" + he = Image.open(base / entry["he"]).convert("RGB") if "he" in entry else None + gt = Image.open(base / entry["gt"]).convert("RGB") if "gt" in entry else None + gen_her2 = Image.open(base / entry["gen_her2"]).convert("RGB") if "gen_her2" in entry else None + gen_ki67 = Image.open(base / entry["gen_ki67"]).convert("RGB") if "gen_ki67" in entry else None + gen_er = Image.open(base / entry["gen_er"]).convert("RGB") if "gen_er" in entry else None + gen_pr = Image.open(base / entry["gen_pr"]).convert("RGB") if "gen_pr" in entry else None + return he, gt, gen_her2, gen_ki67, gen_er, gen_pr + + +# ── Build Gradio App ───────────────────────────────────────────────── + +gallery = load_gallery() +gallery_names = list(gallery.keys()) if gallery else [] + +with gr.Blocks(title="UNIStainNet — Virtual IHC Staining") as demo: + gr.Markdown( + """ + # UNIStainNet: Foundation-Model-Guided Virtual Staining + **H&E → IHC (HER2, Ki67, ER, PR)**  |  + Single unified model  |  42M parameters  |  + UNI spatial conditioning + """ + ) + + # ── Tab 1: Single Stain ────────────────────────────────────── + with gr.Tab("Virtual Staining"): + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image(type="pil", label="Upload H&E Image", height=400) + stain_choice = gr.Radio( + choices=STAIN_NAMES, value="HER2", label="Target IHC Stain" + ) + guidance_slider = gr.Slider( + minimum=1.0, maximum=3.0, step=0.1, value=1.0, + label="Guidance Scale (1.0 = no CFG)", + ) + generate_btn = gr.Button("Generate", variant="primary") + gen_time = gr.Textbox(label="Time", interactive=False) + with gr.Column(scale=1): + output_image = gr.Image(type="pil", label="Generated IHC", height=400) + + generate_btn.click( + fn=generate_single_stain, + inputs=[input_image, stain_choice, guidance_slider], + outputs=[output_image, gen_time], + ) + + # ── Tab 2: Cross-Stain ─────────────────────────────────────── + with gr.Tab("Cross-Stain Comparison"): + gr.Markdown( + "Generate **all 4 IHC stains** from a single H&E input. " + "Demonstrates the unified multi-stain capability." + ) + with gr.Row(): + cross_input = gr.Image(type="pil", label="Upload H&E Image", height=350) + cross_guidance = gr.Slider( + minimum=1.0, maximum=3.0, step=0.1, value=1.0, + label="Guidance Scale", + ) + cross_btn = gr.Button("Generate All Stains", variant="primary") + cross_time = gr.Textbox(label="Time", interactive=False) + + with gr.Row(): + cross_he_out = gr.Image(type="pil", label="H&E Input", height=300) + cross_her2 = gr.Image(type="pil", label="HER2", height=300) + cross_ki67 = gr.Image(type="pil", label="Ki67", height=300) + cross_er = gr.Image(type="pil", label="ER", height=300) + cross_pr = gr.Image(type="pil", label="PR", height=300) + + cross_btn.click( + fn=generate_all_stains, + inputs=[cross_input, cross_guidance], + outputs=[cross_he_out, cross_her2, cross_ki67, cross_er, cross_pr, cross_time], + ) + + # ── Tab 3: Gallery ─────────────────────────────────────────── + with gr.Tab("Gallery"): + if not gallery_names: + gr.Markdown("No pre-computed gallery available.") + else: + gr.Markdown( + "Pre-computed examples — no GPU required. " + "Select an example to view the H&E input and generated IHC stains." + ) + gallery_dropdown = gr.Dropdown( + choices=gallery_names, + value=gallery_names[0] if gallery_names else None, + label="Select Example", + ) + with gr.Row(): + gal_he = gr.Image(type="pil", label="H&E Input", height=300) + gal_gt = gr.Image(type="pil", label="Ground Truth IHC", height=300) + with gr.Row(): + gal_her2 = gr.Image(type="pil", label="Generated HER2", height=300) + gal_ki67 = gr.Image(type="pil", label="Generated Ki67", height=300) + gal_er = gr.Image(type="pil", label="Generated ER", height=300) + gal_pr = gr.Image(type="pil", label="Generated PR", height=300) + + gallery_dropdown.change( + fn=lambda name: show_gallery(name, gallery), + inputs=[gallery_dropdown], + outputs=[gal_he, gal_gt, gal_her2, gal_ki67, gal_er, gal_pr], + ) + + # ── Tab 4: About ───────────────────────────────────────────── + with gr.Tab("About"): + gr.Markdown( + """ + ## UNIStainNet + + A SPADE-UNet generator conditioned on dense spatial tokens from a frozen + [UNI](https://github.com/mahmoodlab/UNI) pathology foundation model (ViT-L/16). + + **Key features:** + - Dense UNI spatial conditioning (32x32 = 1,024 tokens) + - Misalignment-aware loss suite for consecutive-section training pairs + - Single unified model serves 4 IHC markers (HER2, Ki67, ER, PR) + - 42M generator parameters, single forward pass inference + + ### Architecture + + | Component | Details | + |-----------|---------| + | Generator | SPADE-UNet with UNI spatial conditioning + FiLM stain embeddings | + | Discriminator | Multi-scale PatchGAN (512 + 256) with spectral norm | + | UNI Features | 4x4 sub-crop tiling → UNI ViT-L/16 → 32x32 spatial tokens | + | Parameters | 42M (generator), UNI frozen (303M) | + + ### Results (MIST, unified model) + + | Stain | FID ↓ | KID×1k ↓ | Pearson-R ↑ | DAB KL ↓ | + |-------|-------|-----------|-------------|----------| + | HER2 | 34.5 | 2.2 | 0.929 | 0.166 | + | Ki67 | 27.2 | 1.8 | 0.927 | 0.119 | + | ER | 29.2 | 1.8 | 0.949 | 0.182 | + | PR | 29.0 | 1.1 | 0.943 | 0.171 | + """ + ) + +if __name__ == "__main__": + demo.launch(theme=gr.themes.Soft()) diff --git a/gallery/images/BCI_HER2_0_00013_test_0_gen_er.png b/gallery/images/BCI_HER2_0_00013_test_0_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..3a1f55955b0d08b351fd860f8a24f3402b9cd4a3 --- /dev/null +++ b/gallery/images/BCI_HER2_0_00013_test_0_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc789fa197425c49b7143d09cc64ce779a8f2ab90d02f74e458184c4ba869dca +size 520962 diff --git a/gallery/images/BCI_HER2_0_00013_test_0_gen_her2.png b/gallery/images/BCI_HER2_0_00013_test_0_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..7f039dde5ada21e19315b22d755d4998785af823 --- /dev/null +++ b/gallery/images/BCI_HER2_0_00013_test_0_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57c0fd8ffda5493549c298837eaa047628b1438883e54217bf5bca159ecd5987 +size 518663 diff --git a/gallery/images/BCI_HER2_0_00013_test_0_gen_ki67.png b/gallery/images/BCI_HER2_0_00013_test_0_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..ad31edd294d040db2a74a7962c63d3ef48e16b62 --- /dev/null +++ b/gallery/images/BCI_HER2_0_00013_test_0_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8e223fc1af2edf7b59e0d819726dac78fffc9e8acd493fbfca7b23cf1cb9bab +size 501484 diff --git a/gallery/images/BCI_HER2_0_00013_test_0_gen_pr.png b/gallery/images/BCI_HER2_0_00013_test_0_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..66d7f68564f0fe1b14c0f8b9304185d5906ec138 --- /dev/null +++ b/gallery/images/BCI_HER2_0_00013_test_0_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7f58670da43960e877b49db11836fd0f9707f23d462f17eae73b10751c7545b +size 499433 diff --git a/gallery/images/BCI_HER2_0_00013_test_0_gt.png b/gallery/images/BCI_HER2_0_00013_test_0_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..6e505f7b6427602823087d912202fb7f9f2b36fb --- /dev/null +++ b/gallery/images/BCI_HER2_0_00013_test_0_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6d161935ea85422f14cfa9a0a358fc2ae7489e9ea1b52deb28cdd6bb23d59e2 +size 501621 diff --git a/gallery/images/BCI_HER2_0_00013_test_0_he.png b/gallery/images/BCI_HER2_0_00013_test_0_he.png new file mode 100644 index 0000000000000000000000000000000000000000..7cb330211f604529cde1da30d66bd21c562eed05 --- /dev/null +++ b/gallery/images/BCI_HER2_0_00013_test_0_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69687a14b1a4c2269283376c2d537ee46642a46f95c73f0d331a2308c147aa8c +size 561743 diff --git a/gallery/images/BCI_HER2_0_00198_test_0_gen_er.png b/gallery/images/BCI_HER2_0_00198_test_0_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..72d7f724f387b8f5b85dcfc69798e7217453d0aa --- /dev/null +++ b/gallery/images/BCI_HER2_0_00198_test_0_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9678827b6612a91c8eab8a21b309ea5ef1e34e50c7c1c28d64919860adeaaa5 +size 282535 diff --git a/gallery/images/BCI_HER2_0_00198_test_0_gen_her2.png b/gallery/images/BCI_HER2_0_00198_test_0_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..37dfe9f5a38a47733fb07b2c2200a444f81fae4f --- /dev/null +++ b/gallery/images/BCI_HER2_0_00198_test_0_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8d29e119978a5543d1b6dcf1068e28ea869cf71d184d6601251aeabadf14cb1 +size 288666 diff --git a/gallery/images/BCI_HER2_0_00198_test_0_gen_ki67.png b/gallery/images/BCI_HER2_0_00198_test_0_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..252710cf0edd8c41994cf412615c18f09dde3fab --- /dev/null +++ b/gallery/images/BCI_HER2_0_00198_test_0_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd8ae812c7e61e9e8376526f9b2438ba5cfb7e8b322882ca5b17435bab1fb361 +size 274378 diff --git a/gallery/images/BCI_HER2_0_00198_test_0_gen_pr.png b/gallery/images/BCI_HER2_0_00198_test_0_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..198ef6dc7b02444a2d297c34710696e6e3eefcff --- /dev/null +++ b/gallery/images/BCI_HER2_0_00198_test_0_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e008512f7b111fdbeba4850c400a9cf5ba0286c878dae0a7755375affacd561 +size 289659 diff --git a/gallery/images/BCI_HER2_0_00198_test_0_gt.png b/gallery/images/BCI_HER2_0_00198_test_0_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..6831bd115f77f81bf3ec0ccbe8a81d201a4688b7 --- /dev/null +++ b/gallery/images/BCI_HER2_0_00198_test_0_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38ffc41a62699b2cf736cfbce735104f91cb3984414b54b7a9284fbb6876befb +size 243659 diff --git a/gallery/images/BCI_HER2_0_00198_test_0_he.png b/gallery/images/BCI_HER2_0_00198_test_0_he.png new file mode 100644 index 0000000000000000000000000000000000000000..f186973245c16eb671078043c640411a04c43b3f --- /dev/null +++ b/gallery/images/BCI_HER2_0_00198_test_0_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd6fb1f6cbc7b9c32fe17fb7541e85168aa6161690f36b7c4c8a78f8a13e464e +size 302054 diff --git a/gallery/images/BCI_HER2_1+_00276_test_1+_gen_er.png b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..4425f5f1ac84154adf4243ade00160275c8d8c65 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c94ede19f468234baf4d730a790de1880f4c1d32bedfb96038b559244cdfe05b +size 481362 diff --git a/gallery/images/BCI_HER2_1+_00276_test_1+_gen_her2.png b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..a9ebe65dfa685ebaa66d7b4bc9f2b22c925ccb07 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec91a421bdee47bac2c61efe98a4b7c26a68e8318501ab7845ce2b275682f622 +size 499579 diff --git a/gallery/images/BCI_HER2_1+_00276_test_1+_gen_ki67.png b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..4d77280d373ae023eeaff53b2b86fa6794ddce7e --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5d01b80e336fd66650eb331a4a2d19c21ebc66d64aa23f012ad47f33c2e9a8e +size 493574 diff --git a/gallery/images/BCI_HER2_1+_00276_test_1+_gen_pr.png b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..7121d7e4e2401854a242efe6a5fd4d44beeb13a4 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00276_test_1+_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7f2de3f3776824258817ec744a0943b0119a4c2931f7e3883cb157af2f7a9e6 +size 429973 diff --git a/gallery/images/BCI_HER2_1+_00276_test_1+_gt.png b/gallery/images/BCI_HER2_1+_00276_test_1+_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..9d18aa018ef312da240d8f58d6c5d8dfde214381 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00276_test_1+_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f42edca3b23a7880ca67d4dec49efc8e04b0474f2b9933b224893306c9ead690 +size 418460 diff --git a/gallery/images/BCI_HER2_1+_00276_test_1+_he.png b/gallery/images/BCI_HER2_1+_00276_test_1+_he.png new file mode 100644 index 0000000000000000000000000000000000000000..1d72b4cad621cd81b8e65ca9505f0f6b6f1aa278 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00276_test_1+_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50868fff527ff0d2773a2b0e570a833f16ccada880cea2416b6ee2e1d646f16f +size 539147 diff --git a/gallery/images/BCI_HER2_1+_00791_test_1+_gen_er.png b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..daeea57f70029d24dd2be1cb7b3d21b4982152d0 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d7010b47a0efd3c5a0a9dd41dfb363fb36fa0d3d0aa7f94bc905e46decef2fd +size 489913 diff --git a/gallery/images/BCI_HER2_1+_00791_test_1+_gen_her2.png b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..122fd9923a0c036ce04b44a1597ff15dcb4290c9 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b4d42a453a446925da699587173d9fb4a8f723be35b65315b9fd08dfb437489 +size 495567 diff --git a/gallery/images/BCI_HER2_1+_00791_test_1+_gen_ki67.png b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..17f177f2c6bca5093f6196fca0d62a0c3c76a533 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02b57becf88737f53d75f98a82c8223895253f9703e2e6b6fc1fddbcfe712558 +size 484957 diff --git a/gallery/images/BCI_HER2_1+_00791_test_1+_gen_pr.png b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..6730b22de98cdc1e9c13af2cc2d4b03fa48656c4 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00791_test_1+_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83662375fb4d3984d06240578cfedfd9c423cad53abb3a834a1abbd5fd5d39cd +size 455866 diff --git a/gallery/images/BCI_HER2_1+_00791_test_1+_gt.png b/gallery/images/BCI_HER2_1+_00791_test_1+_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..4e4b2a75da605d56c0a3fcc088881c2a972516bd --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00791_test_1+_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9de571b040f68a8075457e591194ec965057cefded78a331d780aad92e0a91b7 +size 465135 diff --git a/gallery/images/BCI_HER2_1+_00791_test_1+_he.png b/gallery/images/BCI_HER2_1+_00791_test_1+_he.png new file mode 100644 index 0000000000000000000000000000000000000000..42bc0edde4b2046732f3f89d62d4063337331df3 --- /dev/null +++ b/gallery/images/BCI_HER2_1+_00791_test_1+_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c5b181b97abfadc0f5407b4bcd629b2ebbe73e20401abc70f33441b35cea914 +size 570402 diff --git a/gallery/images/BCI_HER2_2+_00259_test_2+_gen_er.png b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..a637031bc5714802bc09636ebbbafb0563bd1f40 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:969fdbaba07d7da7e041f43ca87d62702b62ed37e0de8ded38c8c678534adfbe +size 447344 diff --git a/gallery/images/BCI_HER2_2+_00259_test_2+_gen_her2.png b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..56c40d791e31ab9121677e7f6932ebf1a7853cf2 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec43b658c02834878de5b23171b8a046aabcdc993f2dfde5ed50a187e53e47f1 +size 454960 diff --git a/gallery/images/BCI_HER2_2+_00259_test_2+_gen_ki67.png b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..1fbbb37676a61c6aff5030fe078a5e97c902bf3a --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a53cdb569ac5722d91d16d58a760f84d2209034cb32299629f76b1ae02d57d11 +size 455003 diff --git a/gallery/images/BCI_HER2_2+_00259_test_2+_gen_pr.png b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..a7d7fa5640f2b125ce217cdfd12929ba3c30ac99 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00259_test_2+_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91b892c18687414106685daa03682bdcdfc6ce38fa52a4fe1349e6bc4c35d1a5 +size 424666 diff --git a/gallery/images/BCI_HER2_2+_00259_test_2+_gt.png b/gallery/images/BCI_HER2_2+_00259_test_2+_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..eaf23f54336b07e9431d8032464d188b8a678e26 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00259_test_2+_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ed9c7407ddeb12ca2e91a753b80b16740fc1528a76f776b45905bf233d5d465 +size 426255 diff --git a/gallery/images/BCI_HER2_2+_00259_test_2+_he.png b/gallery/images/BCI_HER2_2+_00259_test_2+_he.png new file mode 100644 index 0000000000000000000000000000000000000000..478cf037fa2a38cddd4e20c4f6863d27a730ca36 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00259_test_2+_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9244bad717ddb156ad2b27166820bef4d8e7e18480076985ec0191fdff8767b +size 515210 diff --git a/gallery/images/BCI_HER2_2+_00293_test_2+_gen_er.png b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..417691bdd41cb4325249d3343eb9509baf894059 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3c4a3689bde21544f9e4405215508eee5417f612b3f6153e84503a1c7a2d684 +size 501335 diff --git a/gallery/images/BCI_HER2_2+_00293_test_2+_gen_her2.png b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..682336ea133e78250dcda499f390beacaccfa78a --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7baaffd3656d5c1f6200a8054dee201a211889d5f5c95dc8b8ed09c155da3218 +size 468595 diff --git a/gallery/images/BCI_HER2_2+_00293_test_2+_gen_ki67.png b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..2bd031e29672dc37a254f79613e66cc3abe81104 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9323d4924d59b0387c5324d8223201326a4cc470a28d6660182b1c59e547a4c5 +size 459066 diff --git a/gallery/images/BCI_HER2_2+_00293_test_2+_gen_pr.png b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..31b765443a6130b6ff49d8157b3b7b2357d5080e --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00293_test_2+_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e44ded8eb730b8dae1eb9c7cc51f5b9efc07131a6cf268f4552f85094b19af3 +size 444836 diff --git a/gallery/images/BCI_HER2_2+_00293_test_2+_gt.png b/gallery/images/BCI_HER2_2+_00293_test_2+_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..950fe0ce5857978f6ed0caa493ba8c3085db4168 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00293_test_2+_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ac2b306b50b1eb945c46dfa1a47909cdf4bb93b81f659d0090a080ea7a3663a +size 475961 diff --git a/gallery/images/BCI_HER2_2+_00293_test_2+_he.png b/gallery/images/BCI_HER2_2+_00293_test_2+_he.png new file mode 100644 index 0000000000000000000000000000000000000000..ea6c29913ab256b31c5db0f93c92ecaf0aa51022 --- /dev/null +++ b/gallery/images/BCI_HER2_2+_00293_test_2+_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afab90825ca78657d59802893887726befb8ffa66f348f550b12ab020fa6202e +size 540057 diff --git a/gallery/images/BCI_HER2_3+_00220_test_3+_gen_er.png b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..c0c328e1ebb46b87d727fe8e822eb61cfbca1ffa --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:139bcccf67267f052eea9ca5b04aefeee13d135d54be3f865b34d470edb64e53 +size 528812 diff --git a/gallery/images/BCI_HER2_3+_00220_test_3+_gen_her2.png b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..a30160deba43291c146472ef722fc03ab2fdd78e --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:306ef597f5ce18501e0498af2ec0282f8578d844d680b44532aa3620d5e0cbf3 +size 530336 diff --git a/gallery/images/BCI_HER2_3+_00220_test_3+_gen_ki67.png b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..ecaebfb4986300d28d6db9d660e37a62f05d8121 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b61bda4b3e2f149a8e4c9ca88f42e91fe34dbe9e2cdd47bb29b36ee3fe8d40d1 +size 517304 diff --git a/gallery/images/BCI_HER2_3+_00220_test_3+_gen_pr.png b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..f585ecf4d69042cfe8a26f9d8bd6e26583c47141 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00220_test_3+_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3918decb0c055332cd490e478e207ccdd2a32480a1c0ee3bb68f8718ab5a1d3 +size 502579 diff --git a/gallery/images/BCI_HER2_3+_00220_test_3+_gt.png b/gallery/images/BCI_HER2_3+_00220_test_3+_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..2d5a639cb59c43bc3439b251fd4adad731777995 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00220_test_3+_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e1c374e468a8739e9369d6d5123d96ee54d2bdb5582553a3356adb58ccb91ed +size 461306 diff --git a/gallery/images/BCI_HER2_3+_00220_test_3+_he.png b/gallery/images/BCI_HER2_3+_00220_test_3+_he.png new file mode 100644 index 0000000000000000000000000000000000000000..84bdebd11450c41ff7aca3a125092450bc63d9a7 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00220_test_3+_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eefa14a632fb7629cd4fa29431f2f376325858cd5c0ed763a9ecf05ff6581965 +size 561587 diff --git a/gallery/images/BCI_HER2_3+_00277_test_3+_gen_er.png b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..0c0e7a11a8ac8fb7e8afc112cbeb8a89db915de6 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb891640988f393c3d237049a21892d8b693f366458c041f32c07b328a9b7729 +size 573792 diff --git a/gallery/images/BCI_HER2_3+_00277_test_3+_gen_her2.png b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..a1c1f9b44e302fbcf519db4dd1d7efd5727e7605 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d5e1774b08a110ab48c7355a87094be61f1200faa5251e25a9e3c185dd407e +size 575224 diff --git a/gallery/images/BCI_HER2_3+_00277_test_3+_gen_ki67.png b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..feee93ba7cd2094d738e004e76cb3fc7faebcb58 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4623157751e69b12b78c994b0e8cc6f03a57c47e1c9638d870b7db0bc2450fe +size 564634 diff --git a/gallery/images/BCI_HER2_3+_00277_test_3+_gen_pr.png b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..39333a50d02b7a79f10e1a21cff0fbb2278ad623 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00277_test_3+_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:418a0805d44a33ae41330a8c8a0fc40420a421d28b140680dbe41731a7167950 +size 553819 diff --git a/gallery/images/BCI_HER2_3+_00277_test_3+_gt.png b/gallery/images/BCI_HER2_3+_00277_test_3+_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..68d13a113c01baa017b7539a8ed5b2d0bedae0c8 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00277_test_3+_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45f2757b9fa0e2b427c40ee10a00a5c4101d4eb824dc1bbd50e193d690f29641 +size 468337 diff --git a/gallery/images/BCI_HER2_3+_00277_test_3+_he.png b/gallery/images/BCI_HER2_3+_00277_test_3+_he.png new file mode 100644 index 0000000000000000000000000000000000000000..5ee06b986aded8320243df3527ac3e4a0a14fc79 --- /dev/null +++ b/gallery/images/BCI_HER2_3+_00277_test_3+_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af08a63cb16f938213af1d6195a42804dc4324004b0c476bfc9720c15e129f37 +size 606372 diff --git a/gallery/images/MIST_ER_35M2101733_7_3_gen_er.png b/gallery/images/MIST_ER_35M2101733_7_3_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..79f29e49805a6bddc1d72821393ff269be5bfe15 --- /dev/null +++ b/gallery/images/MIST_ER_35M2101733_7_3_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f084aa5ae48a218ccbdca0fb5dd105e5ea50969c42c6d01e2196b0dd3cc4e3dd +size 396943 diff --git a/gallery/images/MIST_ER_35M2101733_7_3_gen_her2.png b/gallery/images/MIST_ER_35M2101733_7_3_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..6d6d28ba4bb683aeeaaf811e7ded0252cea64c6d --- /dev/null +++ b/gallery/images/MIST_ER_35M2101733_7_3_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c45f3d9b08da18c3c74888c0b55282c6c1b9b0d1a39af1c2f7173720b315106 +size 417578 diff --git a/gallery/images/MIST_ER_35M2101733_7_3_gen_ki67.png b/gallery/images/MIST_ER_35M2101733_7_3_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..4e9229b4f62a97e15b9e1a950869a076aceff654 --- /dev/null +++ b/gallery/images/MIST_ER_35M2101733_7_3_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6082155a19e23aa37fb3098d9dbf2794f46c44682a4eed20d3b7fdb0bcfa4bc2 +size 384326 diff --git a/gallery/images/MIST_ER_35M2101733_7_3_gen_pr.png b/gallery/images/MIST_ER_35M2101733_7_3_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..208785aeb9a44526203b6f051c8b6faefeaa934b --- /dev/null +++ b/gallery/images/MIST_ER_35M2101733_7_3_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c6b8236e77361a5a55e09ccce1aecc11f7c85b1b8054b0970b250c8f55535e4 +size 410296 diff --git a/gallery/images/MIST_ER_35M2101733_7_3_gt.png b/gallery/images/MIST_ER_35M2101733_7_3_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..e6738c990f0d16b391332cd7e0a8e1b01d08797a --- /dev/null +++ b/gallery/images/MIST_ER_35M2101733_7_3_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58651a906063d92387ffd1b2315b48ba7ae80e23370b710ff2b269d2d39e720b +size 432277 diff --git a/gallery/images/MIST_ER_35M2101733_7_3_he.png b/gallery/images/MIST_ER_35M2101733_7_3_he.png new file mode 100644 index 0000000000000000000000000000000000000000..f626d6b3e22a43736fc613aff95b0db255ad2dad --- /dev/null +++ b/gallery/images/MIST_ER_35M2101733_7_3_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7107c4bdbc2d64a18518dedd61b3dfec1c202e8f8a5d8adaf2a05d3e9328660 +size 483871 diff --git a/gallery/images/MIST_ER_40M2101566_28_8_gen_er.png b/gallery/images/MIST_ER_40M2101566_28_8_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..4e1c85149b9f560875ddb7d4fcee266d10d299cd --- /dev/null +++ b/gallery/images/MIST_ER_40M2101566_28_8_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a15b89bb7a77ac23bd882e8ab240654db2bbf20b7170899b660c06a69f4e833 +size 460257 diff --git a/gallery/images/MIST_ER_40M2101566_28_8_gen_her2.png b/gallery/images/MIST_ER_40M2101566_28_8_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..b9f656c32eea8b75bdc3fb92618346747b4937a5 --- /dev/null +++ b/gallery/images/MIST_ER_40M2101566_28_8_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:138b6e4e5c047cfb084f1cc3bf36a7d192007ec2211477a4dec39c42771aa1dc +size 467533 diff --git a/gallery/images/MIST_ER_40M2101566_28_8_gen_ki67.png b/gallery/images/MIST_ER_40M2101566_28_8_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..4ff5345a26596fdabdcc06ae4ff5bcaea632fce0 --- /dev/null +++ b/gallery/images/MIST_ER_40M2101566_28_8_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e95a75cc37ab06cdaaea9e53a4a8ea3a42bb08acc49a767b148be6b129d49e64 +size 454109 diff --git a/gallery/images/MIST_ER_40M2101566_28_8_gen_pr.png b/gallery/images/MIST_ER_40M2101566_28_8_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..af3a74970d3c6c03198066fefe9f8548e2f346ae --- /dev/null +++ b/gallery/images/MIST_ER_40M2101566_28_8_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b103070fea202f7e6de21e60070ca77189b3ccbf53ec02b6632030a0c7ad492d +size 451342 diff --git a/gallery/images/MIST_ER_40M2101566_28_8_gt.png b/gallery/images/MIST_ER_40M2101566_28_8_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..0aa4d1d7b3111efb4777969b9a8666938eb9929b --- /dev/null +++ b/gallery/images/MIST_ER_40M2101566_28_8_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f631f91cd843a91b4a97cc9cc814065d2c4349b163b104ed1c53cc0978dae465 +size 468728 diff --git a/gallery/images/MIST_ER_40M2101566_28_8_he.png b/gallery/images/MIST_ER_40M2101566_28_8_he.png new file mode 100644 index 0000000000000000000000000000000000000000..ec32b9a2daf38f8ce770d9cb229243ff13aefe19 --- /dev/null +++ b/gallery/images/MIST_ER_40M2101566_28_8_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e42057699d85549ddefbf3229f43d32923b22a080a0a0426e362550f762612c +size 536554 diff --git a/gallery/images/MIST_HER2_19M2102438_35_28_gen_er.png b/gallery/images/MIST_HER2_19M2102438_35_28_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..6083b6d3b80a458f8f98fce747a783f9df175ebe --- /dev/null +++ b/gallery/images/MIST_HER2_19M2102438_35_28_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bfd1ed93a06825099a864627541dea8adca9c50c9039d94adf1a77bd2adba27a +size 535119 diff --git a/gallery/images/MIST_HER2_19M2102438_35_28_gen_her2.png b/gallery/images/MIST_HER2_19M2102438_35_28_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..1d09354bd5939bc96db8c954abf1747b2d68558a --- /dev/null +++ b/gallery/images/MIST_HER2_19M2102438_35_28_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0a3de12457872c60dd41527285f6b797d8974e010f434f523dd798b6df3a8d9 +size 522453 diff --git a/gallery/images/MIST_HER2_19M2102438_35_28_gen_ki67.png b/gallery/images/MIST_HER2_19M2102438_35_28_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..a6116c8901c64444356d8b42a5dc5a82370421da --- /dev/null +++ b/gallery/images/MIST_HER2_19M2102438_35_28_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ed6453882b8c83196e7143593fa0e9ec1abde16f81da1538e4694949728aaa6 +size 504066 diff --git a/gallery/images/MIST_HER2_19M2102438_35_28_gen_pr.png b/gallery/images/MIST_HER2_19M2102438_35_28_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..73e5a66af7d3d4d4a58ab96f5e3bb81c4d75b9c2 --- /dev/null +++ b/gallery/images/MIST_HER2_19M2102438_35_28_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aef304babaa5c6c46088fe2cef2709f638b7294dfbe51d80d876703252a6b496 +size 542238 diff --git a/gallery/images/MIST_HER2_19M2102438_35_28_gt.png b/gallery/images/MIST_HER2_19M2102438_35_28_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..2dc64f4b2e7ae61ddd4986ba0ddeb149539edb06 --- /dev/null +++ b/gallery/images/MIST_HER2_19M2102438_35_28_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e280c0febbb027c0e063cb276f73df7bf1f7543d489a33e28c274f2b0076cae +size 567182 diff --git a/gallery/images/MIST_HER2_19M2102438_35_28_he.png b/gallery/images/MIST_HER2_19M2102438_35_28_he.png new file mode 100644 index 0000000000000000000000000000000000000000..6090b9b69a675494dc568a706817d479d8ec610d --- /dev/null +++ b/gallery/images/MIST_HER2_19M2102438_35_28_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84d3e5caf52727383ad70d03bdd475f8d0b8353e4fe13ab91aa686ddc62de961 +size 583985 diff --git a/gallery/images/MIST_HER2_67M2100642_15_18_gen_er.png b/gallery/images/MIST_HER2_67M2100642_15_18_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..49657ac360e2293e555b2be2dc5dc452c636f906 --- /dev/null +++ b/gallery/images/MIST_HER2_67M2100642_15_18_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5672b9603efc3930e72bcfbd8013f02c52bbafe2e832293f5f5dfe5cd63f07e3 +size 464348 diff --git a/gallery/images/MIST_HER2_67M2100642_15_18_gen_her2.png b/gallery/images/MIST_HER2_67M2100642_15_18_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..635013acece0e4bb3f59f0685aca837da730a743 --- /dev/null +++ b/gallery/images/MIST_HER2_67M2100642_15_18_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec0539532d9e487f0bc79e56d9f1fd4c8c610c291baeb296e57352368f8486ed +size 437226 diff --git a/gallery/images/MIST_HER2_67M2100642_15_18_gen_ki67.png b/gallery/images/MIST_HER2_67M2100642_15_18_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..faf2ea0863e8e5b858b35f6a59195f6d8a307ba1 --- /dev/null +++ b/gallery/images/MIST_HER2_67M2100642_15_18_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0390b9a576eab3e1597aa4b78c610a80f3c33ee0f34b24dd12b5b018e98c79d4 +size 423039 diff --git a/gallery/images/MIST_HER2_67M2100642_15_18_gen_pr.png b/gallery/images/MIST_HER2_67M2100642_15_18_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..1e26c6d1a75157ba760d00b3e46390c28a8b316e --- /dev/null +++ b/gallery/images/MIST_HER2_67M2100642_15_18_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74f4ec5c69a12d6430a1aade998a95a488e190abf225406fb24f02cbf266bdfa +size 459437 diff --git a/gallery/images/MIST_HER2_67M2100642_15_18_gt.png b/gallery/images/MIST_HER2_67M2100642_15_18_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..634c1c46e916b756257045e9ab71bafe74908196 --- /dev/null +++ b/gallery/images/MIST_HER2_67M2100642_15_18_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b970fb417bd6a7b707756ec6da0eab12cc8f83e64bd290e0cd22741bd6e68207 +size 437301 diff --git a/gallery/images/MIST_HER2_67M2100642_15_18_he.png b/gallery/images/MIST_HER2_67M2100642_15_18_he.png new file mode 100644 index 0000000000000000000000000000000000000000..a218897812c9284bb173482a67d15999660c8adb --- /dev/null +++ b/gallery/images/MIST_HER2_67M2100642_15_18_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d8df18932b56b83d03aff8e4fd6f9dcdeaf464de4c2de39e8e34142bdfe49f2 +size 519715 diff --git a/gallery/images/MIST_Ki67_10M2102916_10_20_gen_er.png b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..d4e3dd8ab9711168d07c5d082f887a45f85776e1 --- /dev/null +++ b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c56514c184dd41774c0d28a0b04bb5a7bf795484be508cd986eccc1d5d33678 +size 550739 diff --git a/gallery/images/MIST_Ki67_10M2102916_10_20_gen_her2.png b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..0f29f5e85ea1d5c3a6f4467c45124babd53e1f4b --- /dev/null +++ b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:899a325e87900d67c20cf3dc01602ce2e7c09ff1f6256abb89be5d9ba4047818 +size 554013 diff --git a/gallery/images/MIST_Ki67_10M2102916_10_20_gen_ki67.png b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..4e2264d0586b004e15327223f1abd7057ab32464 --- /dev/null +++ b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:179465e65e234835142485d826d882ba0edc0aa87f2fb5fbe31333b92d4b78ec +size 571168 diff --git a/gallery/images/MIST_Ki67_10M2102916_10_20_gen_pr.png b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..e28c90fea3eb7015f2a963f51e5b51685f995c54 --- /dev/null +++ b/gallery/images/MIST_Ki67_10M2102916_10_20_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96c165a39d1ec3ac4a39f949fac4dab8ffd030d966ca982af9d0cf9607d4ee76 +size 560383 diff --git a/gallery/images/MIST_Ki67_10M2102916_10_20_gt.png b/gallery/images/MIST_Ki67_10M2102916_10_20_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..f0321c23bd66c613445f4bf91260ed6fc351be4e --- /dev/null +++ b/gallery/images/MIST_Ki67_10M2102916_10_20_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ed01ece975fda62108d27b059e4c3707f16aa6f9d2e90ee9fd38d20e1b6fc24 +size 555055 diff --git a/gallery/images/MIST_Ki67_10M2102916_10_20_he.png b/gallery/images/MIST_Ki67_10M2102916_10_20_he.png new file mode 100644 index 0000000000000000000000000000000000000000..e5132270a715930f01b4dbbebb016ce23292ecb7 --- /dev/null +++ b/gallery/images/MIST_Ki67_10M2102916_10_20_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13f6f0b573fd7bbb668e3bd52a9291872dbc06191e9c28009673c2b3e61033da +size 607497 diff --git a/gallery/images/MIST_Ki67_80M2100377_14_29_gen_er.png b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..85c25b61d72f848ae7c28045d253c97b3ffc1a15 --- /dev/null +++ b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10976b3d592ca673295329e6b30ad107dd97227a7c3b571a8a31ab5d9c4e7069 +size 563442 diff --git a/gallery/images/MIST_Ki67_80M2100377_14_29_gen_her2.png b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..fe5f99619f52507c2f37899e124fa6b7c30c53ae --- /dev/null +++ b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2694679fd3236577a2b2d3a6b3700364ab76be3609312fa1621fd454736c38aa +size 565467 diff --git a/gallery/images/MIST_Ki67_80M2100377_14_29_gen_ki67.png b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..c15439120732fe5e1818427e23358b642a8c70c2 --- /dev/null +++ b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a283a95fb9a0425757f253df0d386ba5fb4e24309ccdd0fb5e4549eabbc39b6 +size 555820 diff --git a/gallery/images/MIST_Ki67_80M2100377_14_29_gen_pr.png b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..e8b50050ba7d032b4dac0a99eebef5c001a9e036 --- /dev/null +++ b/gallery/images/MIST_Ki67_80M2100377_14_29_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a7605dc7b8e6ef27324f4428e36fab6146b091f23164f8744d9c5af0a6cd13c +size 560133 diff --git a/gallery/images/MIST_Ki67_80M2100377_14_29_gt.png b/gallery/images/MIST_Ki67_80M2100377_14_29_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..bfd04defc5d99dd2a54ed8b4b9affa43853f471c --- /dev/null +++ b/gallery/images/MIST_Ki67_80M2100377_14_29_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a32518be61d5a9773122b1d2e643646d2463dddba9b9778c9b0bd7816043c546 +size 557575 diff --git a/gallery/images/MIST_Ki67_80M2100377_14_29_he.png b/gallery/images/MIST_Ki67_80M2100377_14_29_he.png new file mode 100644 index 0000000000000000000000000000000000000000..1626126d096fd8d3c0e317472f3a0fd5cacc47f0 --- /dev/null +++ b/gallery/images/MIST_Ki67_80M2100377_14_29_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e45a13831f08161bfef22a3333fd37c086339af3c73ffb2569e1dcc62b4c242 +size 609805 diff --git a/gallery/images/MIST_PR_17M2102569_15_14_gen_er.png b/gallery/images/MIST_PR_17M2102569_15_14_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..cf9dc344f0247a18b9826b183b5b502e8b17a0cb --- /dev/null +++ b/gallery/images/MIST_PR_17M2102569_15_14_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be5964cd2c3a0d0ee09e0e4694b53535dfcfeb44fb239788cca875a8c1366096 +size 457576 diff --git a/gallery/images/MIST_PR_17M2102569_15_14_gen_her2.png b/gallery/images/MIST_PR_17M2102569_15_14_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..4ae7ffe0a08dd3d75c06ff26f8a105edeea3bd7b --- /dev/null +++ b/gallery/images/MIST_PR_17M2102569_15_14_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a66517a1dac9448dc0b1b09cf2813e3cb6a346b45a43f99c883e6dc06433e68d +size 476432 diff --git a/gallery/images/MIST_PR_17M2102569_15_14_gen_ki67.png b/gallery/images/MIST_PR_17M2102569_15_14_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..67fe8b21abcf3c110f4375a03f1fb340d6dcb186 --- /dev/null +++ b/gallery/images/MIST_PR_17M2102569_15_14_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef0808b1142b000eb2ec30df15e0f5f0058fade2098b2d3da3b08eac6bfb194a +size 445830 diff --git a/gallery/images/MIST_PR_17M2102569_15_14_gen_pr.png b/gallery/images/MIST_PR_17M2102569_15_14_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..8b8175c83a23af333b682552eb11b83c6e89b01d --- /dev/null +++ b/gallery/images/MIST_PR_17M2102569_15_14_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48afa26cf1ab35454079e3a81d02dce97c9761f3af29527b1f36ae4cc6f923f0 +size 497701 diff --git a/gallery/images/MIST_PR_17M2102569_15_14_gt.png b/gallery/images/MIST_PR_17M2102569_15_14_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..a4e841067015b3bb71e11a20cd12d498f900ba47 --- /dev/null +++ b/gallery/images/MIST_PR_17M2102569_15_14_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:159ccb7aafb71b646f60dd00761313f52736ff318b2a9b841ee5197c05b43bd6 +size 515776 diff --git a/gallery/images/MIST_PR_17M2102569_15_14_he.png b/gallery/images/MIST_PR_17M2102569_15_14_he.png new file mode 100644 index 0000000000000000000000000000000000000000..90c01902492c69804386bb8d17425d2a6cb567b6 --- /dev/null +++ b/gallery/images/MIST_PR_17M2102569_15_14_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccc9630dba43241ceef6a926e9dbb8d462c62844f7cc2716f56464001e0c1302 +size 571879 diff --git a/gallery/images/MIST_PR_28M2101987_14_22_gen_er.png b/gallery/images/MIST_PR_28M2101987_14_22_gen_er.png new file mode 100644 index 0000000000000000000000000000000000000000..89d148c3e6ddc6b8908cdbccb655b13fb110fef1 --- /dev/null +++ b/gallery/images/MIST_PR_28M2101987_14_22_gen_er.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4a0230839ff74897065dc6aa0150cc69b04e1703e8af117d18eb33bdaffa1d5 +size 552814 diff --git a/gallery/images/MIST_PR_28M2101987_14_22_gen_her2.png b/gallery/images/MIST_PR_28M2101987_14_22_gen_her2.png new file mode 100644 index 0000000000000000000000000000000000000000..3dad5d40a739adbb45b53e16f74ee6a14ddc180f --- /dev/null +++ b/gallery/images/MIST_PR_28M2101987_14_22_gen_her2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18dec03b49486ff51c587869347d2e0bcc9ea5d25278beb9c61edef34672b8a6 +size 551314 diff --git a/gallery/images/MIST_PR_28M2101987_14_22_gen_ki67.png b/gallery/images/MIST_PR_28M2101987_14_22_gen_ki67.png new file mode 100644 index 0000000000000000000000000000000000000000..5a91962c4ae08c428ae64ba670899ba5bd6e4225 --- /dev/null +++ b/gallery/images/MIST_PR_28M2101987_14_22_gen_ki67.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cc4e18a125b3bbcc74db80aded5b63ff92ae20e8850a3a04bd657d241c43251 +size 540178 diff --git a/gallery/images/MIST_PR_28M2101987_14_22_gen_pr.png b/gallery/images/MIST_PR_28M2101987_14_22_gen_pr.png new file mode 100644 index 0000000000000000000000000000000000000000..c3cddcfa4b1d40d60d845551a39b00caed92ebcb --- /dev/null +++ b/gallery/images/MIST_PR_28M2101987_14_22_gen_pr.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:175cc6f97112ba986bb563d5f20f87327fe34b22334f9c3bc8a68fd09256ff9e +size 540688 diff --git a/gallery/images/MIST_PR_28M2101987_14_22_gt.png b/gallery/images/MIST_PR_28M2101987_14_22_gt.png new file mode 100644 index 0000000000000000000000000000000000000000..0dbea5054933c03a9b71e2db2c98f89663c345e7 --- /dev/null +++ b/gallery/images/MIST_PR_28M2101987_14_22_gt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2154e2684b7e884b7bd636f5b5f56c1a50616b1a8c554df2a1bd581fc17604a8 +size 569890 diff --git a/gallery/images/MIST_PR_28M2101987_14_22_he.png b/gallery/images/MIST_PR_28M2101987_14_22_he.png new file mode 100644 index 0000000000000000000000000000000000000000..b4316ff1790314392beeabd83a1e5b87c88f8da4 --- /dev/null +++ b/gallery/images/MIST_PR_28M2101987_14_22_he.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a01ba13a34b963e61d56177728247634e95fd2bd7534b37a80654007bbb2b38 +size 603765 diff --git a/gallery/metadata.json b/gallery/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..c4d6468c063488e15ffe8283630d9b0c748f7af4 --- /dev/null +++ b/gallery/metadata.json @@ -0,0 +1,162 @@ +{ + "BCI_HER2_0_00198_test_0": { + "he": "BCI_HER2_0_00198_test_0_he.png", + "gt": "BCI_HER2_0_00198_test_0_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_0_00198_test_0_gen_her2.png", + "gen_ki67": "BCI_HER2_0_00198_test_0_gen_ki67.png", + "gen_er": "BCI_HER2_0_00198_test_0_gen_er.png", + "gen_pr": "BCI_HER2_0_00198_test_0_gen_pr.png" + }, + "BCI_HER2_0_00013_test_0": { + "he": "BCI_HER2_0_00013_test_0_he.png", + "gt": "BCI_HER2_0_00013_test_0_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_0_00013_test_0_gen_her2.png", + "gen_ki67": "BCI_HER2_0_00013_test_0_gen_ki67.png", + "gen_er": "BCI_HER2_0_00013_test_0_gen_er.png", + "gen_pr": "BCI_HER2_0_00013_test_0_gen_pr.png" + }, + "BCI_HER2_1+_00791_test_1+": { + "he": "BCI_HER2_1+_00791_test_1+_he.png", + "gt": "BCI_HER2_1+_00791_test_1+_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_1+_00791_test_1+_gen_her2.png", + "gen_ki67": "BCI_HER2_1+_00791_test_1+_gen_ki67.png", + "gen_er": "BCI_HER2_1+_00791_test_1+_gen_er.png", + "gen_pr": "BCI_HER2_1+_00791_test_1+_gen_pr.png" + }, + "BCI_HER2_1+_00276_test_1+": { + "he": "BCI_HER2_1+_00276_test_1+_he.png", + "gt": "BCI_HER2_1+_00276_test_1+_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_1+_00276_test_1+_gen_her2.png", + "gen_ki67": "BCI_HER2_1+_00276_test_1+_gen_ki67.png", + "gen_er": "BCI_HER2_1+_00276_test_1+_gen_er.png", + "gen_pr": "BCI_HER2_1+_00276_test_1+_gen_pr.png" + }, + "BCI_HER2_2+_00293_test_2+": { + "he": "BCI_HER2_2+_00293_test_2+_he.png", + "gt": "BCI_HER2_2+_00293_test_2+_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_2+_00293_test_2+_gen_her2.png", + "gen_ki67": "BCI_HER2_2+_00293_test_2+_gen_ki67.png", + "gen_er": "BCI_HER2_2+_00293_test_2+_gen_er.png", + "gen_pr": "BCI_HER2_2+_00293_test_2+_gen_pr.png" + }, + "BCI_HER2_2+_00259_test_2+": { + "he": "BCI_HER2_2+_00259_test_2+_he.png", + "gt": "BCI_HER2_2+_00259_test_2+_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_2+_00259_test_2+_gen_her2.png", + "gen_ki67": "BCI_HER2_2+_00259_test_2+_gen_ki67.png", + "gen_er": "BCI_HER2_2+_00259_test_2+_gen_er.png", + "gen_pr": "BCI_HER2_2+_00259_test_2+_gen_pr.png" + }, + "BCI_HER2_3+_00277_test_3+": { + "he": "BCI_HER2_3+_00277_test_3+_he.png", + "gt": "BCI_HER2_3+_00277_test_3+_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_3+_00277_test_3+_gen_her2.png", + "gen_ki67": "BCI_HER2_3+_00277_test_3+_gen_ki67.png", + "gen_er": "BCI_HER2_3+_00277_test_3+_gen_er.png", + "gen_pr": "BCI_HER2_3+_00277_test_3+_gen_pr.png" + }, + "BCI_HER2_3+_00220_test_3+": { + "he": "BCI_HER2_3+_00220_test_3+_he.png", + "gt": "BCI_HER2_3+_00220_test_3+_gt.png", + "gt_stain": "HER2", + "source": "BCI", + "gen_her2": "BCI_HER2_3+_00220_test_3+_gen_her2.png", + "gen_ki67": "BCI_HER2_3+_00220_test_3+_gen_ki67.png", + "gen_er": "BCI_HER2_3+_00220_test_3+_gen_er.png", + "gen_pr": "BCI_HER2_3+_00220_test_3+_gen_pr.png" + }, + "MIST_HER2_67M2100642_15_18": { + "he": "MIST_HER2_67M2100642_15_18_he.png", + "gt": "MIST_HER2_67M2100642_15_18_gt.png", + "gt_stain": "HER2", + "source": "MIST", + "gen_her2": "MIST_HER2_67M2100642_15_18_gen_her2.png", + "gen_ki67": "MIST_HER2_67M2100642_15_18_gen_ki67.png", + "gen_er": "MIST_HER2_67M2100642_15_18_gen_er.png", + "gen_pr": "MIST_HER2_67M2100642_15_18_gen_pr.png" + }, + "MIST_HER2_19M2102438_35_28": { + "he": "MIST_HER2_19M2102438_35_28_he.png", + "gt": "MIST_HER2_19M2102438_35_28_gt.png", + "gt_stain": "HER2", + "source": "MIST", + "gen_her2": "MIST_HER2_19M2102438_35_28_gen_her2.png", + "gen_ki67": "MIST_HER2_19M2102438_35_28_gen_ki67.png", + "gen_er": "MIST_HER2_19M2102438_35_28_gen_er.png", + "gen_pr": "MIST_HER2_19M2102438_35_28_gen_pr.png" + }, + "MIST_Ki67_10M2102916_10_20": { + "he": "MIST_Ki67_10M2102916_10_20_he.png", + "gt": "MIST_Ki67_10M2102916_10_20_gt.png", + "gt_stain": "Ki67", + "source": "MIST", + "gen_her2": "MIST_Ki67_10M2102916_10_20_gen_her2.png", + "gen_ki67": "MIST_Ki67_10M2102916_10_20_gen_ki67.png", + "gen_er": "MIST_Ki67_10M2102916_10_20_gen_er.png", + "gen_pr": "MIST_Ki67_10M2102916_10_20_gen_pr.png" + }, + "MIST_Ki67_80M2100377_14_29": { + "he": "MIST_Ki67_80M2100377_14_29_he.png", + "gt": "MIST_Ki67_80M2100377_14_29_gt.png", + "gt_stain": "Ki67", + "source": "MIST", + "gen_her2": "MIST_Ki67_80M2100377_14_29_gen_her2.png", + "gen_ki67": "MIST_Ki67_80M2100377_14_29_gen_ki67.png", + "gen_er": "MIST_Ki67_80M2100377_14_29_gen_er.png", + "gen_pr": "MIST_Ki67_80M2100377_14_29_gen_pr.png" + }, + "MIST_ER_40M2101566_28_8": { + "he": "MIST_ER_40M2101566_28_8_he.png", + "gt": "MIST_ER_40M2101566_28_8_gt.png", + "gt_stain": "ER", + "source": "MIST", + "gen_her2": "MIST_ER_40M2101566_28_8_gen_her2.png", + "gen_ki67": "MIST_ER_40M2101566_28_8_gen_ki67.png", + "gen_er": "MIST_ER_40M2101566_28_8_gen_er.png", + "gen_pr": "MIST_ER_40M2101566_28_8_gen_pr.png" + }, + "MIST_ER_35M2101733_7_3": { + "he": "MIST_ER_35M2101733_7_3_he.png", + "gt": "MIST_ER_35M2101733_7_3_gt.png", + "gt_stain": "ER", + "source": "MIST", + "gen_her2": "MIST_ER_35M2101733_7_3_gen_her2.png", + "gen_ki67": "MIST_ER_35M2101733_7_3_gen_ki67.png", + "gen_er": "MIST_ER_35M2101733_7_3_gen_er.png", + "gen_pr": "MIST_ER_35M2101733_7_3_gen_pr.png" + }, + "MIST_PR_28M2101987_14_22": { + "he": "MIST_PR_28M2101987_14_22_he.png", + "gt": "MIST_PR_28M2101987_14_22_gt.png", + "gt_stain": "PR", + "source": "MIST", + "gen_her2": "MIST_PR_28M2101987_14_22_gen_her2.png", + "gen_ki67": "MIST_PR_28M2101987_14_22_gen_ki67.png", + "gen_er": "MIST_PR_28M2101987_14_22_gen_er.png", + "gen_pr": "MIST_PR_28M2101987_14_22_gen_pr.png" + }, + "MIST_PR_17M2102569_15_14": { + "he": "MIST_PR_17M2102569_15_14_he.png", + "gt": "MIST_PR_17M2102569_15_14_gt.png", + "gt_stain": "PR", + "source": "MIST", + "gen_her2": "MIST_PR_17M2102569_15_14_gen_her2.png", + "gen_ki67": "MIST_PR_17M2102569_15_14_gen_ki67.png", + "gen_er": "MIST_PR_17M2102569_15_14_gen_er.png", + "gen_pr": "MIST_PR_17M2102569_15_14_gen_pr.png" + } +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..56069bbd78b45f6d8129e58eb536b7b66bbbea86 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch>=2.0 +torchvision>=0.15 +pytorch-lightning>=2.0 +timm>=0.9 +lpips>=0.1.4 +torchmetrics>=1.0 +scipy>=1.10 +scikit-learn>=1.2 +Pillow>=9.0 +numpy>=1.24 +huggingface_hub diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/bci_dataset.py b/src/data/bci_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b15d379047d6632be3e8330696505df3d9181a1e --- /dev/null +++ b/src/data/bci_dataset.py @@ -0,0 +1,315 @@ +""" +Crop-based dataset loaders for training on random 512x512 crops from native 1024x1024. + +Both BCI and MIST variants share the same crop + augmentation logic. +UNI features are extracted on-the-fly on GPU (not pre-computed). +""" + +import os +import random +from pathlib import Path +from typing import Optional, Tuple + +import torch +from torch.utils.data import Dataset, DataLoader +from PIL import Image +import torchvision.transforms as T +import torchvision.transforms.functional as TF +import pytorch_lightning as pl + + +class CropPairedDataset(Dataset): + """Base class for random-crop paired H&E/IHC datasets. + + Loads 1024x1024 images, takes a random 512x512 crop (same position for both), + and returns the crop + a UNI-ready version for on-the-fly feature extraction. + """ + + def __init__( + self, + he_dir: str, + ihc_dir: str, + image_size: Tuple[int, int] = (512, 512), + crop_size: int = 512, + augment: bool = False, + labels: Optional[list] = None, + null_class: int = 4, + ): + self.he_dir = Path(he_dir) + self.ihc_dir = Path(ihc_dir) + self.image_size = image_size + self.crop_size = crop_size + self.augment = augment + self.null_class = null_class + self.labels = labels + + # UNI normalization (ImageNet stats, 224x224 per sub-crop) + self.uni_crop_transform = T.Compose([ + T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + def _random_crop_pair(self, he_img, ihc_img): + """Take the same random 512x512 crop from both images.""" + w, h = he_img.size + if w < self.crop_size or h < self.crop_size: + raise ValueError( + f"Image size {w}x{h} smaller than crop size {self.crop_size}" + ) + if w == self.crop_size and h == self.crop_size: + return he_img, ihc_img + + left = random.randint(0, w - self.crop_size) + top = random.randint(0, h - self.crop_size) + he_crop = he_img.crop((left, top, left + self.crop_size, top + self.crop_size)) + ihc_crop = ihc_img.crop((left, top, left + self.crop_size, top + self.crop_size)) + return he_crop, ihc_crop + + def _prepare_uni_sub_crops(self, he_pil): + """Split 512x512 PIL crop into 4x4 sub-crops, each resized to 224x224 with UNI normalization. + + Returns: [16, 3, 224, 224] tensor ready for UNI forward pass on GPU. + """ + w, h = he_pil.size + num_crops = 4 + cw = w // num_crops + ch = h // num_crops + + sub_crops = [] + for i in range(num_crops): + for j in range(num_crops): + left = j * cw + top = i * ch + sub = he_pil.crop((left, top, left + cw, top + ch)) + sub_crops.append(self.uni_crop_transform(sub)) + + return torch.stack(sub_crops) # [16, 3, 224, 224] + + def _apply_paired_augmentations(self, he_img, ihc_img): + """Apply identical spatial transforms to both images.""" + if random.random() > 0.5: + he_img = TF.hflip(he_img) + ihc_img = TF.hflip(ihc_img) + if random.random() > 0.5: + he_img = TF.vflip(he_img) + ihc_img = TF.vflip(ihc_img) + if random.random() > 0.5: + k = random.choice([1, 2, 3]) + he_img = TF.rotate(he_img, k * 90) + ihc_img = TF.rotate(ihc_img, k * 90) + if random.random() > 0.7: + angle = random.uniform(-15, 15) + translate = [random.uniform(-0.05, 0.05) * self.image_size[1], + random.uniform(-0.05, 0.05) * self.image_size[0]] + scale = random.uniform(0.9, 1.1) + he_img = TF.affine(he_img, angle, translate, scale, shear=0, + interpolation=T.InterpolationMode.BILINEAR) + ihc_img = TF.affine(ihc_img, angle, translate, scale, shear=0, + interpolation=T.InterpolationMode.BILINEAR) + return he_img, ihc_img + + def _apply_he_color_augmentation(self, he_img): + """Apply color jitter to H&E only (simulates staining variability).""" + if random.random() > 0.5: + he_img = TF.adjust_brightness(he_img, random.uniform(0.9, 1.1)) + if random.random() > 0.5: + he_img = TF.adjust_contrast(he_img, random.uniform(0.9, 1.1)) + if random.random() > 0.5: + he_img = TF.adjust_saturation(he_img, random.uniform(0.9, 1.1)) + return he_img + + def _process_pair(self, he_img, ihc_img, label, filename): + """Common processing: crop -> augment -> tensorize -> UNI sub-crops. + + Returns: (he_tensor, ihc_tensor, uni_sub_crops, label, filename) + - he_tensor: [3, 512, 512] in [-1, 1] + - ihc_tensor: [3, 512, 512] in [-1, 1] + - uni_sub_crops: [16, 3, 224, 224] with ImageNet normalization + """ + # Random crop (same position for both) + he_crop, ihc_crop = self._random_crop_pair(he_img, ihc_img) + + # Augmentations (applied to PIL before UNI extraction, so features match) + if self.augment: + he_crop, ihc_crop = self._apply_paired_augmentations(he_crop, ihc_crop) + he_aug = self._apply_he_color_augmentation(he_crop) + else: + he_aug = he_crop + + # Prepare UNI sub-crops from the augmented H&E crop + uni_sub_crops = self._prepare_uni_sub_crops(he_aug) + + # Convert to training tensors [-1, 1] + he_tensor = TF.normalize(TF.to_tensor(he_aug), [0.5]*3, [0.5]*3) + ihc_tensor = TF.normalize(TF.to_tensor(ihc_crop), [0.5]*3, [0.5]*3) + + return he_tensor, ihc_tensor, uni_sub_crops, label, filename + + +class BCICropDataset(CropPairedDataset): + """BCI dataset with random 512 crops from 1024x1024 native images.""" + + HER2_LABEL_MAP = {'0': 0, '1+': 1, '2+': 2, '3+': 3} + + def __init__(self, he_dir, ihc_dir, image_size=(512, 512), + crop_size=512, augment=False): + super().__init__(he_dir, ihc_dir, image_size, crop_size, augment) + + self.he_images = sorted([f for f in os.listdir(he_dir) if f.endswith('.png')]) + self.ihc_images = sorted([f for f in os.listdir(ihc_dir) if f.endswith('.png')]) + assert len(self.he_images) == len(self.ihc_images) + + self.labels = [self._parse_label(f) for f in self.he_images] + + from collections import Counter + dist = Counter(self.labels) + print(f"BCI Crop Dataset: {len(self)} images, classes: {dict(sorted(dist.items()))}") + + def _parse_label(self, filename): + parts = filename.replace('.png', '').split('_') + if len(parts) >= 3: + level = parts[2] + if level in self.HER2_LABEL_MAP: + return self.HER2_LABEL_MAP[level] + raise ValueError(f"Cannot parse label from: {filename}") + + def __len__(self): + return len(self.he_images) + + def __getitem__(self, idx): + filename = self.he_images[idx] + he_img = Image.open(self.he_dir / filename).convert('RGB') + ihc_img = Image.open(self.ihc_dir / self.ihc_images[idx]).convert('RGB') + return self._process_pair(he_img, ihc_img, self.labels[idx], filename) + + +class MISTCropDataset(CropPairedDataset): + """MIST dataset with random 512 crops from 1024x1024 native images.""" + + def __init__(self, he_dir, ihc_dir, image_size=(512, 512), + crop_size=512, augment=False, null_class=4): + super().__init__(he_dir, ihc_dir, image_size, crop_size, augment, + null_class=null_class) + + valid_exts = ('.jpg', '.jpeg', '.png') + self.he_images = sorted([f for f in os.listdir(he_dir) + if f.lower().endswith(valid_exts)]) + self.ihc_images = sorted([f for f in os.listdir(ihc_dir) + if f.lower().endswith(valid_exts)]) + + # Verify pairing + he_stems = {Path(f).stem for f in self.he_images} + ihc_stems = {Path(f).stem for f in self.ihc_images} + if he_stems != ihc_stems: + common = he_stems & ihc_stems + self.he_images = sorted([f for f in self.he_images if Path(f).stem in common]) + self.ihc_images = sorted([f for f in self.ihc_images if Path(f).stem in common]) + print(f"Using {len(self.he_images)} matched pairs") + + print(f"MIST Crop Dataset: {len(self)} images (null_class={null_class})") + + def __len__(self): + return len(self.he_images) + + def __getitem__(self, idx): + filename = self.he_images[idx] + he_img = Image.open(self.he_dir / filename).convert('RGB') + ihc_img = Image.open(self.ihc_dir / self.ihc_images[idx]).convert('RGB') + return self._process_pair(he_img, ihc_img, self.null_class, filename) + + +class BCICropDataModule(pl.LightningDataModule): + def __init__(self, data_dir, batch_size=4, + num_workers=4, image_size=(512, 512), crop_size=512): + super().__init__() + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = image_size + self.crop_size = crop_size + + def setup(self, stage=None): + if stage == 'fit' or stage is None: + self.train_dataset = BCICropDataset( + he_dir=self.data_dir / 'HE' / 'train', + ihc_dir=self.data_dir / 'IHC' / 'train', + image_size=self.image_size, + crop_size=self.crop_size, + augment=True, + ) + if stage in ('fit', 'validate', 'test') or stage is None: + self.val_dataset = BCICropDataset( + he_dir=self.data_dir / 'HE' / 'test', + ihc_dir=self.data_dir / 'IHC' / 'test', + image_size=self.image_size, + crop_size=self.crop_size, + augment=False, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, pin_memory=True, + persistent_workers=self.num_workers > 0, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=True, + persistent_workers=self.num_workers > 0, + ) + + def test_dataloader(self): + return self.val_dataloader() + + +class MISTCropDataModule(pl.LightningDataModule): + def __init__(self, data_dir, batch_size=4, + num_workers=4, image_size=(512, 512), crop_size=512, null_class=4): + super().__init__() + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = image_size + self.crop_size = crop_size + self.null_class = null_class + + def setup(self, stage=None): + if stage == 'fit' or stage is None: + self.train_dataset = MISTCropDataset( + he_dir=self.data_dir / 'trainA', + ihc_dir=self.data_dir / 'trainB', + image_size=self.image_size, + crop_size=self.crop_size, + augment=True, + null_class=self.null_class, + ) + if stage in ('fit', 'validate', 'test') or stage is None: + self.val_dataset = MISTCropDataset( + he_dir=self.data_dir / 'valA', + ihc_dir=self.data_dir / 'valB', + image_size=self.image_size, + crop_size=self.crop_size, + augment=False, + null_class=self.null_class, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, pin_memory=True, + persistent_workers=self.num_workers > 0, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=True, + persistent_workers=self.num_workers > 0, + ) + + def test_dataloader(self): + return self.val_dataloader() diff --git a/src/data/mist_dataset.py b/src/data/mist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..24297eb050c23e732235867cc98972de8359a3cc --- /dev/null +++ b/src/data/mist_dataset.py @@ -0,0 +1,169 @@ +""" +Multi-stain dataset for training a single model on all MIST IHC stains. + +Combines HER2, Ki67, ER, PR into one dataset, returning a stain label (0-3) +instead of a class label. Reuses the same crop + UNI sub-crop pipeline +from CropPairedDataset. + +Stain label mapping: + 0 = HER2, 1 = Ki67, 2 = ER, 3 = PR, 4 = null (CFG dropout) +""" + +import os +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +from torch.utils.data import Dataset, DataLoader +from PIL import Image +import pytorch_lightning as pl + +from src.data.bci_dataset import CropPairedDataset + + +STAIN_TO_LABEL = {'HER2': 0, 'Ki67': 1, 'ER': 2, 'PR': 3} +LABEL_TO_STAIN = {v: k for k, v in STAIN_TO_LABEL.items()} + + +class MISTMultiStainCropDataset(CropPairedDataset): + """Multi-stain MIST dataset with random 512 crops from native 1024x1024. + + Loads all 4 MIST stains into a single dataset. Each sample returns a + stain label (0-3) as the conditioning signal, reusing the class embedding + slot in the generator. + """ + + def __init__( + self, + base_dir: str, + stains: List[str], + split: str = 'train', + image_size: Tuple[int, int] = (512, 512), + crop_size: int = 512, + augment: bool = False, + null_class: int = 4, + ): + super().__init__( + he_dir='.', # placeholder, we override __getitem__ + ihc_dir='.', + image_size=image_size, + crop_size=crop_size, + augment=augment, + null_class=null_class, + ) + + self.base_dir = Path(base_dir) + self.samples = [] # (he_path, ihc_path, stain_label) + + split_he = 'trainA' if split == 'train' else 'valA' + split_ihc = 'trainB' if split == 'train' else 'valB' + valid_exts = ('.jpg', '.jpeg', '.png') + + for stain in stains: + if stain not in STAIN_TO_LABEL: + raise ValueError(f"Unknown stain: {stain}. Must be one of {list(STAIN_TO_LABEL.keys())}") + + stain_label = STAIN_TO_LABEL[stain] + he_dir = self.base_dir / stain / 'TrainValAB' / split_he + ihc_dir = self.base_dir / stain / 'TrainValAB' / split_ihc + + if not he_dir.exists(): + raise FileNotFoundError(f"H&E directory not found: {he_dir}") + if not ihc_dir.exists(): + raise FileNotFoundError(f"IHC directory not found: {ihc_dir}") + + he_files = sorted([f for f in os.listdir(he_dir) + if f.lower().endswith(valid_exts)]) + ihc_files = sorted([f for f in os.listdir(ihc_dir) + if f.lower().endswith(valid_exts)]) + + # Match by stem (H&E may be .jpg, IHC may be .png) + he_stems = {Path(f).stem: f for f in he_files} + ihc_stems = {Path(f).stem: f for f in ihc_files} + common = sorted(set(he_stems.keys()) & set(ihc_stems.keys())) + + for stem in common: + self.samples.append(( + he_dir / he_stems[stem], + ihc_dir / ihc_stems[stem], + stain_label, + )) + + print(f" {stain} ({split}): {len(common)} pairs") + + # Per-stain counts for logging + from collections import Counter + dist = Counter(s[2] for s in self.samples) + stain_counts = {LABEL_TO_STAIN[k]: v for k, v in sorted(dist.items())} + print(f"Multi-Stain Crop Dataset ({split}): {len(self.samples)} total | {stain_counts}") + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + he_path, ihc_path, stain_label = self.samples[idx] + he_img = Image.open(he_path).convert('RGB') + ihc_img = Image.open(ihc_path).convert('RGB') + return self._process_pair(he_img, ihc_img, stain_label, he_path.name) + + +class MISTMultiStainCropDataModule(pl.LightningDataModule): + """Lightning DataModule for multi-stain MIST training.""" + + def __init__( + self, + base_dir: str, + stains: Optional[List[str]] = None, + batch_size: int = 4, + num_workers: int = 4, + image_size: Tuple[int, int] = (512, 512), + crop_size: int = 512, + null_class: int = 4, + ): + super().__init__() + self.base_dir = base_dir + self.stains = stains or ['HER2', 'Ki67', 'ER', 'PR'] + self.batch_size = batch_size + self.num_workers = num_workers + self.image_size = image_size + self.crop_size = crop_size + self.null_class = null_class + + def setup(self, stage=None): + if stage == 'fit' or stage is None: + self.train_dataset = MISTMultiStainCropDataset( + base_dir=self.base_dir, + stains=self.stains, + split='train', + image_size=self.image_size, + crop_size=self.crop_size, + augment=True, + null_class=self.null_class, + ) + if stage in ('fit', 'validate', 'test') or stage is None: + self.val_dataset = MISTMultiStainCropDataset( + base_dir=self.base_dir, + stains=self.stains, + split='val', + image_size=self.image_size, + crop_size=self.crop_size, + augment=False, + null_class=self.null_class, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, pin_memory=True, + persistent_workers=self.num_workers > 0, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, batch_size=self.batch_size, shuffle=False, + num_workers=self.num_workers, pin_memory=True, + persistent_workers=self.num_workers > 0, + ) + + def test_dataloader(self): + return self.val_dataloader() diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/blocks.py b/src/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..75a4241d69c437e32c167e75a025c8c93e18552c --- /dev/null +++ b/src/models/blocks.py @@ -0,0 +1,109 @@ +""" +Building blocks for UNIStainNet generator. + +- SPADEBlock: SPADE + FiLM normalization (UNI spatial + class channel modulation) +- ResBlock: Residual block with InstanceNorm +- SelfAttention: Self-attention for global context at bottleneck +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SPADEBlock(nn.Module): + """SPADE + FiLM normalization block. + + Combines spatially-adaptive normalization from UNI features (SPADE) + with channel-wise affine modulation from class embedding (FiLM). + """ + + def __init__(self, norm_channels, uni_channels, class_dim=64): + super().__init__() + self.norm = nn.InstanceNorm2d(norm_channels, affine=False) + + # SPADE: learn spatial gamma/beta from UNI features + hidden = min(128, norm_channels) + self.spade_shared = nn.Sequential( + nn.Conv2d(uni_channels, hidden, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + self.spade_gamma = nn.Conv2d(hidden, norm_channels, 3, padding=1) + self.spade_beta = nn.Conv2d(hidden, norm_channels, 3, padding=1) + + # FiLM: learn channel gamma/beta from class embedding + self.film_gamma = nn.Linear(class_dim, norm_channels) + self.film_beta = nn.Linear(class_dim, norm_channels) + + # Init SPADE gamma/beta near zero (ControlNet-style gradual activation) + nn.init.zeros_(self.spade_gamma.weight) + nn.init.zeros_(self.spade_gamma.bias) + nn.init.zeros_(self.spade_beta.weight) + nn.init.zeros_(self.spade_beta.bias) + + # Init FiLM gamma near 1, beta near 0 + nn.init.ones_(self.film_gamma.weight) + nn.init.zeros_(self.film_gamma.bias) + nn.init.zeros_(self.film_beta.weight) + nn.init.zeros_(self.film_beta.bias) + + def forward(self, x, uni_spatial, class_emb): + """ + Args: + x: [B, C, H, W] feature map + uni_spatial: [B, uni_ch, H, W] UNI features at matching resolution + class_emb: [B, class_dim] class embedding + """ + normalized = self.norm(x) + + # SPADE modulation from UNI features + shared = self.spade_shared(uni_spatial) + gamma_s = self.spade_gamma(shared) + beta_s = self.spade_beta(shared) + + # FiLM modulation from class + gamma_c = self.film_gamma(class_emb).unsqueeze(-1).unsqueeze(-1) # [B, C, 1, 1] + beta_c = self.film_beta(class_emb).unsqueeze(-1).unsqueeze(-1) + + # Combined: (gamma_spade + gamma_film) * norm(x) + (beta_spade + beta_film) + return (gamma_s + gamma_c) * normalized + (beta_s + beta_c) + + +class ResBlock(nn.Module): + """Residual block with InstanceNorm.""" + + def __init__(self, channels): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(channels, channels, 3, padding=1), + nn.InstanceNorm2d(channels), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(channels, channels, 3, padding=1), + nn.InstanceNorm2d(channels), + ) + self.act = nn.LeakyReLU(0.2, inplace=True) + + def forward(self, x): + return self.act(x + self.block(x)) + + +class SelfAttention(nn.Module): + """Self-attention layer for global context at bottleneck.""" + + def __init__(self, channels): + super().__init__() + self.norm = nn.GroupNorm(32, channels) + self.qkv = nn.Conv2d(channels, channels * 3, 1) + self.proj = nn.Conv2d(channels, channels, 1) + self.scale = channels ** -0.5 + + def forward(self, x): + B, C, H, W = x.shape + h = self.norm(x) + qkv = self.qkv(h).reshape(B, 3, C, H * W) + q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] + + attn = (q.transpose(-1, -2) @ k) * self.scale + attn = attn.softmax(dim=-1) + out = (v @ attn.transpose(-1, -2)).reshape(B, C, H, W) + return x + self.proj(out) diff --git a/src/models/discriminator.py b/src/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..6129e8a6cbc6dfbbc706505a653f658716da7961 --- /dev/null +++ b/src/models/discriminator.py @@ -0,0 +1,164 @@ +""" +PatchGAN discriminator for HER2 image realism scoring. + +Supports both unconditional (3ch HER2 only) and conditional (6ch H&E+HER2) modes. +Returns patch-level logits AND intermediate features for feature matching loss. + +Architecture: + C64(SN) -> C128(SN+IN) -> C256(SN+IN) -> C512(SN+IN,s1) -> 1ch(SN,s1) + 70x70 receptive field, output [B, 1, 30, 30] for 512x512 input. + ~2.8M params (3ch) or ~2.8M params (6ch). + +References: + - Isola et al., "Image-to-Image Translation with Conditional Adversarial Networks" (CVPR 2017) + - Miyato et al., "Spectral Normalization for GANs" (ICLR 2018) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.autograd as autograd +from torch.nn.utils import spectral_norm + + +class PatchDiscriminator(nn.Module): + """PatchGAN discriminator with spectral normalization. + + Returns both logits and intermediate features (for feature matching loss). + + Args: + in_channels: 3 for unconditional (HER2 only), 6 for conditional (H&E + HER2) + ndf: base number of discriminator filters + n_layers: number of intermediate conv layers + """ + + def __init__(self, in_channels=3, ndf=64, n_layers=3): + super().__init__() + self.n_layers = n_layers + + # Build layers as a list (not sequential) so we can extract features + self.layers = nn.ModuleList() + + # First layer: spectral norm, no instance norm + self.layers.append(nn.Sequential( + spectral_norm(nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1)), + nn.LeakyReLU(0.2, inplace=True), + )) + + # Intermediate layers: spectral norm + instance norm + nf_mult = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + self.layers.append(nn.Sequential( + spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=2, padding=1)), + nn.InstanceNorm2d(ndf * nf_mult), + nn.LeakyReLU(0.2, inplace=True), + )) + + # Penultimate layer: stride 1 + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + self.layers.append(nn.Sequential( + spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=1, padding=1)), + nn.InstanceNorm2d(ndf * nf_mult), + nn.LeakyReLU(0.2, inplace=True), + )) + + # Final layer: 1-channel output, no activation (hinge loss uses raw logits) + self.layers.append(nn.Sequential( + spectral_norm(nn.Conv2d(ndf * nf_mult, 1, 4, stride=1, padding=1)), + )) + + def forward(self, x, return_features=False): + """ + Args: + x: [B, C, H, W] in [-1, 1]. C=3 (unconditional) or C=6 (conditional). + return_features: if True, also return intermediate features for FM loss. + + Returns: + logits: [B, 1, H', W'] patch-level real/fake logits + features: list of intermediate feature maps (only if return_features=True) + """ + features = [] + h = x + for layer in self.layers: + h = layer(h) + if return_features: + features.append(h) + + if return_features: + return h, features + return h + + +# ====================================================================== +# Loss functions +# ====================================================================== + +def hinge_loss_d(d_real, d_fake): + """Discriminator hinge loss.""" + return (torch.relu(1.0 - d_real).mean() + torch.relu(1.0 + d_fake).mean()) / 2 + + +def hinge_loss_g(d_fake): + """Generator hinge loss.""" + return -d_fake.mean() + + +def r1_gradient_penalty(discriminator, real_images, weight=10.0): + """R1 gradient penalty (Mescheder et al., 2018). + + Regularizes discriminator to have small gradients on real data, + which prevents the discriminator from becoming too confident and + stabilizes GAN training. + """ + real_images = real_images.detach().requires_grad_(True) + d_real = discriminator(real_images) + grad_real = autograd.grad( + outputs=d_real.sum(), + inputs=real_images, + create_graph=True, + )[0] + penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() + return weight * penalty + + +def feature_matching_loss(d_feats_fake, d_feats_real): + """Feature matching loss: L1 between discriminator features of fake vs real. + + Matches statistics at each discriminator layer. Alignment-free because + it compares feature distributions, not pixel-level correspondence. + """ + loss = 0.0 + for feat_fake, feat_real in zip(d_feats_fake, d_feats_real): + loss += torch.nn.functional.l1_loss(feat_fake, feat_real.detach()) + return loss / len(d_feats_fake) + + +class MultiScaleDiscriminator(nn.Module): + """Two PatchGAN discriminators at different scales.""" + + def __init__(self, in_channels=6, ndf=64, n_layers=3): + super().__init__() + self.disc_512 = PatchDiscriminator(in_channels, ndf, n_layers) + self.disc_256 = PatchDiscriminator(in_channels, ndf, n_layers) + + def forward(self, x, return_features=False): + """ + Args: + x: [B, 6, 512, 512] concat(output, H&E) + + Returns: + list of (logits, [features]) from each scale + """ + x_256 = F.interpolate(x, size=256, mode='bilinear', align_corners=False) + + if return_features: + out_512, feats_512 = self.disc_512(x, return_features=True) + out_256, feats_256 = self.disc_256(x_256, return_features=True) + return [(out_512, feats_512), (out_256, feats_256)] + else: + out_512 = self.disc_512(x) + out_256 = self.disc_256(x_256) + return [out_512, out_256] diff --git a/src/models/edge_encoder.py b/src/models/edge_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1bcccb0edb0ed1509eb48851e1fd8489c1100812 --- /dev/null +++ b/src/models/edge_encoder.py @@ -0,0 +1,184 @@ +""" +Edge encoders for UNIStainNet: parallel structure pathway from H&E edges. + +- EdgeEncoder (v1): Sequential Sobel → multi-scale CNN +- MultiScaleEdgeEncoder (v2): Independent per-scale edge extraction with RGB input +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class EdgeEncoder(nn.Module): + """Lightweight encoder that extracts multi-scale edge features from H&E input. + + Extracts Sobel edges from grayscale H&E, then encodes them through a small + CNN to produce multi-scale feature maps. These are concatenated with the + main encoder's skip connections in the decoder, giving the generator an + explicit structural signal. + + Key insight: H&E input and generated output share the exact same spatial + frame (no misalignment). So edge features from H&E are pixel-aligned with + the decoder's output — unlike real HER2 ground truth. + """ + + def __init__(self, base_ch=32): + super().__init__() + # Sobel kernels (fixed, not learned) + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=torch.float32).view(1, 1, 3, 3) + sobel_y = sobel_x.transpose(-1, -2) + self.register_buffer('sobel_x', sobel_x) + self.register_buffer('sobel_y', sobel_y) + + # Edge feature encoder: 2ch (grad_x, grad_y) → multi-scale features + # Mirrors the main encoder's spatial hierarchy + self.enc1 = nn.Sequential( # 512→256, out: base_ch + nn.Conv2d(2, base_ch, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc2 = nn.Sequential( # 256→128, out: base_ch*2 + nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1), + nn.InstanceNorm2d(base_ch * 2), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc3 = nn.Sequential( # 128→64, out: base_ch*4 + nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1), + nn.InstanceNorm2d(base_ch * 4), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc4 = nn.Sequential( # 64→32, out: base_ch*4 + nn.Conv2d(base_ch * 4, base_ch * 4, 4, stride=2, padding=1), + nn.InstanceNorm2d(base_ch * 4), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, he_images): + """ + Args: + he_images: [B, 3, 512, 512] in [-1, 1] + + Returns: + dict of edge features at each decoder resolution: + 256: [B, base_ch, 256, 256] + 128: [B, base_ch*2, 128, 128] + 64: [B, base_ch*4, 64, 64] + 32: [B, base_ch*4, 32, 32] + """ + # Convert to grayscale [0, 1] + gray = ((he_images + 1) / 2).mean(dim=1, keepdim=True) # [B, 1, 512, 512] + + # Sobel edge detection + gx = F.conv2d(gray, self.sobel_x, padding=1) + gy = F.conv2d(gray, self.sobel_y, padding=1) + edges = torch.cat([gx, gy], dim=1) # [B, 2, 512, 512] + + # Multi-scale encoding + e1 = self.enc1(edges) # [B, base_ch, 256, 256] + e2 = self.enc2(e1) # [B, base_ch*2, 128, 128] + e3 = self.enc3(e2) # [B, base_ch*4, 64, 64] + e4 = self.enc4(e3) # [B, base_ch*4, 32, 32] + + return {256: e1, 128: e2, 64: e3, 32: e4} + + +class MultiScaleEdgeEncoder(nn.Module): + """Multi-scale edge encoder with independent per-scale edge extraction. + + Improvements over EdgeEncoder: + 1. RGB-aware: Learnable first layer on full RGB (can discover stain-specific + edges — e.g., hematoxylin boundaries vs eosin boundaries carry different + information for HER2 staining). + 2. Multi-scale Sobel: Extracts edges independently at each resolution before + encoding. Fine 2-5px edges don't get lost through sequential downsampling. + 3. Edge features at 512: Provides features at output resolution for fine + structure preservation (cell walls, membrane patterns). + """ + + def __init__(self, base_ch=32): + super().__init__() + # Fixed Sobel kernels for structural prior + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=torch.float32).view(1, 1, 3, 3) + sobel_y = sobel_x.transpose(-1, -2) + self.register_buffer('sobel_x', sobel_x) + self.register_buffer('sobel_y', sobel_y) + + # Per-scale feature extractors + # Input: 3ch RGB + 2ch Sobel = 5ch at each scale + in_ch = 5 + + # 512→512 (edge features at output resolution) + self.scale_512 = nn.Sequential( + nn.Conv2d(in_ch, base_ch, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_ch, base_ch, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 256×256 + self.scale_256 = nn.Sequential( + nn.Conv2d(in_ch, base_ch, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_ch, base_ch, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 128×128 + self.scale_128 = nn.Sequential( + nn.Conv2d(in_ch, base_ch * 2, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_ch * 2, base_ch * 2, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 64×64 + self.scale_64 = nn.Sequential( + nn.Conv2d(in_ch, base_ch * 4, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_ch * 4, base_ch * 4, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 32×32 + self.scale_32 = nn.Sequential( + nn.Conv2d(in_ch, base_ch * 4, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_ch * 4, base_ch * 4, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + def _extract_edges_at_scale(self, he_01, size): + """Downsample H&E, extract Sobel edges, return RGB+edges.""" + if size < 512: + h = F.interpolate(he_01, size=size, mode='bilinear', align_corners=False) + else: + h = he_01 + gray = h.mean(dim=1, keepdim=True) + gx = F.conv2d(gray, self.sobel_x, padding=1) + gy = F.conv2d(gray, self.sobel_y, padding=1) + return torch.cat([h, gx, gy], dim=1) # [B, 5, size, size] + + def forward(self, he_images): + """ + Args: + he_images: [B, 3, 512, 512] in [-1, 1] + + Returns: + dict of edge features at each decoder resolution: + 512: [B, base_ch, 512, 512] + 256: [B, base_ch, 256, 256] + 128: [B, base_ch*2, 128, 128] + 64: [B, base_ch*4, 64, 64] + 32: [B, base_ch*4, 32, 32] + """ + he_01 = (he_images + 1) / 2 # [0, 1] for consistent edge magnitudes + + return { + 512: self.scale_512(self._extract_edges_at_scale(he_01, 512)), + 256: self.scale_256(self._extract_edges_at_scale(he_01, 256)), + 128: self.scale_128(self._extract_edges_at_scale(he_01, 128)), + 64: self.scale_64(self._extract_edges_at_scale(he_01, 64)), + 32: self.scale_32(self._extract_edges_at_scale(he_01, 32)), + } diff --git a/src/models/generator.py b/src/models/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fb8c39764fc6e89df5aa5db026b99b411abfb34e --- /dev/null +++ b/src/models/generator.py @@ -0,0 +1,300 @@ +""" +SPADEUNetGenerator: H&E → IHC translation generator. + +SPADE-UNet conditioned on UNI pathology features + HER2 class embedding. +Encoder processes H&E input, decoder uses SPADE conditioning from UNI features ++ FiLM from class embedding, with skip connections. + +~30M params at 512, supports 1024 with extra encoder/decoder levels. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models.blocks import SPADEBlock, ResBlock, SelfAttention +from src.models.edge_encoder import EdgeEncoder, MultiScaleEdgeEncoder +from src.models.uni_processor import UNIFeatureProcessor, UNIFeatureProcessorHighRes + + +class SPADEUNetGenerator(nn.Module): + """SPADE-UNet generator for H&E → HER2 translation. + + Encoder processes H&E input into multi-scale features. + Decoder uses SPADE conditioning from UNI features + FiLM from class embedding. + Skip connections from encoder to decoder. + + ~30M params. + """ + + def __init__(self, num_classes=5, class_dim=64, uni_dim=1024, + input_skip=False, edge_encoder=False, edge_base_ch=32, + uni_spatial_size=4, image_size=512, uni_spade_at_512=False): + super().__init__() + self.num_classes = num_classes + self.class_dim = class_dim + self.input_skip = input_skip + self.edge_encoder_flag = edge_encoder + self.uni_spatial_size = uni_spatial_size + self.image_size = image_size + self.uni_spade_at_512 = uni_spade_at_512 + + # Class embedding (5 classes: 0, 1+, 2+, 3+, null) + self.class_embed = nn.Embedding(num_classes, class_dim) + + # UNI feature processor — choose based on spatial resolution + if uni_spatial_size >= 16: + # High-res patch tokens (e.g., 32x32 = 1024 tokens) + self.uni_processor = UNIFeatureProcessorHighRes( + uni_dim=uni_dim, base_channels=512, spatial_size=uni_spatial_size, + output_512=(uni_spade_at_512 and image_size == 1024), + ) + else: + # Original CLS-token features (4x4 = 16 tokens) + self.uni_processor = UNIFeatureProcessor( + uni_dim=uni_dim, base_channels=512, + ) + + # Edge encoder (parallel structure pathway) + # Note: edge encoder always operates at 512 resolution. + # For 1024 input, H&E is downsampled to 512 before edge extraction. + self.edge_encoder_type = edge_encoder # False, 'v1', or 'v2' + if edge_encoder == 'v2': + self.edge_encoder = MultiScaleEdgeEncoder(base_ch=edge_base_ch) + edge_ch = {512: edge_base_ch, 256: edge_base_ch, 128: edge_base_ch * 2, + 64: edge_base_ch * 4, 32: edge_base_ch * 4} + elif edge_encoder: # True or 'v1' + self.edge_encoder = EdgeEncoder(base_ch=edge_base_ch) + edge_ch = {512: 0, 256: edge_base_ch, 128: edge_base_ch * 2, + 64: edge_base_ch * 4, 32: edge_base_ch * 4} + else: + self.edge_encoder = None + edge_ch = {512: 0, 256: 0, 128: 0, 64: 0, 32: 0} + + # === 1024 support: extra encoder/decoder levels === + if image_size == 1024: + # enc0: 1024→512 (lightweight, just spatial downsample) + self.enc0 = nn.Sequential( + nn.Conv2d(3, 32, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + enc1_in_ch = 32 # enc1 takes enc0 output, not raw H&E + else: + self.enc0 = None + enc1_in_ch = 3 # enc1 takes raw H&E at 512 + + # Encoder + self.enc1 = nn.Sequential( # 512→256 + nn.Conv2d(enc1_in_ch, 64, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc2 = nn.Sequential( # 256→128 + nn.Conv2d(64, 128, 4, stride=2, padding=1), + nn.InstanceNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc3 = nn.Sequential( # 128→64 + nn.Conv2d(128, 256, 4, stride=2, padding=1), + nn.InstanceNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc4 = nn.Sequential( # 64→32 + nn.Conv2d(256, 512, 4, stride=2, padding=1), + nn.InstanceNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + ) + self.enc5 = nn.Sequential( # 32→16 + nn.Conv2d(512, 512, 4, stride=2, padding=1), + nn.InstanceNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + ) + + # Bottleneck (at 16×16) + self.bottleneck = nn.Sequential( + ResBlock(512), + SelfAttention(512), + ResBlock(512), + ) + + # Decoder with SPADE conditioning + # Channel counts: main_skip + edge_skip (if enabled) + upsampled + # D5: 512 (up) + 512 (skip e4) + edge_ch[32] → 512 + self.dec5_conv = nn.Conv2d(512 + 512 + edge_ch[32], 512, 3, padding=1) + self.dec5_spade = SPADEBlock(512, uni_channels=512, class_dim=class_dim) + self.dec5_act = nn.LeakyReLU(0.2, inplace=True) + + # D4: 512 (up) + 256 (skip e3) + edge_ch[64] → 256 + self.dec4_conv = nn.Conv2d(512 + 256 + edge_ch[64], 256, 3, padding=1) + self.dec4_spade = SPADEBlock(256, uni_channels=256, class_dim=class_dim) + self.dec4_act = nn.LeakyReLU(0.2, inplace=True) + + # D3: 256 (up) + 128 (skip e2) + edge_ch[128] → 128 + self.dec3_conv = nn.Conv2d(256 + 128 + edge_ch[128], 128, 3, padding=1) + self.dec3_spade = SPADEBlock(128, uni_channels=128, class_dim=class_dim) + self.dec3_act = nn.LeakyReLU(0.2, inplace=True) + + # D2: 128 (up) + 64 (skip e1) + edge_ch[256] → 64 + self.dec2_conv = nn.Conv2d(128 + 64 + edge_ch[256], 64, 3, padding=1) + self.dec2_spade = SPADEBlock(64, uni_channels=64, class_dim=class_dim) + self.dec2_act = nn.LeakyReLU(0.2, inplace=True) + + if image_size == 1024: + # D1 (new): upsample 256→512, skip from enc0 (32ch) + edge@512 + dec1_in_ch = 64 + 32 + edge_ch[512] + if uni_spade_at_512: + # UNI SPADE conditioning at 512 level (uni_ch=32 at this scale) + self.dec1_conv = nn.Conv2d(dec1_in_ch, 64, 3, padding=1) + self.dec1_spade = SPADEBlock(64, uni_channels=32, class_dim=class_dim) + self.dec1_act = nn.LeakyReLU(0.2, inplace=True) + else: + self.dec1_conv = nn.Sequential( + nn.Conv2d(dec1_in_ch, 64, 3, padding=1), + nn.InstanceNorm2d(64), + nn.LeakyReLU(0.2, inplace=True), + ) + self.dec1_spade = None + self.dec1_act = None + # Output: upsample 512→1024, optional H&E input skip + output_in_ch = 64 + (3 if input_skip else 0) + self.output = nn.Sequential( + nn.Conv2d(output_in_ch, 64, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, 3, padding=1), + nn.Tanh(), + ) + else: + self.dec1_conv = None + # Output: concat H&E input (3ch if input_skip) + edge@512 (if v2) + output_in_ch = 64 + (3 if input_skip else 0) + edge_ch[512] + self.output = nn.Sequential( + nn.Conv2d(output_in_ch, 64, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, 3, padding=1), + nn.Tanh(), + ) + + def encode(self, images): + """Extract intermediate encoder features for PatchNCE loss. + + Args: + images: [B, 3, H, H] in [-1, 1] (H&E or generated IHC) + + Returns: + dict mapping layer index to feature tensor: + {1: [B, 64, 256, 256], 2: [B, 128, 128, 128], + 3: [B, 256, 64, 64], 4: [B, 512, 32, 32]} + """ + if self.enc0 is not None: + e0 = self.enc0(images) + enc1_input = e0 + else: + enc1_input = images + + e1 = self.enc1(enc1_input) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + return {1: e1, 2: e2, 3: e3, 4: e4} + + def forward(self, he_images, uni_features, labels): + """ + Args: + he_images: [B, 3, H, H] in [-1, 1] where H=512 or H=1024 + uni_features: [B, N, 1024] where N=16 (4x4 CLS) or N=1024 (32x32 patch) + labels: [B] int class labels (0-4) + + Returns: + output: [B, 3, H, H] in [-1, 1] + """ + class_emb = self.class_embed(labels) + uni_maps = self.uni_processor(uni_features) + + # Edge encoder (parallel structure pathway) + # Edge encoder always operates at 512 resolution + if self.edge_encoder_type: + if self.image_size == 1024: + he_512 = F.interpolate(he_images, size=512, mode='bilinear', align_corners=False) + edge_maps = self.edge_encoder(he_512) + else: + edge_maps = self.edge_encoder(he_images) + else: + edge_maps = None + + # === 1024: extra encoder level === + if self.enc0 is not None: + e0 = self.enc0(he_images) # [B, 32, 512, 512] + enc1_input = e0 + else: + e0 = None + enc1_input = he_images + + # Encoder + e1 = self.enc1(enc1_input) # [B, 64, 256, 256] + e2 = self.enc2(e1) # [B, 128, 128, 128] + e3 = self.enc3(e2) # [B, 256, 64, 64] + e4 = self.enc4(e3) # [B, 512, 32, 32] + e5 = self.enc5(e4) # [B, 512, 16, 16] + + # Bottleneck at 16×16 + x = self.bottleneck(e5) # [B, 512, 16, 16] + + # D5: upsample 16→32, skip from e4 + edge@32, UNI at 32 + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + skip5 = [x, e4] + ([edge_maps[32]] if edge_maps else []) + x = torch.cat(skip5, dim=1) + x = self.dec5_conv(x) + x = self.dec5_spade(x, uni_maps[32], class_emb) + x = self.dec5_act(x) + + # D4: upsample 32→64, skip from e3 + edge@64, UNI at 64 + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + skip4 = [x, e3] + ([edge_maps[64]] if edge_maps else []) + x = torch.cat(skip4, dim=1) + x = self.dec4_conv(x) + x = self.dec4_spade(x, uni_maps[64], class_emb) + x = self.dec4_act(x) + + # D3: upsample 64→128, skip from e2 + edge@128, UNI at 128 + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + skip3 = [x, e2] + ([edge_maps[128]] if edge_maps else []) + x = torch.cat(skip3, dim=1) + x = self.dec3_conv(x) + x = self.dec3_spade(x, uni_maps[128], class_emb) + x = self.dec3_act(x) + + # D2: upsample 128→256, skip from e1 + edge@256, UNI at 256 + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + skip2 = [x, e1] + ([edge_maps[256]] if edge_maps else []) + x = torch.cat(skip2, dim=1) + x = self.dec2_conv(x) + x = self.dec2_spade(x, uni_maps[256], class_emb) + x = self.dec2_act(x) + + if self.image_size == 1024: + # D1: upsample 256→512, skip from e0 (32ch) + edge@512 + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + skip1 = [x, e0] + ([edge_maps[512]] if edge_maps else []) + x = torch.cat(skip1, dim=1) + x = self.dec1_conv(x) + if self.dec1_spade is not None: + x = self.dec1_spade(x, uni_maps[512], class_emb) + x = self.dec1_act(x) + # [B, 64, 512, 512] + + # Output: upsample 512→1024, optional H&E input skip + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + if self.input_skip: + x = torch.cat([x, he_images], dim=1) + x = self.output(x) # [B, 3, 1024, 1024] + else: + # D1: upsample 256→512, optional skip from H&E input + edge@512 + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + skip1 = [x] + if self.input_skip: + skip1.append(he_images) + if edge_maps and 512 in edge_maps: + skip1.append(edge_maps[512]) + x = torch.cat(skip1, dim=1) if len(skip1) > 1 else x + x = self.output(x) # [B, 3, 512, 512] + + return x diff --git a/src/models/losses.py b/src/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbe32e0650b115b360b56cb5526249249adc3fa --- /dev/null +++ b/src/models/losses.py @@ -0,0 +1,154 @@ +""" +Loss functions for UNIStainNet. + +- VGGFeatureExtractor: intermediate VGG16 features for Gram-matrix style loss +- gram_matrix: compute Gram matrix of feature maps +- PatchNCELoss: contrastive loss between H&E input and generated output (alignment-free) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VGGFeatureExtractor(nn.Module): + """Extract intermediate VGG16 features for Gram-matrix style loss. + + Uses early VGG layers (relu1_2, relu2_2, relu3_3) which capture texture + at different scales. Gram matrices of these features are alignment-invariant + texture descriptors — they measure feature co-occurrence statistics, not + spatial layout (Gatys et al., 2016). + """ + + def __init__(self): + super().__init__() + from torchvision.models import vgg16, VGG16_Weights + vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features + # Extract at relu1_2 (idx 4), relu2_2 (idx 9), relu3_3 (idx 16) + self.slice1 = nn.Sequential(*list(vgg.children())[:4]) # → relu1_2 + self.slice2 = nn.Sequential(*list(vgg.children())[4:9]) # → relu2_2 + self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # → relu3_3 + # Freeze + for p in self.parameters(): + p.requires_grad = False + self.eval() + # ImageNet normalization + self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """ + Args: + x: [B, 3, H, W] in [-1, 1] + Returns: + list of feature maps at 3 scales + """ + # Normalize: [-1,1] → [0,1] → ImageNet + x = (x + 1) / 2 + x = (x - self.mean) / self.std + f1 = self.slice1(x) + f2 = self.slice2(f1) + f3 = self.slice3(f2) + return [f1, f2, f3] + + +def gram_matrix(feat): + """Compute Gram matrix of feature map. + + Args: + feat: [B, C, H, W] + Returns: + gram: [B, C, C] — normalized by spatial size + """ + B, C, H, W = feat.shape + feat_flat = feat.reshape(B, C, H * W) # [B, C, N] + gram = torch.bmm(feat_flat, feat_flat.transpose(1, 2)) # [B, C, C] + return gram / (C * H * W) + + +class PatchNCELoss(nn.Module): + """Patchwise Noise Contrastive Estimation loss. + + Compares H&E input and generated IHC through the generator's encoder. + For each spatial position in the generated features, the corresponding + position in the H&E features is the positive, and random other positions + are negatives. Never sees GT IHC. + + Reference: Park et al., "Contrastive Learning for Unpaired Image-to-Image + Translation" (ECCV 2020) — adapted for paired (misaligned) setting. + """ + + def __init__(self, layer_channels, num_patches=256, temperature=0.07): + """ + Args: + layer_channels: dict {layer_idx: channels} for each encoder layer + num_patches: number of spatial positions to sample per layer + temperature: InfoNCE temperature + """ + super().__init__() + self.num_patches = num_patches + self.temperature = temperature + + # 2-layer MLP projection head per encoder layer + self.mlps = nn.ModuleDict() + for layer_idx, ch in layer_channels.items(): + self.mlps[str(layer_idx)] = nn.Sequential( + nn.Linear(ch, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + ) + + def forward(self, feats_src, feats_tgt): + """Compute PatchNCE loss across encoder layers. + + Args: + feats_src: dict {layer_idx: [B, C, H, W]} from H&E input + feats_tgt: dict {layer_idx: [B, C, H, W]} from generated IHC + + Returns: + scalar loss + """ + total_loss = 0.0 + n_layers = 0 + + for layer_idx_str, mlp in self.mlps.items(): + layer_idx = int(layer_idx_str) + feat_src = feats_src[layer_idx] # [B, C, H, W] + feat_tgt = feats_tgt[layer_idx] # [B, C, H, W] + + B, C, H, W = feat_src.shape + n_total = H * W + + # Reshape to [B, C, H*W] then [B, H*W, C] + src_flat = feat_src.flatten(2).permute(0, 2, 1) # [B, HW, C] + tgt_flat = feat_tgt.flatten(2).permute(0, 2, 1) # [B, HW, C] + + # Sample random spatial positions + n_sample = min(self.num_patches, n_total) + idx = torch.randperm(n_total, device=feat_src.device)[:n_sample] + + src_sampled = src_flat[:, idx, :] # [B, n_sample, C] + tgt_sampled = tgt_flat[:, idx, :] # [B, n_sample, C] + + # Project through MLP + src_proj = mlp(src_sampled) # [B, n_sample, 256] + tgt_proj = mlp(tgt_sampled) # [B, n_sample, 256] + + # L2 normalize + src_proj = F.normalize(src_proj, dim=-1) + tgt_proj = F.normalize(tgt_proj, dim=-1) + + # InfoNCE: for each query (tgt), positive is matching src position + # negatives are all other src positions + # logits: [B, n_sample, n_sample] — (i,j) = similarity of tgt_i to src_j + logits = torch.bmm(tgt_proj, src_proj.transpose(1, 2)) # [B, n, n] + logits = logits / self.temperature + + # Target: diagonal (position i matches position i) + target = torch.arange(n_sample, device=logits.device).unsqueeze(0).expand(B, -1) + + loss = F.cross_entropy(logits.flatten(0, 1), target.flatten(0, 1)) + total_loss = total_loss + loss + n_layers += 1 + + return total_loss / n_layers if n_layers > 0 else total_loss diff --git a/src/models/trainer.py b/src/models/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc11fe8f1cd4235a1d3d88befc454080863319c --- /dev/null +++ b/src/models/trainer.py @@ -0,0 +1,989 @@ +""" +UNIStainNet: Pixel-Space UNI-Guided Virtual Staining Network. + +Architecture: + Generator: SPADE-UNet conditioned on UNI pathology features + stain/class embedding + Discriminator: Multi-scale PatchGAN (512 + 256) + Losses: LPIPS@128 + adversarial + DAB intensity + DAB contrast + +References: + - Park et al., "Semantic Image Synthesis with SPADE" (CVPR 2019) + - Chen et al., "A general-purpose self-supervised model for pathology" (Nature Medicine 2024) + - Isola et al., "Image-to-Image Translation with pix2pix" (CVPR 2017) +""" + +import copy +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import lpips +import pytorch_lightning as pl +import torchvision +import wandb + +from src.models.discriminator import ( + PatchDiscriminator, MultiScaleDiscriminator, + hinge_loss_d, hinge_loss_g, r1_gradient_penalty, feature_matching_loss, +) +from src.models.generator import SPADEUNetGenerator +from src.models.losses import VGGFeatureExtractor, gram_matrix, PatchNCELoss +from src.utils.dab import DABExtractor + + +# ====================================================================== +# Training Module +# ====================================================================== + +class UNIStainNetTrainer(pl.LightningModule): + """PyTorch Lightning training module for UNIStainNet. + + Handles GAN training with manual optimization, CFG dropout, EMA, and + all loss computations. + """ + + def __init__( + self, + # Architecture + num_classes=5, + null_class=4, + class_dim=64, + uni_dim=1024, + ndf=64, + disc_n_layers=3, + input_skip=False, + # Optimizer + gen_lr=1e-4, + disc_lr=4e-4, + warmup_steps=1000, + # Loss weights + lpips_weight=1.0, + lpips_256_weight=0.5, + lpips_512_weight=0.0, + adversarial_weight=1.0, + dab_intensity_weight=0.1, + dab_contrast_weight=0.05, + dab_sharpness_weight=0.0, + gram_style_weight=0.0, + edge_weight=0.0, + he_edge_weight=0.0, + bg_white_weight=0.0, + bg_threshold=0.85, + l1_lowres_weight=0.0, + edge_encoder=False, + edge_base_ch=32, + uni_spatial_size=4, + uncond_disc_weight=0.0, + crop_disc_weight=0.0, + crop_size=128, + feat_match_weight=0.0, + patchnce_weight=0.0, + patchnce_layers=(2, 3, 4), + patchnce_n_patches=256, + patchnce_temperature=0.07, + # Ablation + disable_uni=False, + disable_class=False, + # GAN training + r1_weight=10.0, + r1_every=16, + adversarial_start_step=2000, + # CFG + cfg_drop_class_prob=0.10, + cfg_drop_uni_prob=0.10, + cfg_drop_both_prob=0.05, + # EMA + ema_decay=0.999, + # On-the-fly UNI extraction (for crop-based training) + extract_uni_on_the_fly=False, + uni_spatial_pool_size=32, + # Resolution + image_size=512, + # 1024 architecture: extend UNI SPADE to 512 level + uni_spade_at_512=False, + # Per-label names for multi-stain logging + label_names=None, + ): + super().__init__() + self.save_hyperparameters() + self.automatic_optimization = False + + self.null_class = null_class + + # On-the-fly UNI feature extraction (loaded lazily on first use) + self._uni_model = None + self._uni_extract_on_the_fly = extract_uni_on_the_fly + + # Generator + self.generator = SPADEUNetGenerator( + num_classes=num_classes, + class_dim=class_dim, + uni_dim=uni_dim, + input_skip=input_skip, + edge_encoder=edge_encoder, + edge_base_ch=edge_base_ch, + uni_spatial_size=uni_spatial_size, + image_size=image_size, + uni_spade_at_512=uni_spade_at_512, + ) + + # Discriminator (global multi-scale) + self.discriminator = MultiScaleDiscriminator( + in_channels=6, ndf=ndf, n_layers=disc_n_layers, + ) + + # Crop discriminator (local full-res detail) + if crop_disc_weight > 0: + self.crop_discriminator = PatchDiscriminator( + in_channels=6, ndf=ndf, n_layers=disc_n_layers, + ) + else: + self.crop_discriminator = None + + # Unconditional discriminator (HER2-only, alignment-free texture judge) + # Also needed for feature matching loss (FM uses uncond disc features) + if uncond_disc_weight > 0 or feat_match_weight > 0: + self.uncond_discriminator = PatchDiscriminator( + in_channels=3, ndf=ndf, n_layers=disc_n_layers, + ) + else: + self.uncond_discriminator = None + + # PatchNCE loss (contrastive, alignment-free: H&E input vs generated) + if patchnce_weight > 0: + # Encoder channel dims: {1: 64, 2: 128, 3: 256, 4: 512} + enc_channels = {1: 64, 2: 128, 3: 256, 4: 512} + layer_channels = {l: enc_channels[l] for l in patchnce_layers} + self.patchnce_loss = PatchNCELoss( + layer_channels=layer_channels, + num_patches=patchnce_n_patches, + temperature=patchnce_temperature, + ) + else: + self.patchnce_loss = None + + # EMA generator + self.generator_ema = copy.deepcopy(self.generator) + self.generator_ema.requires_grad_(False) + + # Losses + self.lpips_fn = lpips.LPIPS(net='alex') + self.lpips_fn.requires_grad_(False) + self.lpips_fn.eval() + + self.dab_extractor = DABExtractor(device='cpu') + + # VGG feature extractor for Gram-matrix style loss + if gram_style_weight > 0: + self.vgg_extractor = VGGFeatureExtractor() + self.vgg_extractor.requires_grad_(False) + self.vgg_extractor.eval() + else: + self.vgg_extractor = None + + # Param counts + n_gen = sum(p.numel() for p in self.generator.parameters()) + n_disc = sum(p.numel() for p in self.discriminator.parameters()) + n_crop = sum(p.numel() for p in self.crop_discriminator.parameters()) if self.crop_discriminator else 0 + n_uncond = sum(p.numel() for p in self.uncond_discriminator.parameters()) if self.uncond_discriminator else 0 + print(f"Generator: {n_gen:,} params") + print(f"Discriminator: {n_disc:,} params (global) + {n_crop:,} (crop) + {n_uncond:,} (uncond)") + + def configure_optimizers(self): + gen_params = list(self.generator.parameters()) + if self.patchnce_loss is not None: + gen_params += list(self.patchnce_loss.parameters()) + opt_g = torch.optim.Adam( + gen_params, + lr=self.hparams.gen_lr, + betas=(0.0, 0.999), + ) + # All discriminator params in one optimizer + disc_params = list(self.discriminator.parameters()) + if self.crop_discriminator is not None: + disc_params += list(self.crop_discriminator.parameters()) + if self.uncond_discriminator is not None: + disc_params += list(self.uncond_discriminator.parameters()) + opt_d = torch.optim.Adam( + disc_params, + lr=self.hparams.disc_lr, + betas=(0.0, 0.999), + ) + return [opt_g, opt_d] + + def _get_lr_scale(self): + """Linear warmup.""" + if self.global_step < self.hparams.warmup_steps: + return self.global_step / max(1, self.hparams.warmup_steps) + return 1.0 + + @torch.no_grad() + def _update_ema(self): + """Update EMA generator weights.""" + decay = self.hparams.ema_decay + for p_ema, p in zip(self.generator_ema.parameters(), self.generator.parameters()): + p_ema.data.mul_(decay).add_(p.data, alpha=1 - decay) + + def on_save_checkpoint(self, checkpoint): + """Exclude frozen UNI model from checkpoint (it's reloaded on-the-fly).""" + state_dict = checkpoint.get('state_dict', {}) + keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')] + for k in keys_to_remove: + del state_dict[k] + + def on_load_checkpoint(self, checkpoint): + """Filter out UNI model keys from old checkpoints that included them.""" + state_dict = checkpoint.get('state_dict', {}) + keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')] + for k in keys_to_remove: + del state_dict[k] + + def _load_uni_model(self): + """Lazily load UNI ViT-L/16 for on-the-fly feature extraction.""" + if self._uni_model is None: + import timm + self._uni_model = timm.create_model( + "hf-hub:MahmoodLab/uni", + pretrained=True, + init_values=1e-5, + dynamic_img_size=True, + ) + self._uni_model.eval() + self._uni_model.requires_grad_(False) + self._uni_model = self._uni_model.to(self.device) + n_params = sum(p.numel() for p in self._uni_model.parameters()) + print(f"UNI model loaded for on-the-fly extraction: {n_params:,} params") + return self._uni_model + + @torch.no_grad() + def _extract_uni_from_sub_crops(self, uni_sub_crops): + """Extract UNI features from pre-prepared sub-crops on GPU. + + Args: + uni_sub_crops: [B, 16, 3, 224, 224] — batch of 4x4 sub-crop grids, + already normalized with ImageNet stats. + + Returns: + uni_features: [B, S*S, 1024] where S = uni_spatial_pool_size (default 32) + """ + uni_model = self._load_uni_model() + B = uni_sub_crops.shape[0] + spatial_size = self.hparams.uni_spatial_pool_size + num_crops = 4 # 4x4 grid + patches_per_side = 14 # 224/16 + + # Batched UNI forward: [B, 16, 3, 224, 224] -> [B*16, 3, 224, 224] + all_crops = uni_sub_crops.reshape(B * 16, 3, 224, 224).to(self.device) + all_feats = uni_model.forward_features(all_crops) # [B*16, 197, 1024] + patch_tokens = all_feats[:, 1:, :] # [B*16, 196, 1024] + + # Reshape back to per-sample grids: [B, 4, 4, 14, 14, 1024] + patch_tokens = patch_tokens.reshape( + B, num_crops, num_crops, + patches_per_side, patches_per_side, 1024 + ) + # Interleave to spatial grid: [B, 56, 56, 1024] + full_size = num_crops * patches_per_side # 56 + full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5) + full_grid = full_grid.reshape(B, full_size, full_size, 1024) + + # Pool to target spatial size (batched) + if spatial_size < full_size: + grid_bchw = full_grid.permute(0, 3, 1, 2) # [B, 1024, 56, 56] + pooled = F.adaptive_avg_pool2d(grid_bchw, spatial_size) # [B, 1024, S, S] + result = pooled.permute(0, 2, 3, 1) # [B, S, S, 1024] + else: + result = full_grid + + S = result.shape[1] + return result.reshape(B, S * S, 1024) # [B, S*S, 1024] + + def _apply_cfg_dropout(self, labels, uni_features): + """Apply classifier-free guidance dropout during training (vectorized).""" + B = labels.shape[0] + device = labels.device + + new_labels = labels.clone() + new_uni = uni_features.clone() + + r = torch.rand(B, device=device) + p_both = self.hparams.cfg_drop_both_prob + p_class = p_both + self.hparams.cfg_drop_class_prob + p_uni = p_class + self.hparams.cfg_drop_uni_prob + + drop_both = r < p_both + drop_class = (r >= p_both) & (r < p_class) + drop_uni = (r >= p_class) & (r < p_uni) + + new_labels[drop_both | drop_class] = self.null_class + new_uni[drop_both | drop_uni] = 0.0 + + return new_labels, new_uni + + def compute_dab_intensity_loss(self, generated, target): + """Top-10% percentile matching for DAB intensity.""" + with torch.amp.autocast('cuda', enabled=False): + gen = generated.float() + tgt = target.float() + + dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none") + dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none") + + def _batched_top10_mean(dab): + """Compute mean of top-10% DAB intensity per sample (vectorized).""" + B = dab.shape[0] + flat = dab.reshape(B, -1) # [B, H*W] + p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True) + flat = flat.clamp(max=p99) + p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True) + mask = flat >= p90 # [B, H*W] + # Use masked mean: sum(vals * mask) / sum(mask), fallback to flat mean + masked_sum = (flat * mask).sum(dim=1) + mask_count = mask.sum(dim=1).clamp(min=1) + return masked_sum / mask_count # [B] + + gen_scores = _batched_top10_mean(dab_gen) + tgt_scores = _batched_top10_mean(dab_tgt) + return F.l1_loss(gen_scores, tgt_scores) + + def compute_dab_contrast_loss(self, generated, labels): + """Class-ordering hinge loss: DAB(3+) > DAB(2+) > DAB(1+) > DAB(0).""" + with torch.amp.autocast('cuda', enabled=False): + gen = generated.float() + # Only use non-null labels + valid = labels < self.null_class + if valid.sum() < 2: + return torch.tensor(0.0, device=self.device, requires_grad=True) + + gen_valid = gen[valid] + labels_valid = labels[valid] + + dab_gen = self.dab_extractor.extract_dab_intensity(gen_valid, normalize="none") + + B = dab_gen.shape[0] + flat = dab_gen.reshape(B, -1) + p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True) + flat = flat.clamp(max=p99) + p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True) + mask = flat >= p90 + masked_sum = (flat * mask).sum(dim=1) + mask_count = mask.sum(dim=1).clamp(min=1) + dab_scores = masked_sum / mask_count + + class_pairs = [ + (3, 0, 0.20), (3, 1, 0.15), + (2, 0, 0.08), (3, 2, 0.10), + ] + + losses = [] + for high_cls, low_cls, margin in class_pairs: + high_mask = labels_valid == high_cls + low_mask = labels_valid == low_cls + if high_mask.sum() > 0 and low_mask.sum() > 0: + high_score = dab_scores[high_mask].mean() + low_score = dab_scores[low_mask].mean() + losses.append(F.relu(margin - (high_score - low_score))) + + if losses: + return torch.stack(losses).mean() + return torch.tensor(0.0, device=self.device, requires_grad=True) + + def compute_edge_loss(self, generated, target): + """Fourier spectral loss at 256x256 for boundary sharpness. + + Compares power spectrum magnitudes between generated and target. + The Fourier magnitude is inherently translation-invariant — shifting + an image doesn't change its frequency content — so this is robust to + the ~30px misalignment in consecutive-cut BCI pairs. + + Focuses on high-frequency bands (outer 75% of spectrum) where + blurriness manifests as reduced power. + """ + with torch.amp.autocast('cuda', enabled=False): + gen = F.interpolate(generated.float(), size=256, mode='bilinear', align_corners=False) + tgt = F.interpolate(target.float(), size=256, mode='bilinear', align_corners=False) + + # Grayscale + gen_gray = gen.mean(dim=1, keepdim=True) + tgt_gray = tgt.mean(dim=1, keepdim=True) + + # 2D FFT -> power spectrum (log-scale for stability) + gen_fft = torch.fft.fft2(gen_gray) + tgt_fft = torch.fft.fft2(tgt_gray) + gen_mag = torch.log1p(gen_fft.abs()) + tgt_mag = torch.log1p(tgt_fft.abs()) + + # High-frequency mask: keep outer 75% of spectrum + H, W = gen_mag.shape[-2], gen_mag.shape[-1] + cy, cx = H // 2, W // 2 + y = torch.arange(H, device=gen.device).float() - cy + x = torch.arange(W, device=gen.device).float() - cx + dist = (y[:, None] ** 2 + x[None, :] ** 2).sqrt() + max_dist = (cy ** 2 + cx ** 2) ** 0.5 + hf_mask = (dist > 0.25 * max_dist).float() + + # L1 on high-frequency magnitudes + return F.l1_loss(gen_mag * hf_mask, tgt_mag * hf_mask) + + def compute_dab_sharpness_loss(self, generated, target): + """DAB spatial sharpness loss: penalizes diffuse brown, rewards membrane-localized DAB. + + Two components: + 1. DAB gradient magnitude: mean Sobel gradient magnitude per image. + 2. DAB local variance distribution: sorted-L1 (Wasserstein-1) on + patch variance vectors. + """ + with torch.amp.autocast('cuda', enabled=False): + gen = generated.float() + tgt = target.float() + + dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none") + dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none") + + # Ensure [B, 1, H, W] + if dab_gen.dim() == 3: + dab_gen = dab_gen.unsqueeze(1) + if dab_tgt.dim() == 3: + dab_tgt = dab_tgt.unsqueeze(1) + + B = dab_gen.shape[0] + + # --- Component 1: Gradient magnitude (batched) --- + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=torch.float32, device=gen.device).view(1, 1, 3, 3) + sobel_y = sobel_x.transpose(-1, -2) + + gx_gen = F.conv2d(dab_gen, sobel_x, padding=1) + gy_gen = F.conv2d(dab_gen, sobel_y, padding=1) + grad_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt() + + gx_tgt = F.conv2d(dab_tgt, sobel_x, padding=1) + gy_tgt = F.conv2d(dab_tgt, sobel_y, padding=1) + grad_tgt = (gx_tgt**2 + gy_tgt**2 + 1e-8).sqrt() + + # Match mean gradient magnitude per image + grad_loss = F.l1_loss(grad_gen.mean(dim=[1, 2, 3]), grad_tgt.mean(dim=[1, 2, 3])) + + # --- Component 2: Local variance distribution (sorted-L1) --- + ps = 16 # patch size + var_losses = [] + for i in range(B): + g = dab_gen[i, 0] # [H, W] + t = dab_tgt[i, 0] + + H, W = g.shape + nH, nW = H // ps, W // ps + g_patches = g[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps) + t_patches = t[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps) + + g_var = g_patches.var(dim=1) + t_var = t_patches.var(dim=1) + + g_sorted, _ = g_var.sort() + t_sorted, _ = t_var.sort() + var_losses.append(F.l1_loss(g_sorted, t_sorted.detach())) + + var_loss = torch.stack(var_losses).mean() + + return grad_loss + var_loss + + def compute_he_edge_loss(self, generated, he_input): + """H&E edge structure preservation loss. + + Extracts Sobel edges from H&E input and generated output, then + computes L1 loss between edge maps at multiple scales. + """ + with torch.amp.autocast('cuda', enabled=False): + gen = generated.float() + he = he_input.float() + + gen_gray = ((gen + 1) / 2).mean(dim=1, keepdim=True) + he_gray = ((he + 1) / 2).mean(dim=1, keepdim=True) + + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], + dtype=torch.float32, device=gen.device).view(1, 1, 3, 3) + sobel_y = sobel_x.transpose(-1, -2) + + loss = 0.0 + full_size = gen_gray.shape[-1] + scales = [full_size, full_size // 2] + for size in scales: + if size < full_size: + g = F.interpolate(gen_gray, size=size, mode='bilinear', align_corners=False) + h = F.interpolate(he_gray, size=size, mode='bilinear', align_corners=False) + else: + g, h = gen_gray, he_gray + + gx_gen = F.conv2d(g, sobel_x, padding=1) + gy_gen = F.conv2d(g, sobel_y, padding=1) + edge_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt() + + gx_he = F.conv2d(h, sobel_x, padding=1) + gy_he = F.conv2d(h, sobel_y, padding=1) + edge_he = (gx_he**2 + gy_he**2 + 1e-8).sqrt() + + loss = loss + F.l1_loss(edge_gen, edge_he.detach()) + + return loss / 2.0 + + def compute_background_loss(self, generated, he_input): + """Background white loss: push background regions toward white.""" + with torch.amp.autocast('cuda', enabled=False): + gen = generated.float() + he = he_input.float() + + he_bright = ((he + 1) / 2).mean(dim=1, keepdim=True) + + threshold = self.hparams.bg_threshold + mask = torch.sigmoid((he_bright - threshold) * 20.0) + + white_target = torch.ones_like(gen) + diff = (gen - white_target).abs() + + weighted_diff = diff * mask + + mask_sum = mask.sum() * 3 + if mask_sum > 0: + return weighted_diff.sum() / mask_sum + return torch.tensor(0.0, device=gen.device, requires_grad=True) + + def compute_gram_style_loss(self, generated, target): + """Gram-matrix style loss: match texture statistics via VGG feature correlations.""" + with torch.amp.autocast('cuda', enabled=False): + gen = generated.float() + tgt = target.float() + + gen_256 = F.interpolate(gen, size=256, mode='bilinear', align_corners=False) + tgt_256 = F.interpolate(tgt, size=256, mode='bilinear', align_corners=False) + + gen_feats = self.vgg_extractor(gen_256) + tgt_feats = self.vgg_extractor(tgt_256) + + loss = 0.0 + for gf, tf in zip(gen_feats, tgt_feats): + gram_gen = gram_matrix(gf) + gram_tgt = gram_matrix(tf) + loss = loss + F.l1_loss(gram_gen, gram_tgt.detach()) + + return loss / len(gen_feats) + + def training_step(self, batch, batch_idx): + he, her2, uni_or_crops, labels, fnames = batch + opt_g, opt_d = self.optimizers() + + # On-the-fly UNI extraction: dataset returns [B, 16, 3, 224, 224] sub-crops + if self._uni_extract_on_the_fly: + uni = self._extract_uni_from_sub_crops(uni_or_crops) + else: + uni = uni_or_crops + + # Apply CFG dropout + labels_dropped, uni_dropped = self._apply_cfg_dropout(labels, uni) + + # Ablation: zero out UNI features + if self.hparams.disable_uni: + uni_dropped = torch.zeros_like(uni_dropped) + + # Ablation: force all labels to null class + if self.hparams.disable_class: + labels_dropped = torch.full_like(labels_dropped, self.hparams.null_class) + + # ---------------------------------------------------------------- + # Generator step + # ---------------------------------------------------------------- + generated = self.generator(he, uni_dropped, labels_dropped) + + # LPIPS main: 4x downsample (128 for 512 input, 256 for 1024) + lpips_main_size = self.hparams.image_size // 4 + gen_lpips = F.interpolate(generated, size=lpips_main_size, mode='bilinear', align_corners=False) + her2_lpips = F.interpolate(her2, size=lpips_main_size, mode='bilinear', align_corners=False) + loss_lpips = self.lpips_fn(gen_lpips, her2_lpips).mean() + + loss_g = self.hparams.lpips_weight * loss_lpips + + # LPIPS fine: 2x downsample (256 for 512 input, 512 for 1024) + if self.hparams.lpips_256_weight > 0: + lpips_fine_size = self.hparams.image_size // 2 + gen_fine = F.interpolate(generated, size=lpips_fine_size, mode='bilinear', align_corners=False) + her2_fine = F.interpolate(her2, size=lpips_fine_size, mode='bilinear', align_corners=False) + loss_lpips_256 = self.lpips_fn(gen_fine, her2_fine).mean() + loss_g = loss_g + self.hparams.lpips_256_weight * loss_lpips_256 + self.log('train/lpips_fine', loss_lpips_256, prog_bar=False) + + # LPIPS at full resolution (expensive) + if self.hparams.lpips_512_weight > 0: + loss_lpips_512 = self.lpips_fn(generated, her2).mean() + loss_g = loss_g + self.hparams.lpips_512_weight * loss_lpips_512 + self.log('train/lpips_fullres', loss_lpips_512, prog_bar=False) + + # Low-resolution L1 (color fidelity, misalignment-robust at 64x64) + if self.hparams.l1_lowres_weight > 0: + gen_64 = F.interpolate(generated, size=64, mode='bilinear', align_corners=False) + her2_64 = F.interpolate(her2, size=64, mode='bilinear', align_corners=False) + loss_l1_lowres = F.l1_loss(gen_64, her2_64) + loss_g = loss_g + self.hparams.l1_lowres_weight * loss_l1_lowres + self.log('train/l1_lowres', loss_l1_lowres, prog_bar=False) + + # DAB losses (use original labels, not dropped) + if self.hparams.dab_intensity_weight > 0: + loss_dab = self.compute_dab_intensity_loss(generated, her2) + loss_g = loss_g + self.hparams.dab_intensity_weight * loss_dab + self.log('train/dab_intensity', loss_dab, prog_bar=False) + + if self.hparams.dab_contrast_weight > 0: + loss_dab_contrast = self.compute_dab_contrast_loss(generated, labels) + loss_g = loss_g + self.hparams.dab_contrast_weight * loss_dab_contrast + self.log('train/dab_contrast', loss_dab_contrast, prog_bar=False) + + # Edge loss (boundary sharpness) + if self.hparams.edge_weight > 0: + loss_edge = self.compute_edge_loss(generated, her2) + loss_g = loss_g + self.hparams.edge_weight * loss_edge + self.log('train/edge_loss', loss_edge, prog_bar=False) + + # DAB sharpness loss (membrane-localized vs diffuse brown) + if self.hparams.dab_sharpness_weight > 0: + loss_dab_sharp = self.compute_dab_sharpness_loss(generated, her2) + loss_g = loss_g + self.hparams.dab_sharpness_weight * loss_dab_sharp + self.log('train/dab_sharpness', loss_dab_sharp, prog_bar=False) + + # Gram-matrix style loss + if self.hparams.gram_style_weight > 0 and self.vgg_extractor is not None: + loss_gram = self.compute_gram_style_loss(generated, her2) + loss_g = loss_g + self.hparams.gram_style_weight * loss_gram + self.log('train/gram_style', loss_gram, prog_bar=False) + + # H&E edge structure preservation (pixel-aligned) + if self.hparams.he_edge_weight > 0: + loss_he_edge = self.compute_he_edge_loss(generated, he) + loss_g = loss_g + self.hparams.he_edge_weight * loss_he_edge + self.log('train/he_edge', loss_he_edge, prog_bar=False) + + # Background white loss + if self.hparams.bg_white_weight > 0: + loss_bg = self.compute_background_loss(generated, he) + loss_g = loss_g + self.hparams.bg_white_weight * loss_bg + self.log('train/bg_white', loss_bg, prog_bar=False) + + # PatchNCE loss (contrastive: H&E input vs generated, never sees GT) + if self.hparams.patchnce_weight > 0 and self.patchnce_loss is not None: + feats_he = self.generator.encode(he) + feats_gen = self.generator.encode(generated) + loss_nce = self.patchnce_loss(feats_he, feats_gen) + loss_g = loss_g + self.hparams.patchnce_weight * loss_nce + self.log('train/patchnce', loss_nce, prog_bar=False) + + # Adversarial losses (after warmup) + loss_adv = torch.tensor(0.0, device=self.device) + loss_feat_match = torch.tensor(0.0, device=self.device) + loss_crop_adv = torch.tensor(0.0, device=self.device) + loss_uncond_adv = torch.tensor(0.0, device=self.device) + any_adv = (self.hparams.adversarial_weight > 0 or + self.hparams.uncond_disc_weight > 0 or + self.hparams.crop_disc_weight > 0 or + self.hparams.feat_match_weight > 0) + img_sz = self.hparams.image_size + # Pre-compute disc-resolution tensors (512 for 1024 input, identity for 512) + if img_sz == 1024: + he_for_disc = F.interpolate(he, size=512, mode='bilinear', align_corners=False) + her2_for_disc = F.interpolate(her2, size=512, mode='bilinear', align_corners=False) + else: + he_for_disc = he + her2_for_disc = her2 + if self.global_step >= self.hparams.adversarial_start_step and any_adv: + if img_sz == 1024: + gen_for_disc = F.interpolate(generated, size=512, mode='bilinear', align_corners=False) + else: + gen_for_disc = generated + + # Conditional discriminator (paired: generated+HE vs real_HER2+HE) + if self.hparams.adversarial_weight > 0: + fake_input = torch.cat([gen_for_disc, he_for_disc], dim=1) + disc_outputs = self.discriminator(fake_input) + loss_adv = sum(hinge_loss_g(out) for out in disc_outputs) / len(disc_outputs) + loss_g = loss_g + self.hparams.adversarial_weight * loss_adv + + # Feature matching from unconditional disc + if (self.hparams.feat_match_weight > 0 and + self.uncond_discriminator is not None): + _, fake_feats = self.uncond_discriminator(gen_for_disc, return_features=True) + with torch.no_grad(): + _, real_feats = self.uncond_discriminator(her2_for_disc, return_features=True) + loss_feat_match = feature_matching_loss(fake_feats, real_feats) + loss_g = loss_g + self.hparams.feat_match_weight * loss_feat_match + + # Crop discriminator: random crops at full resolution + if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0: + fake_input_crop = torch.cat([generated, he], dim=1) + cs = self.hparams.crop_size + top = torch.randint(0, img_sz - cs, (1,)).item() + left = torch.randint(0, img_sz - cs, (1,)).item() + fake_crop = fake_input_crop[:, :, top:top+cs, left:left+cs] + loss_crop_adv = hinge_loss_g(self.crop_discriminator(fake_crop)) + loss_g = loss_g + self.hparams.crop_disc_weight * loss_crop_adv + + # Unconditional discriminator: HER2-only adversarial + if self.uncond_discriminator is not None and self.hparams.uncond_disc_weight > 0: + loss_uncond_adv = hinge_loss_g(self.uncond_discriminator(gen_for_disc)) + loss_g = loss_g + self.hparams.uncond_disc_weight * loss_uncond_adv + + # Generator backward + step + lr_scale = self._get_lr_scale() + for pg in opt_g.param_groups: + pg['lr'] = self.hparams.gen_lr * lr_scale + + opt_g.zero_grad() + self.manual_backward(loss_g) + torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0) + opt_g.step() + + # Update EMA + self._update_ema() + + # ---------------------------------------------------------------- + # Discriminator step + # ---------------------------------------------------------------- + loss_d = torch.tensor(0.0, device=self.device) + loss_crop_d = torch.tensor(0.0, device=self.device) + loss_uncond_d = torch.tensor(0.0, device=self.device) + if self.global_step >= self.hparams.adversarial_start_step and any_adv: + with torch.no_grad(): + fake_detached = self.generator(he, uni_dropped, labels_dropped) + + # For 1024, downsample for disc + if img_sz == 1024: + fake_det_disc = F.interpolate(fake_detached, size=512, mode='bilinear', align_corners=False) + else: + fake_det_disc = fake_detached + + # Conditional discriminator + if self.hparams.adversarial_weight > 0: + real_input = torch.cat([her2_for_disc, he_for_disc], dim=1) + fake_input = torch.cat([fake_det_disc, he_for_disc], dim=1) + + disc_real = self.discriminator(real_input) + disc_fake = self.discriminator(fake_input) + + loss_d = sum( + hinge_loss_d(dr, df) + for dr, df in zip(disc_real, disc_fake) + ) / len(disc_real) + + # Crop discriminator + if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0: + real_input_c = torch.cat([her2, he], dim=1) + fake_input_c = torch.cat([fake_detached, he], dim=1) + cs = self.hparams.crop_size + top = torch.randint(0, img_sz - cs, (1,)).item() + left = torch.randint(0, img_sz - cs, (1,)).item() + real_crop = real_input_c[:, :, top:top+cs, left:left+cs] + fake_crop = fake_input_c[:, :, top:top+cs, left:left+cs] + loss_crop_d = hinge_loss_d( + self.crop_discriminator(real_crop), + self.crop_discriminator(fake_crop), + ) + loss_d = loss_d + self.hparams.crop_disc_weight * loss_crop_d + + # Unconditional discriminator + if self.uncond_discriminator is not None and ( + self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0): + uncond_real_out = self.uncond_discriminator(her2_for_disc) + uncond_fake_out = self.uncond_discriminator(fake_det_disc) + loss_uncond_d = hinge_loss_d(uncond_real_out, uncond_fake_out) + loss_d = loss_d + max(self.hparams.uncond_disc_weight, 1.0) * loss_uncond_d + + # R1 gradient penalty + loss_r1 = torch.tensor(0.0, device=self.device) + if self.global_step % self.hparams.r1_every == 0: + with torch.amp.autocast('cuda', enabled=False): + if self.hparams.adversarial_weight > 0: + real_input_r1 = torch.cat([her2_for_disc, he_for_disc], dim=1).float().detach().requires_grad_(True) + for disc in [self.discriminator.disc_512]: + d_real = disc(real_input_r1) + grad_real = torch.autograd.grad( + outputs=d_real.sum(), inputs=real_input_r1, + create_graph=True, + )[0] + loss_r1 = loss_r1 + self.hparams.r1_weight * grad_real.pow(2).mean() + if self.uncond_discriminator is not None and ( + self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0): + her2_r1 = her2_for_disc.float().detach().requires_grad_(True) + d_real_uncond = self.uncond_discriminator(her2_r1) + grad_uncond = torch.autograd.grad( + outputs=d_real_uncond.sum(), inputs=her2_r1, + create_graph=True, + )[0] + loss_r1 = loss_r1 + self.hparams.r1_weight * grad_uncond.pow(2).mean() + loss_d = loss_d + loss_r1 + self.log('train/r1_penalty', loss_r1, prog_bar=False) + + opt_d.zero_grad() + self.manual_backward(loss_d) + if self.hparams.adversarial_weight > 0: + torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0) + if self.crop_discriminator is not None: + torch.nn.utils.clip_grad_norm_(self.crop_discriminator.parameters(), 1.0) + if self.uncond_discriminator is not None: + torch.nn.utils.clip_grad_norm_(self.uncond_discriminator.parameters(), 1.0) + opt_d.step() + + # Logging + self.log('train/loss_g', loss_g, prog_bar=True) + self.log('train/loss_d', loss_d, prog_bar=True) + self.log('train/lpips', loss_lpips, prog_bar=True) + self.log('train/adversarial', loss_adv, prog_bar=False) + self.log('train/lr_scale', lr_scale, prog_bar=False) + if self.crop_discriminator is not None: + self.log('train/crop_adv_g', loss_crop_adv, prog_bar=False) + self.log('train/crop_adv_d', loss_crop_d, prog_bar=False) + if self.uncond_discriminator is not None: + self.log('train/uncond_adv_g', loss_uncond_adv, prog_bar=False) + self.log('train/uncond_adv_d', loss_uncond_d, prog_bar=False) + if self.hparams.feat_match_weight > 0: + self.log('train/feat_match', loss_feat_match, prog_bar=False) + + def on_validation_epoch_start(self): + # Pick a random batch index for the second sample grid + n_val_batches = max(1, len(self.trainer.val_dataloaders)) + self._random_val_batch_idx = torch.randint(1, max(2, n_val_batches), (1,)).item() + # Per-label sample collectors (for multi-stain visual grids) + self._val_per_label_samples = {} + + def _log_sample_grid(self, he, her2_01, gen_01, key): + """Log H&E | Real | Gen grid to wandb.""" + n = min(4, len(he)) + he_01 = ((he[:n].cpu() + 1) / 2).clamp(0, 1) + grid_images = [] + for i in range(n): + grid_images.extend([ + he_01[i], + her2_01[i].cpu(), + gen_01[i].cpu(), + ]) + grid = torchvision.utils.make_grid(grid_images, nrow=3, padding=2) + if self.logger: + self.logger.experiment.log({ + key: [wandb.Image(grid, caption='H&E | Real | Gen')], + 'global_step': self.global_step, + }) + + def validation_step(self, batch, batch_idx): + he, her2, uni_or_crops, labels, fnames = batch + + # On-the-fly UNI extraction + if self._uni_extract_on_the_fly: + uni = self._extract_uni_from_sub_crops(uni_or_crops) + else: + uni = uni_or_crops + + if self.hparams.disable_uni: + uni = torch.zeros_like(uni) + + if self.hparams.disable_class: + labels = torch.full_like(labels, self.hparams.null_class) + + # Use EMA generator + with torch.no_grad(): + generated = self.generator_ema(he, uni, labels) + + # LPIPS (4x downsample: 128 for 512, 256 for 1024) + lpips_size = self.hparams.image_size // 4 + gen_lpips = F.interpolate(generated, size=lpips_size, mode='bilinear', align_corners=False) + her2_lpips = F.interpolate(her2, size=lpips_size, mode='bilinear', align_corners=False) + lpips_val = self.lpips_fn(gen_lpips, her2_lpips).mean() + + # SSIM + gen_01 = ((generated + 1) / 2).clamp(0, 1) + her2_01 = ((her2 + 1) / 2).clamp(0, 1) + from torchmetrics.functional.image import structural_similarity_index_measure + ssim_val = structural_similarity_index_measure(gen_01, her2_01, data_range=1.0) + + # DAB MAE (canonical: mean of top-10%) + dab_gen = self.dab_extractor.extract_dab_intensity(generated.float().cpu(), normalize="none") + dab_real = self.dab_extractor.extract_dab_intensity(her2.float().cpu(), normalize="none") + + def p90_score(dab): + flat = dab.flatten() + p90 = torch.quantile(flat, 0.9) + mask = flat >= p90 + return flat[mask].mean().item() if mask.sum() > 0 else flat.mean().item() + + dab_mae = sum( + abs(p90_score(dab_gen[i]) - p90_score(dab_real[i])) + for i in range(len(dab_gen)) + ) / len(dab_gen) + + self.log('val/lpips', lpips_val, prog_bar=True, sync_dist=True) + self.log('val/ssim', ssim_val, prog_bar=True, sync_dist=True) + self.log('val/dab_mae', dab_mae, prog_bar=True, sync_dist=True) + + # Collect per-label samples for visual grids (multi-stain only) + if hasattr(self, '_val_per_label_samples'): + for i in range(len(labels)): + lbl = labels[i].item() + if lbl == self.hparams.null_class: + continue + if lbl not in self._val_per_label_samples: + self._val_per_label_samples[lbl] = {'he': [], 'real': [], 'gen': []} + bucket = self._val_per_label_samples[lbl] + if len(bucket['he']) < 4: + bucket['he'].append(he[i].cpu()) + bucket['real'].append(her2_01[i].cpu()) + bucket['gen'].append(gen_01[i].cpu()) + + # Log sample grids: first batch (fixed) + one random batch + if batch_idx == 0: + self._log_sample_grid(he, her2_01, gen_01, 'val/samples_fixed') + elif batch_idx == self._random_val_batch_idx: + self._log_sample_grid(he, her2_01, gen_01, 'val/samples_random') + + def on_validation_epoch_end(self): + """Log per-label sample grids if multiple labels are present.""" + if not hasattr(self, '_val_per_label_samples') or len(self._val_per_label_samples) <= 1: + return + + label_names = getattr(self.hparams, 'label_names', None) + + for lbl, bucket in sorted(self._val_per_label_samples.items()): + if not bucket['he'] or not self.logger: + continue + name = label_names[lbl] if label_names and lbl < len(label_names) else str(lbl) + self._log_sample_grid( + torch.stack(bucket['he']), + torch.stack(bucket['real']), + torch.stack(bucket['gen']), + f'val/samples_{name}', + ) + + self._val_per_label_samples = {} + + @torch.no_grad() + def generate(self, he_images, uni_features, labels, + num_inference_steps=None, guidance_scale=1.0, seed=None): + """Generate IHC images from H&E input. + + Args: + he_images: [B, 3, H, H] where H=512 or H=1024 + uni_features: [B, N, 1024] where N=16 (4x4 CLS) or N=1024 (32x32 patch) + labels: [B] class/stain labels + num_inference_steps: ignored (single forward pass) + guidance_scale: CFG scale (1.0 = no guidance) + seed: random seed (for reproducibility, though model is deterministic) + """ + if seed is not None: + torch.manual_seed(seed) + + gen = self.generator_ema if hasattr(self, 'generator_ema') else self.generator + + if guidance_scale <= 1.0: + return gen(he_images, uni_features, labels) + + # Classifier-free guidance + null_labels = torch.full_like(labels, self.null_class) + + output_cond = gen(he_images, uni_features, labels) + output_uncond = gen(he_images, uni_features, null_labels) + + output = output_uncond + guidance_scale * (output_cond - output_uncond) + return output.clamp(-1, 1) diff --git a/src/models/uni_processor.py b/src/models/uni_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..a268c709c15d8fef80110c45e1ce3f41e6ab027d --- /dev/null +++ b/src/models/uni_processor.py @@ -0,0 +1,226 @@ +""" +UNI feature processors: transform UNI pathology features into multi-scale spatial maps. + +- UNIFeatureProcessor: for CLS-token features (4x4 = 16 tokens) +- UNIFeatureProcessorHighRes: for patch-token features (32x32 = 1024 tokens) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNIFeatureProcessor(nn.Module): + """Process UNI features [B, 16, 1024] → multi-scale spatial feature maps. + + UNI produces 16 spatial tokens (4x4 grid) of 1024-dim. We project to + generator channel dim and upsample to match each decoder layer resolution. + """ + + def __init__(self, uni_dim=1024, base_channels=512): + super().__init__() + self.base_channels = base_channels + + # Project UNI features to generator channel dim + self.proj = nn.Sequential( + nn.Linear(uni_dim, base_channels), + nn.LeakyReLU(0.2, inplace=True), + ) + + # Multi-scale upsamplers: 4×4 → {16, 32, 64, 128, 256} + # Each stage doubles spatial resolution + ch = base_channels + + # 4→8→16 + self.up_16 = nn.Sequential( + nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + # 16→32 + self.up_32 = nn.Sequential( + nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + # 32→64 + ch_64 = base_channels // 2 # 256 + self.up_64 = nn.Sequential( + nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + # 64→128 + ch_128 = base_channels // 4 # 128 + self.up_128 = nn.Sequential( + nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + # 128→256 + ch_256 = base_channels // 8 # 64 + self.up_256 = nn.Sequential( + nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, uni_features): + """ + Args: + uni_features: [B, 16, 1024] + + Returns: + dict of spatial feature maps at each resolution + """ + B = uni_features.shape[0] + + # Project and reshape to spatial + x = self.proj(uni_features) # [B, 16, 512] + x = x.permute(0, 2, 1).reshape(B, self.base_channels, 4, 4) # [B, 512, 4, 4] + + # Multi-scale upsampling + feat_16 = self.up_16(x) # [B, 512, 16, 16] + feat_32 = self.up_32(feat_16) # [B, 512, 32, 32] + feat_64 = self.up_64(feat_32) # [B, 256, 64, 64] + feat_128 = self.up_128(feat_64) # [B, 128, 128, 128] + feat_256 = self.up_256(feat_128) # [B, 64, 256, 256] + + return { + 16: feat_16, + 32: feat_32, + 64: feat_64, + 128: feat_128, + 256: feat_256, + } + + +class UNIFeatureProcessorHighRes(nn.Module): + """Process high-res UNI features [B, 1024, 1024] → multi-scale spatial maps. + + With patch-token extraction, UNI produces 1024 tokens (32x32 spatial grid) + of 1024-dim — 64x more spatial resolution than the CLS-only 4x4 grid. + + Since we START at 32x32, we process features with Conv2d (no hallucinated + upsampling). Every spatial feature is backed by real UNI patch tokens. + + Architecture: + 32x32 input → conv process → feat_32 (512ch) + 32→64 upsample → conv → feat_64 (256ch) + 64→128 upsample → conv → feat_128 (128ch) + 128→256 upsample → conv → feat_256 (64ch) + Also: 32→16 downsample → feat_16 (512ch, for bottleneck) + """ + + def __init__(self, uni_dim=1024, base_channels=512, spatial_size=32, + output_512=False): + super().__init__() + self.base_channels = base_channels + self.spatial_size = spatial_size + self.output_512 = output_512 + ch = base_channels + + # Project UNI 1024-dim → 512-dim per token + self.proj = nn.Sequential( + nn.Linear(uni_dim, ch), + nn.LeakyReLU(0.2, inplace=True), + ) + + # Process at 32x32 (native resolution) — refine projected features + self.proc_32 = nn.Sequential( + nn.Conv2d(ch, ch, 3, padding=1), + nn.InstanceNorm2d(ch), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ch, ch, 3, padding=1), + nn.InstanceNorm2d(ch), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 32→16 downsample (for bottleneck conditioning) + self.down_16 = nn.Sequential( + nn.Conv2d(ch, ch, 4, stride=2, padding=1), + nn.InstanceNorm2d(ch), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 32→64 upsample + refine + ch_64 = ch // 2 # 256 + self.up_64 = nn.Sequential( + nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1), + nn.InstanceNorm2d(ch_64), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ch_64, ch_64, 3, padding=1), + nn.InstanceNorm2d(ch_64), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 64→128 upsample + refine + ch_128 = ch // 4 # 128 + self.up_128 = nn.Sequential( + nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1), + nn.InstanceNorm2d(ch_128), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ch_128, ch_128, 3, padding=1), + nn.InstanceNorm2d(ch_128), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 128→256 upsample + refine + ch_256 = ch // 8 # 64 + self.up_256 = nn.Sequential( + nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1), + nn.InstanceNorm2d(ch_256), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ch_256, ch_256, 3, padding=1), + nn.InstanceNorm2d(ch_256), + nn.LeakyReLU(0.2, inplace=True), + ) + + # 256→512 upsample (for 1024 models with SPADE at dec1) + if output_512: + ch_512 = ch // 16 # 32 + self.up_512 = nn.Sequential( + nn.ConvTranspose2d(ch_256, ch_512, 4, stride=2, padding=1), + nn.InstanceNorm2d(ch_512), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ch_512, ch_512, 3, padding=1), + nn.InstanceNorm2d(ch_512), + nn.LeakyReLU(0.2, inplace=True), + ) + + def forward(self, uni_features): + """ + Args: + uni_features: [B, S*S, 1024] where S = spatial_size (default 32) + + Returns: + dict of spatial feature maps: {16, 32, 64, 128, 256} + """ + B = uni_features.shape[0] + S = self.spatial_size + + # Project and reshape to spatial grid + x = self.proj(uni_features) # [B, S*S, 512] + x = x.permute(0, 2, 1).reshape(B, self.base_channels, S, S) # [B, 512, 32, 32] + + # Process at native 32x32 + feat_32 = self.proc_32(x) + x # residual connection + + # Downsample for bottleneck + feat_16 = self.down_16(feat_32) # [B, 512, 16, 16] + + # Upsample path — each level adds spatial detail from real UNI tokens + feat_64 = self.up_64(feat_32) # [B, 256, 64, 64] + feat_128 = self.up_128(feat_64) # [B, 128, 128, 128] + feat_256 = self.up_256(feat_128) # [B, 64, 256, 256] + + out = { + 16: feat_16, + 32: feat_32, + 64: feat_64, + 128: feat_128, + 256: feat_256, + } + + if self.output_512: + feat_512 = self.up_512(feat_256) # [B, 32, 512, 512] + out[512] = feat_512 + + return out diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/utils/dab.py b/src/utils/dab.py new file mode 100644 index 0000000000000000000000000000000000000000..469c9f8b2c0cc24316d7d464d7fa2a6ebc6f6505 --- /dev/null +++ b/src/utils/dab.py @@ -0,0 +1,85 @@ +""" +DAB (3,3'-Diaminobenzidine) stain extraction via color deconvolution. + +Reference: Ruifrok & Johnston, "Quantification of histochemical +staining by color deconvolution", Anal Quant Cytol Histol 2001 +""" + +import torch +import torch.nn.functional as F + + +class DABExtractor: + """Extract DAB stain intensity from IHC images using color deconvolution. + + Uses the Ruifrok & Johnston H-DAB stain matrix with softplus smoothing + for differentiable training loss computation. + """ + + def __init__(self, device='cuda'): + self.device = device + + # Standard H-DAB stain matrix (Ruifrok & Johnston) + # Each row is a stain vector in RGB optical density space + self.stain_matrix = torch.tensor([ + [0.268, 0.570, 0.776], # DAB (brown) + [0.650, 0.704, 0.286], # Hematoxylin (blue) + ], device=device, dtype=torch.float32) + + # Pseudo-inverse for deconvolution: [3, 2] + self.deconv_matrix = torch.linalg.pinv(self.stain_matrix.T) + + def rgb_to_od(self, rgb_images: torch.Tensor) -> torch.Tensor: + """Convert RGB [0,1] to optical density: OD = -log10(I/I0).""" + rgb_images = rgb_images.clamp(1e-6, 1.0) + return -torch.log10(rgb_images + 1e-6) + + def extract_dab_intensity( + self, + images: torch.Tensor, + normalize: str = "max" + ) -> torch.Tensor: + """Extract DAB stain intensity from IHC images. + + Args: + images: [B, 3, H, W] RGB images in [-1, 1] or [0, 1] + normalize: "none", "max", or "meanstd" + + Returns: + dab_intensity: [B, H, W] DAB intensity map + """ + B, C, H, W = images.shape + assert C == 3, "Input must be RGB images" + + # Auto-convert [-1, 1] -> [0, 1] if needed + if images.min() < 0: + images = (images + 1.0) / 2.0 + + od = self.rgb_to_od(images) + od_flat = od.permute(0, 2, 3, 1).reshape(-1, 3) + + # Ensure deconv_matrix is on same device as input + deconv_matrix = self.deconv_matrix.to(od_flat.device) + + # Deconvolve: concentrations = OD @ M_inv^T + concentrations = od_flat @ deconv_matrix.T + dab_flat = concentrations[:, 0] # DAB channel + + dab_intensity = dab_flat.reshape(B, H, W) + + # Softplus for smooth gradients (beta=5.0 for sharper transition) + dab = F.softplus(dab_intensity, beta=5.0) + + if normalize == "max" or normalize is True: + mx = dab.amax(dim=(1, 2), keepdim=True).clamp(min=1e-6) + dab = dab / mx + elif normalize == "meanstd": + mean = dab.mean(dim=(1, 2), keepdim=True) + std = dab.std(dim=(1, 2), keepdim=True).clamp(min=1e-6) + dab = (dab - mean) / std + elif normalize == "none" or normalize is False: + pass + else: + raise ValueError(f"Unknown normalization: {normalize}") + + return dab diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..0b697c25627cf54254b4576ffbf4b50c6e7a457f --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,528 @@ +""" +Consolidated evaluation metrics for UNIStainNet. + +Provides standardized metric computation for both BCI and MIST datasets. + +Metrics: + - Image quality: FID (Inception + UNI), KID, LPIPS (128+512), SSIM, PSNR + - DAB staining: KL divergence, JSD, Pearson-r, MAE (per-pair, 256-bin histograms) + - Optical density: IOD, mIOD, FOD (PSPStain, MICCAI 2024) + - Downstream: AUROC, SFS via UNI linear probe (Star-Diff, arXiv 2025) + +References: + - FID: Heusel et al., "GANs Trained by a Two Time-Scale Update Rule" (NeurIPS 2017) + - KID: Binkowski et al., "Demystifying MMD GANs" (ICLR 2018) + - LPIPS: Zhang et al., "The Unreasonable Effectiveness of Deep Features" (CVPR 2018) + - DAB KL: Liu et al., "ODA-GAN" (Med Image Anal 2024) — per-pair 256-bin histograms + - IOD/mIOD/FOD: Zhan et al., "PSPStain" (MICCAI 2024) — Beer-Lambert optical density + - AUROC/SFS: Wu et al., "Star-Diff" (arXiv 2025) — UNI linear probe downstream task + - DAB deconvolution: Ruifrok & Johnston, Anal Quant Cytol Histol (2001) +""" + +import os +from pathlib import Path + +import torch +import torch.nn.functional as F +import torchvision +import numpy as np +from scipy.stats import entropy, pearsonr + +from src.utils.dab import DABExtractor + + +# ====================================================================== +# p90 DAB score (canonical: mean of top-10%) +# ====================================================================== + +def compute_p90_scores(dab_maps): + """Compute canonical p90 DAB scores: mean of pixels >= 90th percentile. + + This is the canonical p90 metric used throughout the paper. For each image, + we find the 90th percentile threshold and return the mean of all pixels + at or above that threshold — i.e., the mean of the top-10%. + + Args: + dab_maps: [B, H, W] raw DAB intensity maps (normalize="none") + + Returns: + scores: numpy array of shape [B] with per-image p90 scores + """ + scores = [] + for i in range(dab_maps.shape[0]): + flat = dab_maps[i].flatten() + p90 = torch.quantile(flat, 0.9) + mask = flat >= p90 + scores.append(flat[mask].mean().item() if mask.sum() > 0 else flat.mean().item()) + return np.array(scores) + + +# ====================================================================== +# Image quality metrics +# ====================================================================== + +def compute_image_quality_metrics(generated, real): + """FID, KID, LPIPS (full + 128px), SSIM, PSNR. + + Args: + generated: [N, 3, H, W] in [-1, 1] + real: [N, 3, H, W] in [-1, 1] + + Returns: + dict with metric values + """ + from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio + from torchmetrics.image.fid import FrechetInceptionDistance + from torchmetrics.image.kid import KernelInceptionDistance + from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + + gen_01 = ((generated + 1) / 2).clamp(0, 1) + real_01 = ((real + 1) / 2).clamp(0, 1) + N = len(generated) + results = {} + + # SSIM + ssim = StructuralSimilarityIndexMeasure(data_range=1.0) + ssim_vals = [] + for i in range(0, N, 16): + batch_vals = ssim(gen_01[i:i+16], real_01[i:i+16]) + ssim_vals.append(batch_vals.item()) + results['ssim_mean'] = float(np.mean(ssim_vals)) + results['ssim_std'] = float(np.std(ssim_vals)) + + # PSNR + psnr = PeakSignalNoiseRatio(data_range=1.0) + psnr_vals = [] + for i in range(0, N, 16): + batch_vals = psnr(gen_01[i:i+16], real_01[i:i+16]) + psnr_vals.append(batch_vals.item()) + results['psnr_mean'] = float(np.mean(psnr_vals)) + results['psnr_std'] = float(np.std(psnr_vals)) + + # LPIPS (full resolution) + lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type='alex') + lpips_vals = [] + for i in range(0, N, 8): + batch_gen = generated[i:i+8].float().clamp(-1, 1) + batch_real = real[i:i+8].float().clamp(-1, 1) + val = lpips_metric(batch_gen, batch_real).item() + if not np.isnan(val): + lpips_vals.append(val) + results['lpips_mean'] = float(np.mean(lpips_vals)) if lpips_vals else float('nan') + + # LPIPS downsampled (128x128) — more robust for weakly paired consecutive sections + lpips_ds_vals = [] + for i in range(0, N, 8): + batch_gen = F.interpolate(generated[i:i+8].float().clamp(-1, 1), size=128, mode='bilinear', align_corners=False) + batch_real = F.interpolate(real[i:i+8].float().clamp(-1, 1), size=128, mode='bilinear', align_corners=False) + val = lpips_metric(batch_gen, batch_real).item() + if not np.isnan(val): + lpips_ds_vals.append(val) + results['lpips_128_mean'] = float(np.mean(lpips_ds_vals)) if lpips_ds_vals else float('nan') + + # FID (Inception) + fid = FrechetInceptionDistance(feature=2048, normalize=True) + for i in range(0, N, 16): + fid.update(real_01[i:i+16], real=True) + fid.update(gen_01[i:i+16], real=False) + results['fid_inception'] = float(fid.compute().item()) + + # KID (Kernel Inception Distance) — unbiased, better for small N + kid = KernelInceptionDistance(feature=2048, normalize=True, subset_size=min(N, 100)) + for i in range(0, N, 16): + kid.update(real_01[i:i+16], real=True) + kid.update(gen_01[i:i+16], real=False) + kid_mean, kid_std = kid.compute() + results['kid_mean'] = float(kid_mean.item()) + results['kid_std'] = float(kid_std.item()) + results['kid_mean_x1000'] = float(kid_mean.item() * 1000) + results['kid_std_x1000'] = float(kid_std.item() * 1000) + + return results + + +# ====================================================================== +# UNI-FID (pathology-native Frechet distance) +# ====================================================================== + +def compute_uni_fid(generated, real): + """Frechet distance in UNI ViT-L/16 feature space. + + Uses CLS token features from UNI (Chen et al., Nature Medicine 2024) + as a pathology-specific alternative to Inception FID. + + Args: + generated, real: [N, 3, H, W] in [-1, 1] + + Returns: + float: UNI-FID value + """ + import timm + import torchvision.transforms as transforms + from scipy.linalg import sqrtm + + uni_model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, + init_values=1e-5, dynamic_img_size=True) + uni_model = uni_model.cuda().eval() + + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def extract_cls_features(images): + feats = [] + for i in range(0, len(images), 16): + batch = images[i:i+16] + batch_01 = ((batch + 1) / 2).clamp(0, 1) + batch_norm = torch.stack([transform(img) for img in batch_01]) + with torch.no_grad(): + out = uni_model.forward_features(batch_norm.cuda()) + feats.append(out[:, 0, :].cpu()) + return torch.cat(feats).numpy() + + feats_gen = extract_cls_features(generated) + feats_real = extract_cls_features(real) + + mu_gen, mu_real = feats_gen.mean(0), feats_real.mean(0) + sigma_gen = np.cov(feats_gen, rowvar=False) + sigma_real = np.cov(feats_real, rowvar=False) + + diff = mu_gen - mu_real + covmean = sqrtm(sigma_gen @ sigma_real) + if np.iscomplexobj(covmean): + covmean = covmean.real + + del uni_model + torch.cuda.empty_cache() + + return float(diff @ diff + np.trace(sigma_gen + sigma_real - 2 * covmean)) + + +# ====================================================================== +# DAB metrics +# ====================================================================== + +def compute_dab_metrics(generated, real, labels=None, dab_extractor=None): + """DAB intensity metrics: Pearson-r, MAE, KL, JSD. + + Uses canonical p90 scoring (mean of top-10%) for Pearson-r and MAE. + Uses per-pair 256-bin histograms for KL/JSD (ODA-GAN methodology). + + Args: + generated, real: [N, 3, H, W] in [-1, 1] + labels: [N] int class labels, or None for classless evaluation + dab_extractor: DABExtractor instance (created if None) + + Returns: + dict with DAB metric values + """ + if dab_extractor is None: + dab_extractor = DABExtractor(device='cpu') + + dab_gen = dab_extractor.extract_dab_intensity(generated.float(), normalize="none") + dab_real = dab_extractor.extract_dab_intensity(real.float(), normalize="none") + + gen_scores = compute_p90_scores(dab_gen) + real_scores = compute_p90_scores(dab_real) + + results = {} + results['dab_mae_overall'] = float(np.mean(np.abs(gen_scores - real_scores))) + + # Pearson-R + if len(gen_scores) > 2: + r, p_val = pearsonr(gen_scores, real_scores) + results['dab_pearson_r'] = float(r) + results['dab_pearson_p'] = float(p_val) + + # Per-pair DAB KL/JSD (ODA-GAN: 256-bin histogram per pair, averaged) + n_bins = 256 + eps = 1e-10 + pair_kls = [] + pair_jsds = [] + for i in range(dab_gen.shape[0]): + g = dab_gen[i].flatten().numpy() + r = dab_real[i].flatten().numpy() + hist_range = (0, max(g.max(), r.max()) + 1e-6) + hg, _ = np.histogram(g, bins=n_bins, range=hist_range, density=True) + hr, _ = np.histogram(r, bins=n_bins, range=hist_range, density=True) + hg = hg + eps; hr = hr + eps + hg = hg / hg.sum(); hr = hr / hr.sum() + pair_kls.append(float(entropy(hg, hr))) + m = 0.5 * (hg + hr) + pair_jsds.append(float(0.5 * entropy(hg, m) + 0.5 * entropy(hr, m))) + results['dab_kl'] = float(np.mean(pair_kls)) + results['dab_kl_std'] = float(np.std(pair_kls)) + results['dab_jsd'] = float(np.mean(pair_jsds)) + results['dab_jsd_std'] = float(np.std(pair_jsds)) + + # Pooled DAB KL (for reference) + dab_gen_flat = dab_gen.flatten().numpy() + dab_real_flat = dab_real.flatten().numpy() + hist_range = (0, max(dab_gen_flat.max(), dab_real_flat.max()) + 1e-6) + hist_gen, _ = np.histogram(dab_gen_flat, bins=n_bins, range=hist_range, density=True) + hist_real, _ = np.histogram(dab_real_flat, bins=n_bins, range=hist_range, density=True) + hist_gen = hist_gen + eps; hist_real = hist_real + eps + hist_gen = hist_gen / hist_gen.sum(); hist_real = hist_real / hist_real.sum() + results['dab_kl_pooled'] = float(entropy(hist_gen, hist_real)) + + # Mean DAB levels + results['dab_gen_mean'] = float(np.mean(gen_scores)) + results['dab_real_mean'] = float(np.mean(real_scores)) + + # Per-class metrics (BCI only — MIST passes labels=None) + if labels is not None: + class_names = {0: '0', 1: '1+', 2: '2+', 3: '3+'} + within_rs = [] + for cls, name in class_names.items(): + mask = (labels == cls).numpy() if isinstance(labels, torch.Tensor) else (labels == cls) + if mask.sum() > 0: + results[f'dab_real_class_{name}'] = float(np.mean(real_scores[mask])) + results[f'dab_gen_class_{name}'] = float(np.mean(gen_scores[mask])) + results[f'dab_mae_class_{name}'] = float(np.mean(np.abs( + gen_scores[mask] - real_scores[mask]))) + results[f'n_samples_class_{name}'] = int(mask.sum()) + # Within-class Pearson-R + if mask.sum() > 5: + r_cls, _ = pearsonr(gen_scores[mask], real_scores[mask]) + results[f'dab_pearson_r_class_{name}'] = float(r_cls) + within_rs.append(r_cls) + if within_rs: + results['dab_pearson_r_within_class'] = float(np.mean(within_rs)) + + # Ordering violation rate + class_gen_means = {} + for cls in range(4): + mask = (labels == cls).numpy() if isinstance(labels, torch.Tensor) else (labels == cls) + if mask.sum() > 0: + class_gen_means[cls] = float(np.mean(gen_scores[mask])) + ordered_pairs = [(3, 2), (3, 1), (3, 0), (2, 1), (2, 0), (1, 0)] + violations, total_pairs = 0, 0 + for high_cls, low_cls in ordered_pairs: + if high_cls in class_gen_means and low_cls in class_gen_means: + total_pairs += 1 + if class_gen_means[high_cls] < class_gen_means[low_cls]: + violations += 1 + results['ordering_violations'] = violations + results['ordering_total_pairs'] = total_pairs + + return results + + +# ====================================================================== +# IOD / mIOD / FOD metrics (PSPStain) +# ====================================================================== + +def compute_iod_metrics(generated, real, labels=None): + """Compute Integrated Optical Density metrics (PSPStain methodology). + + Beer-Lambert law: OD = -log10(I / I_0), I_0 = 255. + IOD = sum(OD), mIOD = mean(OD), FOD = OD^alpha with alpha=1.8. + + Args: + generated, real: [N, 3, H, W] in [-1, 1] + labels: [N] optional class labels for per-class breakdown + + Returns: + dict with IOD metric values + """ + gen_255 = (((generated + 1) / 2).clamp(0, 1) * 255.0).clamp(min=1.0) + real_255 = (((real + 1) / 2).clamp(0, 1) * 255.0).clamp(min=1.0) + + od_gen = -torch.log10(gen_255 / 255.0) + od_real = -torch.log10(real_255 / 255.0) + + miod_gen = od_gen.mean(dim=(1, 2, 3)).numpy() + miod_real = od_real.mean(dim=(1, 2, 3)).numpy() + + iod_gen = od_gen.sum(dim=(1, 2, 3)).numpy() + iod_real = od_real.sum(dim=(1, 2, 3)).numpy() + + alpha = 1.8 + fod_gen = od_gen.pow(alpha).mean(dim=(1, 2, 3)).numpy() + fod_real = od_real.pow(alpha).mean(dim=(1, 2, 3)).numpy() + + results = {} + results['miod_diff'] = float(np.mean(miod_gen) - np.mean(miod_real)) + results['miod_abs_diff'] = float(np.mean(np.abs(miod_gen - miod_real))) + results['miod_gen_mean'] = float(np.mean(miod_gen)) + results['miod_real_mean'] = float(np.mean(miod_real)) + results['iod_diff'] = float(np.mean(iod_gen) - np.mean(iod_real)) + results['iod_diff_1e7'] = float(results['iod_diff'] / 1e7) + results['mfod_diff'] = float(np.mean(fod_gen) - np.mean(fod_real)) + results['mfod_abs_diff'] = float(np.mean(np.abs(fod_gen - fod_real))) + + if len(miod_gen) > 2: + r, p = pearsonr(miod_gen, miod_real) + results['iod_pearson_r'] = float(r) + + # Per-class mIOD (BCI only) + if labels is not None: + class_names = {0: '0', 1: '1+', 2: '2+', 3: '3+'} + for cls, name in class_names.items(): + mask = (labels == cls).numpy() if isinstance(labels, torch.Tensor) else (labels == cls) + if mask.sum() > 0: + results[f'miod_gen_class_{name}'] = float(np.mean(miod_gen[mask])) + results[f'miod_real_class_{name}'] = float(np.mean(miod_real[mask])) + results[f'miod_diff_class_{name}'] = float( + np.mean(miod_gen[mask]) - np.mean(miod_real[mask])) + + return results + + +# ====================================================================== +# Downstream classifier (AUROC / SFS) +# ====================================================================== + +def compute_downstream_metrics(generated, real, labels, train_ihc_dir): + """AUROC and SFS via UNI linear probe (Star-Diff methodology). + + 1. Extract UNI CLS features from real HER2 training images + 2. Train logistic regression on real train features + 3. Evaluate on generated test images (AUROC, SFS) + 4. Evaluate on real test images as reference + + Args: + generated, real: [N, 3, H, W] in [-1, 1] + labels: [N] class labels + train_ihc_dir: path to real HER2 IHC training images + + Returns: + dict with downstream metric values + """ + from sklearn.linear_model import LogisticRegression + from sklearn.metrics import roc_auc_score, balanced_accuracy_score + from sklearn.preprocessing import label_binarize + import timm + import torchvision.transforms as transforms + from PIL import Image + + results = {} + + # Load UNI model + print(" Loading UNI model for downstream evaluation...") + uni_model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, + init_values=1e-5, dynamic_img_size=True) + uni_model = uni_model.cuda().eval() + + uni_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + def extract_features_from_tensors(images): + feats = [] + for i in range(0, len(images), 16): + batch = images[i:i+16] + batch_01 = ((batch + 1) / 2).clamp(0, 1) + batch_norm = torch.stack([uni_transform(img) for img in batch_01]) + with torch.no_grad(): + out = uni_model.forward_features(batch_norm.cuda()) + feats.append(out[:, 0, :].cpu()) + return torch.cat(feats).numpy() + + def extract_features_from_dir(img_dir): + import torchvision.transforms.functional as TF + img_dir = Path(img_dir) + filenames = sorted([f for f in os.listdir(img_dir) if f.endswith('.png')]) + feats, labs = [], [] + label_map = {'0': 0, '1+': 1, '2+': 2, '3+': 3} + batch_imgs, batch_labs = [], [] + for fn in filenames: + img = Image.open(img_dir / fn).convert('RGB') + img_t = transforms.ToTensor()(img) + img_n = uni_transform(img_t) + batch_imgs.append(img_n) + parts = fn.replace('.png', '').split('_') + batch_labs.append(label_map[parts[2]]) + if len(batch_imgs) == 16: + batch = torch.stack(batch_imgs) + with torch.no_grad(): + out = uni_model.forward_features(batch.cuda()) + feats.append(out[:, 0, :].cpu()) + labs.extend(batch_labs) + batch_imgs, batch_labs = [], [] + if batch_imgs: + batch = torch.stack(batch_imgs) + with torch.no_grad(): + out = uni_model.forward_features(batch.cuda()) + feats.append(out[:, 0, :].cpu()) + labs.extend(batch_labs) + return torch.cat(feats).numpy(), np.array(labs) + + # Extract features from real training images + print(f" Extracting features from training IHC images...") + train_feats, train_labels = extract_features_from_dir(train_ihc_dir) + + # Train linear probe + print(f" Training linear probe on {len(train_labels)} samples...") + clf = LogisticRegression(max_iter=1000, C=1.0, solver='lbfgs', + multi_class='multinomial', random_state=42) + clf.fit(train_feats, train_labels) + results['probe_train_acc'] = float(clf.score(train_feats, train_labels)) + + # Evaluate on generated + gen_feats = extract_features_from_tensors(generated) + test_labels = labels.numpy() if isinstance(labels, torch.Tensor) else labels + gen_probs = clf.predict_proba(gen_feats) + gen_preds = clf.predict(gen_feats) + + test_labels_bin = label_binarize(test_labels, classes=[0, 1, 2, 3]) + try: + results['auroc'] = float(roc_auc_score(test_labels_bin, gen_probs, + multi_class='ovr', average='macro')) + except ValueError: + pass + + results['sfs'] = float(balanced_accuracy_score(test_labels, gen_preds)) + + # Evaluate on real (reference baseline) + real_feats = extract_features_from_tensors(real) + real_probs = clf.predict_proba(real_feats) + real_preds = clf.predict(real_feats) + try: + results['auroc_real_baseline'] = float(roc_auc_score( + test_labels_bin, real_probs, multi_class='ovr', average='macro')) + except ValueError: + pass + results['sfs_real_baseline'] = float(balanced_accuracy_score(test_labels, real_preds)) + + del uni_model + torch.cuda.empty_cache() + + return results + + +# ====================================================================== +# Visualization +# ====================================================================== + +def save_sample_grid(he, real, generated, path, n=16): + """Save H&E | Real IHC | Generated grid for visual inspection.""" + n = min(n, len(he)) + grid_images = [] + for i in range(n): + grid_images.extend([ + ((he[i] + 1) / 2).clamp(0, 1), + ((real[i] + 1) / 2).clamp(0, 1), + ((generated[i] + 1) / 2).clamp(0, 1), + ]) + grid = torchvision.utils.make_grid(grid_images, nrow=3, padding=2) + torchvision.utils.save_image(grid, path) + print(f" Sample grid ({n} samples): {path}") + + +def composite_background(generated, he_images, threshold=0.85): + """Replace background regions in generated images with white. + + Uses H&E brightness to identify background (glass slide), then + forces those regions to white in the generated output. + """ + he_01 = ((he_images + 1) / 2).clamp(0, 1) + brightness = he_01.mean(dim=1, keepdim=True) + tissue = (brightness < threshold).float() + tissue = F.avg_pool2d(tissue, kernel_size=7, stride=1, padding=3) + tissue = (tissue > 0.3).float() + tissue = F.avg_pool2d(tissue, kernel_size=11, stride=1, padding=5) + return generated * tissue + 1.0 * (1.0 - tissue)