trace / image_encoder.py
bingyan user
Rebrand TRACE -> SPARK
8619a66
"""
Image encoder for the image-input SPARK variants (CV and TPD).
Replaces the per-scan 1-D `SignalEncoder` with a small 2-D CNN that maps a
single rasterized plot image (grayscale, 224x224 by default) to a context
vector of the same dimensionality `d_context`. Everything downstream of the
per-scan branch (`cv_augment`, SAB, PMA, classifier, flow heads, OOD head)
stays unchanged, so this is a drop-in replacement.
Input: [B, 1, H, W] grayscale plot image, values in [0, 1].
Output: [B, d_context]
"""
import torch
import torch.nn as nn
class ImageEncoder(nn.Module):
"""Small 2-D CNN encoder for plot images.
Architecture (~3.5M params at default settings):
Conv 1->32 k=7 s=2 GELU + BN -> 112x112
Conv 32->64 k=5 s=2 GELU + BN -> 56x56
Conv 64->96 k=3 s=2 GELU + BN -> 28x28
Conv 96->128 k=3 s=2 GELU + BN -> 14x14
Conv 128->d_model k=3 s=2 GELU + BN -> 7x7
Adaptive avg pool -> [B, d_model]
MLP d_model -> d_context
Designed to be light enough to train from scratch alongside the existing
classifier and flow heads, while still having the receptive field needed
to read curve shape across the whole image.
"""
def __init__(self, in_channels: int = 1, d_model: int = 128,
d_context: int = 128, dropout: float = 0.1):
super().__init__()
self.in_channels = in_channels
self.d_model = d_model
self.d_context = d_context
def block(c_in, c_out, k, s):
return nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=k, stride=s,
padding=k // 2, bias=False),
nn.BatchNorm2d(c_out),
nn.GELU(),
)
self.stem = nn.Sequential(
block(in_channels, 32, k=7, s=2),
block(32, 64, k=5, s=2),
block(64, 96, k=3, s=2),
block(96, 128, k=3, s=2),
block(128, d_model, k=3, s=2),
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.proj = nn.Sequential(
nn.Linear(d_model, d_context),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_context, d_context),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, in_channels, H, W] image tensor in [0, 1].
Returns:
context: [B, d_context]
"""
h = self.stem(x)
h = self.pool(h).flatten(1)
return self.proj(h)
def count_parameters(module: nn.Module) -> int:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
if __name__ == "__main__":
enc = ImageEncoder(in_channels=1, d_model=128, d_context=128)
x = torch.zeros(2, 1, 224, 224)
out = enc(x)
print(f"ImageEncoder params: {count_parameters(enc):,}")
print(f"Input shape: {tuple(x.shape)}")
print(f"Output shape: {tuple(out.shape)}")