| """ |
| Image encoder for the image-input SPARK variants (CV and TPD). |
| |
| Replaces the per-scan 1-D `SignalEncoder` with a small 2-D CNN that maps a |
| single rasterized plot image (grayscale, 224x224 by default) to a context |
| vector of the same dimensionality `d_context`. Everything downstream of the |
| per-scan branch (`cv_augment`, SAB, PMA, classifier, flow heads, OOD head) |
| stays unchanged, so this is a drop-in replacement. |
| |
| Input: [B, 1, H, W] grayscale plot image, values in [0, 1]. |
| Output: [B, d_context] |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class ImageEncoder(nn.Module): |
| """Small 2-D CNN encoder for plot images. |
| |
| Architecture (~3.5M params at default settings): |
| Conv 1->32 k=7 s=2 GELU + BN -> 112x112 |
| Conv 32->64 k=5 s=2 GELU + BN -> 56x56 |
| Conv 64->96 k=3 s=2 GELU + BN -> 28x28 |
| Conv 96->128 k=3 s=2 GELU + BN -> 14x14 |
| Conv 128->d_model k=3 s=2 GELU + BN -> 7x7 |
| Adaptive avg pool -> [B, d_model] |
| MLP d_model -> d_context |
| |
| Designed to be light enough to train from scratch alongside the existing |
| classifier and flow heads, while still having the receptive field needed |
| to read curve shape across the whole image. |
| """ |
|
|
| def __init__(self, in_channels: int = 1, d_model: int = 128, |
| d_context: int = 128, dropout: float = 0.1): |
| super().__init__() |
| self.in_channels = in_channels |
| self.d_model = d_model |
| self.d_context = d_context |
|
|
| def block(c_in, c_out, k, s): |
| return nn.Sequential( |
| nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, |
| padding=k // 2, bias=False), |
| nn.BatchNorm2d(c_out), |
| nn.GELU(), |
| ) |
|
|
| self.stem = nn.Sequential( |
| block(in_channels, 32, k=7, s=2), |
| block(32, 64, k=5, s=2), |
| block(64, 96, k=3, s=2), |
| block(96, 128, k=3, s=2), |
| block(128, d_model, k=3, s=2), |
| ) |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| self.proj = nn.Sequential( |
| nn.Linear(d_model, d_context), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_context, d_context), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, in_channels, H, W] image tensor in [0, 1]. |
| |
| Returns: |
| context: [B, d_context] |
| """ |
| h = self.stem(x) |
| h = self.pool(h).flatten(1) |
| return self.proj(h) |
|
|
|
|
| def count_parameters(module: nn.Module) -> int: |
| return sum(p.numel() for p in module.parameters() if p.requires_grad) |
|
|
|
|
| if __name__ == "__main__": |
| enc = ImageEncoder(in_channels=1, d_model=128, d_context=128) |
| x = torch.zeros(2, 1, 224, 224) |
| out = enc(x) |
| print(f"ImageEncoder params: {count_parameters(enc):,}") |
| print(f"Input shape: {tuple(x.shape)}") |
| print(f"Output shape: {tuple(out.shape)}") |
|
|