Spaces:
Paused
Paused
| import torch.nn as nn | |
| # Local imports (Assuming these contain necessary custom modules) | |
| from models.modules import * | |
| from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 | |
| class FFBaseline(nn.Module): | |
| """ | |
| LSTM Baseline. | |
| Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a | |
| Args: | |
| d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible. | |
| backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). | |
| out_dims (int): Dimensionality of the final output projection. | |
| dropout (float): dropout in last layer | |
| """ | |
| def __init__(self, | |
| d_model, | |
| backbone_type, | |
| out_dims, | |
| dropout=0, | |
| ): | |
| super(FFBaseline, self).__init__() | |
| # --- Core Parameters --- | |
| self.d_model = d_model | |
| self.backbone_type = backbone_type | |
| self.out_dims = out_dims | |
| # --- Input Assertions --- | |
| assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4', | |
| 'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4', | |
| 'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4', | |
| 'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4', | |
| 'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4', | |
| 'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}" | |
| # --- Backbone / Feature Extraction --- | |
| self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet | |
| self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily | |
| resnet_family = resnet18 # Default | |
| if '34' in self.backbone_type: resnet_family = resnet34 | |
| if '50' in self.backbone_type: resnet_family = resnet50 | |
| if '101' in self.backbone_type: resnet_family = resnet101 | |
| if '152' in self.backbone_type: resnet_family = resnet152 | |
| # Determine which ResNet blocks to keep | |
| block_num_str = self.backbone_type.split('-')[-1] | |
| hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4] | |
| self.backbone = resnet_family( | |
| 3, # initial_rgb handles input channels now | |
| hyper_blocks_to_keep, | |
| stride=2, | |
| pretrained=False, | |
| progress=True, | |
| device="cpu", # Initialise on CPU, move later via .to(device) | |
| do_initial_max_pool=True, | |
| ) | |
| # At this point we will have a 4D tensor of features: [B, C, H, W] | |
| # The following lets us scale up the resnet with d_model until it matches the CTM | |
| self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims)) | |
| def forward(self, x): | |
| return self.output_projector((self.backbone(self.initial_rgb(x)))) | |