""" PatchGAN discriminator for HER2 image realism scoring. Supports both unconditional (3ch HER2 only) and conditional (6ch H&E+HER2) modes. Returns patch-level logits AND intermediate features for feature matching loss. Architecture: C64(SN) -> C128(SN+IN) -> C256(SN+IN) -> C512(SN+IN,s1) -> 1ch(SN,s1) 70x70 receptive field, output [B, 1, 30, 30] for 512x512 input. ~2.8M params (3ch) or ~2.8M params (6ch). References: - Isola et al., "Image-to-Image Translation with Conditional Adversarial Networks" (CVPR 2017) - Miyato et al., "Spectral Normalization for GANs" (ICLR 2018) """ import torch import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd from torch.nn.utils import spectral_norm class PatchDiscriminator(nn.Module): """PatchGAN discriminator with spectral normalization. Returns both logits and intermediate features (for feature matching loss). Args: in_channels: 3 for unconditional (HER2 only), 6 for conditional (H&E + HER2) ndf: base number of discriminator filters n_layers: number of intermediate conv layers """ def __init__(self, in_channels=3, ndf=64, n_layers=3): super().__init__() self.n_layers = n_layers # Build layers as a list (not sequential) so we can extract features self.layers = nn.ModuleList() # First layer: spectral norm, no instance norm self.layers.append(nn.Sequential( spectral_norm(nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1)), nn.LeakyReLU(0.2, inplace=True), )) # Intermediate layers: spectral norm + instance norm nf_mult = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.layers.append(nn.Sequential( spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=2, padding=1)), nn.InstanceNorm2d(ndf * nf_mult), nn.LeakyReLU(0.2, inplace=True), )) # Penultimate layer: stride 1 nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) self.layers.append(nn.Sequential( spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=1, padding=1)), nn.InstanceNorm2d(ndf * nf_mult), nn.LeakyReLU(0.2, inplace=True), )) # Final layer: 1-channel output, no activation (hinge loss uses raw logits) self.layers.append(nn.Sequential( spectral_norm(nn.Conv2d(ndf * nf_mult, 1, 4, stride=1, padding=1)), )) def forward(self, x, return_features=False): """ Args: x: [B, C, H, W] in [-1, 1]. C=3 (unconditional) or C=6 (conditional). return_features: if True, also return intermediate features for FM loss. Returns: logits: [B, 1, H', W'] patch-level real/fake logits features: list of intermediate feature maps (only if return_features=True) """ features = [] h = x for layer in self.layers: h = layer(h) if return_features: features.append(h) if return_features: return h, features return h # ====================================================================== # Loss functions # ====================================================================== def hinge_loss_d(d_real, d_fake): """Discriminator hinge loss.""" return (torch.relu(1.0 - d_real).mean() + torch.relu(1.0 + d_fake).mean()) / 2 def hinge_loss_g(d_fake): """Generator hinge loss.""" return -d_fake.mean() def r1_gradient_penalty(discriminator, real_images, weight=10.0): """R1 gradient penalty (Mescheder et al., 2018). Regularizes discriminator to have small gradients on real data, which prevents the discriminator from becoming too confident and stabilizes GAN training. """ real_images = real_images.detach().requires_grad_(True) d_real = discriminator(real_images) grad_real = autograd.grad( outputs=d_real.sum(), inputs=real_images, create_graph=True, )[0] penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return weight * penalty def feature_matching_loss(d_feats_fake, d_feats_real): """Feature matching loss: L1 between discriminator features of fake vs real. Matches statistics at each discriminator layer. Alignment-free because it compares feature distributions, not pixel-level correspondence. """ loss = 0.0 for feat_fake, feat_real in zip(d_feats_fake, d_feats_real): loss += torch.nn.functional.l1_loss(feat_fake, feat_real.detach()) return loss / len(d_feats_fake) class MultiScaleDiscriminator(nn.Module): """Two PatchGAN discriminators at different scales.""" def __init__(self, in_channels=6, ndf=64, n_layers=3): super().__init__() self.disc_512 = PatchDiscriminator(in_channels, ndf, n_layers) self.disc_256 = PatchDiscriminator(in_channels, ndf, n_layers) def forward(self, x, return_features=False): """ Args: x: [B, 6, 512, 512] concat(output, H&E) Returns: list of (logits, [features]) from each scale """ x_256 = F.interpolate(x, size=256, mode='bilinear', align_corners=False) if return_features: out_512, feats_512 = self.disc_512(x, return_features=True) out_256, feats_256 = self.disc_256(x_256, return_features=True) return [(out_512, feats_512), (out_256, feats_256)] else: out_512 = self.disc_512(x) out_256 = self.disc_256(x_256) return [out_512, out_256]