""" Image encoder for the image-input SPARK variants (CV and TPD). Replaces the per-scan 1-D `SignalEncoder` with a small 2-D CNN that maps a single rasterized plot image (grayscale, 224x224 by default) to a context vector of the same dimensionality `d_context`. Everything downstream of the per-scan branch (`cv_augment`, SAB, PMA, classifier, flow heads, OOD head) stays unchanged, so this is a drop-in replacement. Input: [B, 1, H, W] grayscale plot image, values in [0, 1]. Output: [B, d_context] """ import torch import torch.nn as nn class ImageEncoder(nn.Module): """Small 2-D CNN encoder for plot images. Architecture (~3.5M params at default settings): Conv 1->32 k=7 s=2 GELU + BN -> 112x112 Conv 32->64 k=5 s=2 GELU + BN -> 56x56 Conv 64->96 k=3 s=2 GELU + BN -> 28x28 Conv 96->128 k=3 s=2 GELU + BN -> 14x14 Conv 128->d_model k=3 s=2 GELU + BN -> 7x7 Adaptive avg pool -> [B, d_model] MLP d_model -> d_context Designed to be light enough to train from scratch alongside the existing classifier and flow heads, while still having the receptive field needed to read curve shape across the whole image. """ def __init__(self, in_channels: int = 1, d_model: int = 128, d_context: int = 128, dropout: float = 0.1): super().__init__() self.in_channels = in_channels self.d_model = d_model self.d_context = d_context def block(c_in, c_out, k, s): return nn.Sequential( nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=k // 2, bias=False), nn.BatchNorm2d(c_out), nn.GELU(), ) self.stem = nn.Sequential( block(in_channels, 32, k=7, s=2), block(32, 64, k=5, s=2), block(64, 96, k=3, s=2), block(96, 128, k=3, s=2), block(128, d_model, k=3, s=2), ) self.pool = nn.AdaptiveAvgPool2d(1) self.proj = nn.Sequential( nn.Linear(d_model, d_context), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_context, d_context), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, in_channels, H, W] image tensor in [0, 1]. Returns: context: [B, d_context] """ h = self.stem(x) h = self.pool(h).flatten(1) return self.proj(h) def count_parameters(module: nn.Module) -> int: return sum(p.numel() for p in module.parameters() if p.requires_grad) if __name__ == "__main__": enc = ImageEncoder(in_channels=1, d_model=128, d_context=128) x = torch.zeros(2, 1, 224, 224) out = enc(x) print(f"ImageEncoder params: {count_parameters(enc):,}") print(f"Input shape: {tuple(x.shape)}") print(f"Output shape: {tuple(out.shape)}")