Spaces:
Running
Running
File size: 5,804 Bytes
4db9215 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
PatchGAN discriminator for HER2 image realism scoring.
Supports both unconditional (3ch HER2 only) and conditional (6ch H&E+HER2) modes.
Returns patch-level logits AND intermediate features for feature matching loss.
Architecture:
C64(SN) -> C128(SN+IN) -> C256(SN+IN) -> C512(SN+IN,s1) -> 1ch(SN,s1)
70x70 receptive field, output [B, 1, 30, 30] for 512x512 input.
~2.8M params (3ch) or ~2.8M params (6ch).
References:
- Isola et al., "Image-to-Image Translation with Conditional Adversarial Networks" (CVPR 2017)
- Miyato et al., "Spectral Normalization for GANs" (ICLR 2018)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.nn.utils import spectral_norm
class PatchDiscriminator(nn.Module):
"""PatchGAN discriminator with spectral normalization.
Returns both logits and intermediate features (for feature matching loss).
Args:
in_channels: 3 for unconditional (HER2 only), 6 for conditional (H&E + HER2)
ndf: base number of discriminator filters
n_layers: number of intermediate conv layers
"""
def __init__(self, in_channels=3, ndf=64, n_layers=3):
super().__init__()
self.n_layers = n_layers
# Build layers as a list (not sequential) so we can extract features
self.layers = nn.ModuleList()
# First layer: spectral norm, no instance norm
self.layers.append(nn.Sequential(
spectral_norm(nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1)),
nn.LeakyReLU(0.2, inplace=True),
))
# Intermediate layers: spectral norm + instance norm
nf_mult = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
self.layers.append(nn.Sequential(
spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=2, padding=1)),
nn.InstanceNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, inplace=True),
))
# Penultimate layer: stride 1
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
self.layers.append(nn.Sequential(
spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 4, stride=1, padding=1)),
nn.InstanceNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, inplace=True),
))
# Final layer: 1-channel output, no activation (hinge loss uses raw logits)
self.layers.append(nn.Sequential(
spectral_norm(nn.Conv2d(ndf * nf_mult, 1, 4, stride=1, padding=1)),
))
def forward(self, x, return_features=False):
"""
Args:
x: [B, C, H, W] in [-1, 1]. C=3 (unconditional) or C=6 (conditional).
return_features: if True, also return intermediate features for FM loss.
Returns:
logits: [B, 1, H', W'] patch-level real/fake logits
features: list of intermediate feature maps (only if return_features=True)
"""
features = []
h = x
for layer in self.layers:
h = layer(h)
if return_features:
features.append(h)
if return_features:
return h, features
return h
# ======================================================================
# Loss functions
# ======================================================================
def hinge_loss_d(d_real, d_fake):
"""Discriminator hinge loss."""
return (torch.relu(1.0 - d_real).mean() + torch.relu(1.0 + d_fake).mean()) / 2
def hinge_loss_g(d_fake):
"""Generator hinge loss."""
return -d_fake.mean()
def r1_gradient_penalty(discriminator, real_images, weight=10.0):
"""R1 gradient penalty (Mescheder et al., 2018).
Regularizes discriminator to have small gradients on real data,
which prevents the discriminator from becoming too confident and
stabilizes GAN training.
"""
real_images = real_images.detach().requires_grad_(True)
d_real = discriminator(real_images)
grad_real = autograd.grad(
outputs=d_real.sum(),
inputs=real_images,
create_graph=True,
)[0]
penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return weight * penalty
def feature_matching_loss(d_feats_fake, d_feats_real):
"""Feature matching loss: L1 between discriminator features of fake vs real.
Matches statistics at each discriminator layer. Alignment-free because
it compares feature distributions, not pixel-level correspondence.
"""
loss = 0.0
for feat_fake, feat_real in zip(d_feats_fake, d_feats_real):
loss += torch.nn.functional.l1_loss(feat_fake, feat_real.detach())
return loss / len(d_feats_fake)
class MultiScaleDiscriminator(nn.Module):
"""Two PatchGAN discriminators at different scales."""
def __init__(self, in_channels=6, ndf=64, n_layers=3):
super().__init__()
self.disc_512 = PatchDiscriminator(in_channels, ndf, n_layers)
self.disc_256 = PatchDiscriminator(in_channels, ndf, n_layers)
def forward(self, x, return_features=False):
"""
Args:
x: [B, 6, 512, 512] concat(output, H&E)
Returns:
list of (logits, [features]) from each scale
"""
x_256 = F.interpolate(x, size=256, mode='bilinear', align_corners=False)
if return_features:
out_512, feats_512 = self.disc_512(x, return_features=True)
out_256, feats_256 = self.disc_256(x_256, return_features=True)
return [(out_512, feats_512), (out_256, feats_256)]
else:
out_512 = self.disc_512(x)
out_256 = self.disc_256(x_256)
return [out_512, out_256]
|