UNIStainNet / src /models /losses.py
faceless-void's picture
Upload folder using huggingface_hub
4db9215 verified
"""
Loss functions for UNIStainNet.
- VGGFeatureExtractor: intermediate VGG16 features for Gram-matrix style loss
- gram_matrix: compute Gram matrix of feature maps
- PatchNCELoss: contrastive loss between H&E input and generated output (alignment-free)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class VGGFeatureExtractor(nn.Module):
"""Extract intermediate VGG16 features for Gram-matrix style loss.
Uses early VGG layers (relu1_2, relu2_2, relu3_3) which capture texture
at different scales. Gram matrices of these features are alignment-invariant
texture descriptors β€” they measure feature co-occurrence statistics, not
spatial layout (Gatys et al., 2016).
"""
def __init__(self):
super().__init__()
from torchvision.models import vgg16, VGG16_Weights
vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
# Extract at relu1_2 (idx 4), relu2_2 (idx 9), relu3_3 (idx 16)
self.slice1 = nn.Sequential(*list(vgg.children())[:4]) # β†’ relu1_2
self.slice2 = nn.Sequential(*list(vgg.children())[4:9]) # β†’ relu2_2
self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # β†’ relu3_3
# Freeze
for p in self.parameters():
p.requires_grad = False
self.eval()
# ImageNet normalization
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""
Args:
x: [B, 3, H, W] in [-1, 1]
Returns:
list of feature maps at 3 scales
"""
# Normalize: [-1,1] β†’ [0,1] β†’ ImageNet
x = (x + 1) / 2
x = (x - self.mean) / self.std
f1 = self.slice1(x)
f2 = self.slice2(f1)
f3 = self.slice3(f2)
return [f1, f2, f3]
def gram_matrix(feat):
"""Compute Gram matrix of feature map.
Args:
feat: [B, C, H, W]
Returns:
gram: [B, C, C] β€” normalized by spatial size
"""
B, C, H, W = feat.shape
feat_flat = feat.reshape(B, C, H * W) # [B, C, N]
gram = torch.bmm(feat_flat, feat_flat.transpose(1, 2)) # [B, C, C]
return gram / (C * H * W)
class PatchNCELoss(nn.Module):
"""Patchwise Noise Contrastive Estimation loss.
Compares H&E input and generated IHC through the generator's encoder.
For each spatial position in the generated features, the corresponding
position in the H&E features is the positive, and random other positions
are negatives. Never sees GT IHC.
Reference: Park et al., "Contrastive Learning for Unpaired Image-to-Image
Translation" (ECCV 2020) β€” adapted for paired (misaligned) setting.
"""
def __init__(self, layer_channels, num_patches=256, temperature=0.07):
"""
Args:
layer_channels: dict {layer_idx: channels} for each encoder layer
num_patches: number of spatial positions to sample per layer
temperature: InfoNCE temperature
"""
super().__init__()
self.num_patches = num_patches
self.temperature = temperature
# 2-layer MLP projection head per encoder layer
self.mlps = nn.ModuleDict()
for layer_idx, ch in layer_channels.items():
self.mlps[str(layer_idx)] = nn.Sequential(
nn.Linear(ch, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 256),
)
def forward(self, feats_src, feats_tgt):
"""Compute PatchNCE loss across encoder layers.
Args:
feats_src: dict {layer_idx: [B, C, H, W]} from H&E input
feats_tgt: dict {layer_idx: [B, C, H, W]} from generated IHC
Returns:
scalar loss
"""
total_loss = 0.0
n_layers = 0
for layer_idx_str, mlp in self.mlps.items():
layer_idx = int(layer_idx_str)
feat_src = feats_src[layer_idx] # [B, C, H, W]
feat_tgt = feats_tgt[layer_idx] # [B, C, H, W]
B, C, H, W = feat_src.shape
n_total = H * W
# Reshape to [B, C, H*W] then [B, H*W, C]
src_flat = feat_src.flatten(2).permute(0, 2, 1) # [B, HW, C]
tgt_flat = feat_tgt.flatten(2).permute(0, 2, 1) # [B, HW, C]
# Sample random spatial positions
n_sample = min(self.num_patches, n_total)
idx = torch.randperm(n_total, device=feat_src.device)[:n_sample]
src_sampled = src_flat[:, idx, :] # [B, n_sample, C]
tgt_sampled = tgt_flat[:, idx, :] # [B, n_sample, C]
# Project through MLP
src_proj = mlp(src_sampled) # [B, n_sample, 256]
tgt_proj = mlp(tgt_sampled) # [B, n_sample, 256]
# L2 normalize
src_proj = F.normalize(src_proj, dim=-1)
tgt_proj = F.normalize(tgt_proj, dim=-1)
# InfoNCE: for each query (tgt), positive is matching src position
# negatives are all other src positions
# logits: [B, n_sample, n_sample] β€” (i,j) = similarity of tgt_i to src_j
logits = torch.bmm(tgt_proj, src_proj.transpose(1, 2)) # [B, n, n]
logits = logits / self.temperature
# Target: diagonal (position i matches position i)
target = torch.arange(n_sample, device=logits.device).unsqueeze(0).expand(B, -1)
loss = F.cross_entropy(logits.flatten(0, 1), target.flatten(0, 1))
total_loss = total_loss + loss
n_layers += 1
return total_loss / n_layers if n_layers > 0 else total_loss