Spaces:
Running
Running
File size: 5,722 Bytes
4db9215 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Loss functions for UNIStainNet.
- VGGFeatureExtractor: intermediate VGG16 features for Gram-matrix style loss
- gram_matrix: compute Gram matrix of feature maps
- PatchNCELoss: contrastive loss between H&E input and generated output (alignment-free)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class VGGFeatureExtractor(nn.Module):
"""Extract intermediate VGG16 features for Gram-matrix style loss.
Uses early VGG layers (relu1_2, relu2_2, relu3_3) which capture texture
at different scales. Gram matrices of these features are alignment-invariant
texture descriptors β they measure feature co-occurrence statistics, not
spatial layout (Gatys et al., 2016).
"""
def __init__(self):
super().__init__()
from torchvision.models import vgg16, VGG16_Weights
vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
# Extract at relu1_2 (idx 4), relu2_2 (idx 9), relu3_3 (idx 16)
self.slice1 = nn.Sequential(*list(vgg.children())[:4]) # β relu1_2
self.slice2 = nn.Sequential(*list(vgg.children())[4:9]) # β relu2_2
self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # β relu3_3
# Freeze
for p in self.parameters():
p.requires_grad = False
self.eval()
# ImageNet normalization
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""
Args:
x: [B, 3, H, W] in [-1, 1]
Returns:
list of feature maps at 3 scales
"""
# Normalize: [-1,1] β [0,1] β ImageNet
x = (x + 1) / 2
x = (x - self.mean) / self.std
f1 = self.slice1(x)
f2 = self.slice2(f1)
f3 = self.slice3(f2)
return [f1, f2, f3]
def gram_matrix(feat):
"""Compute Gram matrix of feature map.
Args:
feat: [B, C, H, W]
Returns:
gram: [B, C, C] β normalized by spatial size
"""
B, C, H, W = feat.shape
feat_flat = feat.reshape(B, C, H * W) # [B, C, N]
gram = torch.bmm(feat_flat, feat_flat.transpose(1, 2)) # [B, C, C]
return gram / (C * H * W)
class PatchNCELoss(nn.Module):
"""Patchwise Noise Contrastive Estimation loss.
Compares H&E input and generated IHC through the generator's encoder.
For each spatial position in the generated features, the corresponding
position in the H&E features is the positive, and random other positions
are negatives. Never sees GT IHC.
Reference: Park et al., "Contrastive Learning for Unpaired Image-to-Image
Translation" (ECCV 2020) β adapted for paired (misaligned) setting.
"""
def __init__(self, layer_channels, num_patches=256, temperature=0.07):
"""
Args:
layer_channels: dict {layer_idx: channels} for each encoder layer
num_patches: number of spatial positions to sample per layer
temperature: InfoNCE temperature
"""
super().__init__()
self.num_patches = num_patches
self.temperature = temperature
# 2-layer MLP projection head per encoder layer
self.mlps = nn.ModuleDict()
for layer_idx, ch in layer_channels.items():
self.mlps[str(layer_idx)] = nn.Sequential(
nn.Linear(ch, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 256),
)
def forward(self, feats_src, feats_tgt):
"""Compute PatchNCE loss across encoder layers.
Args:
feats_src: dict {layer_idx: [B, C, H, W]} from H&E input
feats_tgt: dict {layer_idx: [B, C, H, W]} from generated IHC
Returns:
scalar loss
"""
total_loss = 0.0
n_layers = 0
for layer_idx_str, mlp in self.mlps.items():
layer_idx = int(layer_idx_str)
feat_src = feats_src[layer_idx] # [B, C, H, W]
feat_tgt = feats_tgt[layer_idx] # [B, C, H, W]
B, C, H, W = feat_src.shape
n_total = H * W
# Reshape to [B, C, H*W] then [B, H*W, C]
src_flat = feat_src.flatten(2).permute(0, 2, 1) # [B, HW, C]
tgt_flat = feat_tgt.flatten(2).permute(0, 2, 1) # [B, HW, C]
# Sample random spatial positions
n_sample = min(self.num_patches, n_total)
idx = torch.randperm(n_total, device=feat_src.device)[:n_sample]
src_sampled = src_flat[:, idx, :] # [B, n_sample, C]
tgt_sampled = tgt_flat[:, idx, :] # [B, n_sample, C]
# Project through MLP
src_proj = mlp(src_sampled) # [B, n_sample, 256]
tgt_proj = mlp(tgt_sampled) # [B, n_sample, 256]
# L2 normalize
src_proj = F.normalize(src_proj, dim=-1)
tgt_proj = F.normalize(tgt_proj, dim=-1)
# InfoNCE: for each query (tgt), positive is matching src position
# negatives are all other src positions
# logits: [B, n_sample, n_sample] β (i,j) = similarity of tgt_i to src_j
logits = torch.bmm(tgt_proj, src_proj.transpose(1, 2)) # [B, n, n]
logits = logits / self.temperature
# Target: diagonal (position i matches position i)
target = torch.arange(n_sample, device=logits.device).unsqueeze(0).expand(B, -1)
loss = F.cross_entropy(logits.flatten(0, 1), target.flatten(0, 1))
total_loss = total_loss + loss
n_layers += 1
return total_loss / n_layers if n_layers > 0 else total_loss
|