File size: 2,953 Bytes
3f4e2ae
8619a66
3f4e2ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Image encoder for the image-input SPARK variants (CV and TPD).

Replaces the per-scan 1-D `SignalEncoder` with a small 2-D CNN that maps a
single rasterized plot image (grayscale, 224x224 by default) to a context
vector of the same dimensionality `d_context`. Everything downstream of the
per-scan branch (`cv_augment`, SAB, PMA, classifier, flow heads, OOD head)
stays unchanged, so this is a drop-in replacement.

Input:  [B, 1, H, W] grayscale plot image, values in [0, 1].
Output: [B, d_context]
"""

import torch
import torch.nn as nn


class ImageEncoder(nn.Module):
    """Small 2-D CNN encoder for plot images.

    Architecture (~3.5M params at default settings):
        Conv 1->32  k=7 s=2   GELU + BN  -> 112x112
        Conv 32->64 k=5 s=2   GELU + BN  -> 56x56
        Conv 64->96 k=3 s=2   GELU + BN  -> 28x28
        Conv 96->128 k=3 s=2  GELU + BN  -> 14x14
        Conv 128->d_model k=3 s=2 GELU + BN -> 7x7
        Adaptive avg pool -> [B, d_model]
        MLP d_model -> d_context

    Designed to be light enough to train from scratch alongside the existing
    classifier and flow heads, while still having the receptive field needed
    to read curve shape across the whole image.
    """

    def __init__(self, in_channels: int = 1, d_model: int = 128,
                 d_context: int = 128, dropout: float = 0.1):
        super().__init__()
        self.in_channels = in_channels
        self.d_model = d_model
        self.d_context = d_context

        def block(c_in, c_out, k, s):
            return nn.Sequential(
                nn.Conv2d(c_in, c_out, kernel_size=k, stride=s,
                          padding=k // 2, bias=False),
                nn.BatchNorm2d(c_out),
                nn.GELU(),
            )

        self.stem = nn.Sequential(
            block(in_channels, 32, k=7, s=2),
            block(32, 64, k=5, s=2),
            block(64, 96, k=3, s=2),
            block(96, 128, k=3, s=2),
            block(128, d_model, k=3, s=2),
        )
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Sequential(
            nn.Linear(d_model, d_context),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_context, d_context),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, in_channels, H, W] image tensor in [0, 1].

        Returns:
            context: [B, d_context]
        """
        h = self.stem(x)
        h = self.pool(h).flatten(1)
        return self.proj(h)


def count_parameters(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


if __name__ == "__main__":
    enc = ImageEncoder(in_channels=1, d_model=128, d_context=128)
    x = torch.zeros(2, 1, 224, 224)
    out = enc(x)
    print(f"ImageEncoder params: {count_parameters(enc):,}")
    print(f"Input shape:  {tuple(x.shape)}")
    print(f"Output shape: {tuple(out.shape)}")