File size: 5,804 Bytes
4db9215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
PatchGAN discriminator for HER2 image realism scoring.

Supports both unconditional (3ch HER2 only) and conditional (6ch H&E+HER2) modes.
Returns patch-level logits AND intermediate features for feature matching loss.

Architecture:
    C64(SN) -> C128(SN+IN) -> C256(SN+IN) -> C512(SN+IN,s1) -> 1ch(SN,s1)
    70x70 receptive field, output [B, 1, 30, 30] for 512x512 input.
    ~2.8M params (3ch) or ~2.8M params (6ch).

References:
  - Isola et al., "Image-to-Image Translation with Conditional Adversarial Networks" (CVPR 2017)
  - Miyato et al., "Spectral Normalization for GANs" (ICLR 2018)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.nn.utils import spectral_norm


class PatchDiscriminator(nn.Module):
    """PatchGAN discriminator with spectral normalization.

    Returns both logits and intermediate features (for feature matching loss).

    Args:
        in_channels: 3 for unconditional (HER2 only), 6 for conditional (H&E + HER2)
        ndf: base number of discriminator filters
        n_layers: number of intermediate conv layers
    """

    def __init__(self, in_channels=3, ndf=64, n_layers=3):
        super().__init__()
        self.n_layers = n_layers

        # Build layers as a list (not sequential) so we can extract features
        self.layers = nn.ModuleList()

        # First layer: spectral norm, no instance norm
        self.layers.append(nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
        ))

        # Intermediate layers: spectral norm + instance norm
        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            self.layers.append(nn.Sequential(
                spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=2, padding=1)),
                nn.InstanceNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, inplace=True),
            ))

        # Penultimate layer: stride 1
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        self.layers.append(nn.Sequential(
            spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=1, padding=1)),
            nn.InstanceNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, inplace=True),
        ))

        # Final layer: 1-channel output, no activation (hinge loss uses raw logits)
        self.layers.append(nn.Sequential(
            spectral_norm(nn.Conv2d(ndf * nf_mult, 1, 4, stride=1, padding=1)),
        ))

    def forward(self, x, return_features=False):
        """
        Args:
            x: [B, C, H, W] in [-1, 1]. C=3 (unconditional) or C=6 (conditional).
            return_features: if True, also return intermediate features for FM loss.

        Returns:
            logits: [B, 1, H', W'] patch-level real/fake logits
            features: list of intermediate feature maps (only if return_features=True)
        """
        features = []
        h = x
        for layer in self.layers:
            h = layer(h)
            if return_features:
                features.append(h)

        if return_features:
            return h, features
        return h


# ======================================================================
# Loss functions
# ======================================================================

def hinge_loss_d(d_real, d_fake):
    """Discriminator hinge loss."""
    return (torch.relu(1.0 - d_real).mean() + torch.relu(1.0 + d_fake).mean()) / 2


def hinge_loss_g(d_fake):
    """Generator hinge loss."""
    return -d_fake.mean()


def r1_gradient_penalty(discriminator, real_images, weight=10.0):
    """R1 gradient penalty (Mescheder et al., 2018).

    Regularizes discriminator to have small gradients on real data,
    which prevents the discriminator from becoming too confident and
    stabilizes GAN training.
    """
    real_images = real_images.detach().requires_grad_(True)
    d_real = discriminator(real_images)
    grad_real = autograd.grad(
        outputs=d_real.sum(),
        inputs=real_images,
        create_graph=True,
    )[0]
    penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
    return weight * penalty


def feature_matching_loss(d_feats_fake, d_feats_real):
    """Feature matching loss: L1 between discriminator features of fake vs real.

    Matches statistics at each discriminator layer. Alignment-free because
    it compares feature distributions, not pixel-level correspondence.
    """
    loss = 0.0
    for feat_fake, feat_real in zip(d_feats_fake, d_feats_real):
        loss += torch.nn.functional.l1_loss(feat_fake, feat_real.detach())
    return loss / len(d_feats_fake)


class MultiScaleDiscriminator(nn.Module):
    """Two PatchGAN discriminators at different scales."""

    def __init__(self, in_channels=6, ndf=64, n_layers=3):
        super().__init__()
        self.disc_512 = PatchDiscriminator(in_channels, ndf, n_layers)
        self.disc_256 = PatchDiscriminator(in_channels, ndf, n_layers)

    def forward(self, x, return_features=False):
        """
        Args:
            x: [B, 6, 512, 512] concat(output, H&E)

        Returns:
            list of (logits, [features]) from each scale
        """
        x_256 = F.interpolate(x, size=256, mode='bilinear', align_corners=False)

        if return_features:
            out_512, feats_512 = self.disc_512(x, return_features=True)
            out_256, feats_256 = self.disc_256(x_256, return_features=True)
            return [(out_512, feats_512), (out_256, feats_256)]
        else:
            out_512 = self.disc_512(x)
            out_256 = self.disc_256(x_256)
            return [out_512, out_256]