File size: 5,722 Bytes
4db9215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Loss functions for UNIStainNet.

- VGGFeatureExtractor: intermediate VGG16 features for Gram-matrix style loss
- gram_matrix: compute Gram matrix of feature maps
- PatchNCELoss: contrastive loss between H&E input and generated output (alignment-free)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class VGGFeatureExtractor(nn.Module):
    """Extract intermediate VGG16 features for Gram-matrix style loss.

    Uses early VGG layers (relu1_2, relu2_2, relu3_3) which capture texture
    at different scales. Gram matrices of these features are alignment-invariant
    texture descriptors β€” they measure feature co-occurrence statistics, not
    spatial layout (Gatys et al., 2016).
    """

    def __init__(self):
        super().__init__()
        from torchvision.models import vgg16, VGG16_Weights
        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features
        # Extract at relu1_2 (idx 4), relu2_2 (idx 9), relu3_3 (idx 16)
        self.slice1 = nn.Sequential(*list(vgg.children())[:4])   # β†’ relu1_2
        self.slice2 = nn.Sequential(*list(vgg.children())[4:9])  # β†’ relu2_2
        self.slice3 = nn.Sequential(*list(vgg.children())[9:16]) # β†’ relu3_3
        # Freeze
        for p in self.parameters():
            p.requires_grad = False
        self.eval()
        # ImageNet normalization
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        """
        Args:
            x: [B, 3, H, W] in [-1, 1]
        Returns:
            list of feature maps at 3 scales
        """
        # Normalize: [-1,1] β†’ [0,1] β†’ ImageNet
        x = (x + 1) / 2
        x = (x - self.mean) / self.std
        f1 = self.slice1(x)
        f2 = self.slice2(f1)
        f3 = self.slice3(f2)
        return [f1, f2, f3]


def gram_matrix(feat):
    """Compute Gram matrix of feature map.

    Args:
        feat: [B, C, H, W]
    Returns:
        gram: [B, C, C] β€” normalized by spatial size
    """
    B, C, H, W = feat.shape
    feat_flat = feat.reshape(B, C, H * W)  # [B, C, N]
    gram = torch.bmm(feat_flat, feat_flat.transpose(1, 2))  # [B, C, C]
    return gram / (C * H * W)


class PatchNCELoss(nn.Module):
    """Patchwise Noise Contrastive Estimation loss.

    Compares H&E input and generated IHC through the generator's encoder.
    For each spatial position in the generated features, the corresponding
    position in the H&E features is the positive, and random other positions
    are negatives. Never sees GT IHC.

    Reference: Park et al., "Contrastive Learning for Unpaired Image-to-Image
    Translation" (ECCV 2020) β€” adapted for paired (misaligned) setting.
    """

    def __init__(self, layer_channels, num_patches=256, temperature=0.07):
        """
        Args:
            layer_channels: dict {layer_idx: channels} for each encoder layer
            num_patches: number of spatial positions to sample per layer
            temperature: InfoNCE temperature
        """
        super().__init__()
        self.num_patches = num_patches
        self.temperature = temperature

        # 2-layer MLP projection head per encoder layer
        self.mlps = nn.ModuleDict()
        for layer_idx, ch in layer_channels.items():
            self.mlps[str(layer_idx)] = nn.Sequential(
                nn.Linear(ch, 256),
                nn.ReLU(inplace=True),
                nn.Linear(256, 256),
            )

    def forward(self, feats_src, feats_tgt):
        """Compute PatchNCE loss across encoder layers.

        Args:
            feats_src: dict {layer_idx: [B, C, H, W]} from H&E input
            feats_tgt: dict {layer_idx: [B, C, H, W]} from generated IHC

        Returns:
            scalar loss
        """
        total_loss = 0.0
        n_layers = 0

        for layer_idx_str, mlp in self.mlps.items():
            layer_idx = int(layer_idx_str)
            feat_src = feats_src[layer_idx]  # [B, C, H, W]
            feat_tgt = feats_tgt[layer_idx]  # [B, C, H, W]

            B, C, H, W = feat_src.shape
            n_total = H * W

            # Reshape to [B, C, H*W] then [B, H*W, C]
            src_flat = feat_src.flatten(2).permute(0, 2, 1)  # [B, HW, C]
            tgt_flat = feat_tgt.flatten(2).permute(0, 2, 1)  # [B, HW, C]

            # Sample random spatial positions
            n_sample = min(self.num_patches, n_total)
            idx = torch.randperm(n_total, device=feat_src.device)[:n_sample]

            src_sampled = src_flat[:, idx, :]  # [B, n_sample, C]
            tgt_sampled = tgt_flat[:, idx, :]  # [B, n_sample, C]

            # Project through MLP
            src_proj = mlp(src_sampled)  # [B, n_sample, 256]
            tgt_proj = mlp(tgt_sampled)  # [B, n_sample, 256]

            # L2 normalize
            src_proj = F.normalize(src_proj, dim=-1)
            tgt_proj = F.normalize(tgt_proj, dim=-1)

            # InfoNCE: for each query (tgt), positive is matching src position
            # negatives are all other src positions
            # logits: [B, n_sample, n_sample] β€” (i,j) = similarity of tgt_i to src_j
            logits = torch.bmm(tgt_proj, src_proj.transpose(1, 2))  # [B, n, n]
            logits = logits / self.temperature

            # Target: diagonal (position i matches position i)
            target = torch.arange(n_sample, device=logits.device).unsqueeze(0).expand(B, -1)

            loss = F.cross_entropy(logits.flatten(0, 1), target.flatten(0, 1))
            total_loss = total_loss + loss
            n_layers += 1

        return total_loss / n_layers if n_layers > 0 else total_loss