File size: 3,922 Bytes
49c96b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
CNN Feature Extractor β Modified ResNet-18 for grayscale thermal images.
Takes single-channel (grayscale) 224Γ224 images and outputs 256-dim
feature embeddings suitable for downstream sequence analysis.
"""
import torch
import torch.nn as nn
import torchvision.models as models
class ThermalFeatureExtractor(nn.Module):
"""
Modified ResNet-18 that accepts 1-channel grayscale input
and produces a compact feature embedding.
Architecture:
Input (1, 224, 224)
β Conv1 (1β64, 7Γ7) (replaces the default 3β64)
β ResNet-18 layers 1-4
β AdaptiveAvgPool β (512,)
β FC(512β256) + BatchNorm + ReLU + Dropout
β 256-dim embedding
"""
def __init__(
self,
embedding_dim: int = 256,
pretrained: bool = True,
in_channels: int = 1,
dropout: float = 0.3,
):
super().__init__()
self.embedding_dim = embedding_dim
# Load pretrained ResNet-18
weights = models.ResNet18_Weights.DEFAULT if pretrained else None
resnet = models.resnet18(weights=weights)
# Replace the first conv layer: 3-channel β 1-channel
original_conv = resnet.conv1
self.conv1 = nn.Conv2d(
in_channels,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False,
)
# If pretrained, initialise from the mean of the RGB weights
if pretrained:
with torch.no_grad():
self.conv1.weight = nn.Parameter(
original_conv.weight.mean(dim=1, keepdim=True)
)
# Keep the rest of ResNet-18 up to avgpool
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = resnet.avgpool
# Projection head: 512 β embedding_dim
self.projection = nn.Sequential(
nn.Linear(512, embedding_dim),
nn.BatchNorm1d(embedding_dim),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
)
@classmethod
def from_config(cls, config) -> "ThermalFeatureExtractor":
"""Construct from a Config object."""
fe = config.model.feature_extractor
return cls(
embedding_dim=fe.embedding_dim,
pretrained=fe.pretrained,
in_channels=fe.in_channels,
dropout=config.model.sequence_analyzer.dropout,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x: Tensor of shape (B, 1, 224, 224).
Returns:
Embedding tensor of shape (B, embedding_dim).
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1) # (B, 512)
x = self.projection(x) # (B, embedding_dim)
return x
def extract_features_from_sequence(
self, sequence: torch.Tensor
) -> torch.Tensor:
"""
Extract features for a batch of sequences.
Args:
sequence: (B, T, 1, H, W) β batch of image sequences.
Returns:
(B, T, embedding_dim)
"""
B, T, C, H, W = sequence.shape
# Flatten batch and time β (B*T, C, H, W)
x = sequence.view(B * T, C, H, W)
features = self.forward(x) # (B*T, D)
return features.view(B, T, self.embedding_dim)
|