| """
|
| CNN Feature Extractor β Modified ResNet-18 for grayscale thermal images.
|
|
|
| Takes single-channel (grayscale) 224Γ224 images and outputs 256-dim
|
| feature embeddings suitable for downstream sequence analysis.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torchvision.models as models
|
|
|
|
|
| class ThermalFeatureExtractor(nn.Module):
|
| """
|
| Modified ResNet-18 that accepts 1-channel grayscale input
|
| and produces a compact feature embedding.
|
|
|
| Architecture:
|
| Input (1, 224, 224)
|
| β Conv1 (1β64, 7Γ7) (replaces the default 3β64)
|
| β ResNet-18 layers 1-4
|
| β AdaptiveAvgPool β (512,)
|
| β FC(512β256) + BatchNorm + ReLU + Dropout
|
| β 256-dim embedding
|
| """
|
|
|
| def __init__(
|
| self,
|
| embedding_dim: int = 256,
|
| pretrained: bool = True,
|
| in_channels: int = 1,
|
| dropout: float = 0.3,
|
| ):
|
| super().__init__()
|
| self.embedding_dim = embedding_dim
|
|
|
|
|
| weights = models.ResNet18_Weights.DEFAULT if pretrained else None
|
| resnet = models.resnet18(weights=weights)
|
|
|
|
|
| original_conv = resnet.conv1
|
| self.conv1 = nn.Conv2d(
|
| in_channels,
|
| 64,
|
| kernel_size=7,
|
| stride=2,
|
| padding=3,
|
| bias=False,
|
| )
|
|
|
|
|
| if pretrained:
|
| with torch.no_grad():
|
| self.conv1.weight = nn.Parameter(
|
| original_conv.weight.mean(dim=1, keepdim=True)
|
| )
|
|
|
|
|
| self.bn1 = resnet.bn1
|
| self.relu = resnet.relu
|
| self.maxpool = resnet.maxpool
|
| self.layer1 = resnet.layer1
|
| self.layer2 = resnet.layer2
|
| self.layer3 = resnet.layer3
|
| self.layer4 = resnet.layer4
|
| self.avgpool = resnet.avgpool
|
|
|
|
|
| self.projection = nn.Sequential(
|
| nn.Linear(512, embedding_dim),
|
| nn.BatchNorm1d(embedding_dim),
|
| nn.ReLU(inplace=True),
|
| nn.Dropout(p=dropout),
|
| )
|
|
|
| @classmethod
|
| def from_config(cls, config) -> "ThermalFeatureExtractor":
|
| """Construct from a Config object."""
|
| fe = config.model.feature_extractor
|
| return cls(
|
| embedding_dim=fe.embedding_dim,
|
| pretrained=fe.pretrained,
|
| in_channels=fe.in_channels,
|
| dropout=config.model.sequence_analyzer.dropout,
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Forward pass.
|
|
|
| Args:
|
| x: Tensor of shape (B, 1, 224, 224).
|
|
|
| Returns:
|
| Embedding tensor of shape (B, embedding_dim).
|
| """
|
| x = self.conv1(x)
|
| x = self.bn1(x)
|
| x = self.relu(x)
|
| x = self.maxpool(x)
|
|
|
| x = self.layer1(x)
|
| x = self.layer2(x)
|
| x = self.layer3(x)
|
| x = self.layer4(x)
|
|
|
| x = self.avgpool(x)
|
| x = torch.flatten(x, 1)
|
| x = self.projection(x)
|
| return x
|
|
|
| def extract_features_from_sequence(
|
| self, sequence: torch.Tensor
|
| ) -> torch.Tensor:
|
| """
|
| Extract features for a batch of sequences.
|
|
|
| Args:
|
| sequence: (B, T, 1, H, W) β batch of image sequences.
|
|
|
| Returns:
|
| (B, T, embedding_dim)
|
| """
|
| B, T, C, H, W = sequence.shape
|
|
|
| x = sequence.view(B * T, C, H, W)
|
| features = self.forward(x)
|
| return features.view(B, T, self.embedding_dim)
|
|
|