| | import functools |
| | from dataclasses import dataclass |
| | from typing import Literal |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision |
| | from einops import rearrange |
| | from jaxtyping import Float |
| | from torch import Tensor, nn |
| | from torchvision.models import ResNet |
| |
|
| | from src.dataset.types import BatchedViews |
| | from .backbone import Backbone |
| |
|
| |
|
| | @dataclass |
| | class BackboneResnetCfg: |
| | name: Literal["resnet"] |
| | model: Literal[ |
| | "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50" |
| | ] |
| | num_layers: int |
| | use_first_pool: bool |
| | d_out: int |
| |
|
| |
|
| | class BackboneResnet(Backbone[BackboneResnetCfg]): |
| | model: ResNet |
| |
|
| | def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None: |
| | super().__init__(cfg) |
| |
|
| | assert d_in == 3 |
| |
|
| | norm_layer = functools.partial( |
| | nn.InstanceNorm2d, |
| | affine=False, |
| | track_running_stats=False, |
| | ) |
| |
|
| | if cfg.model == "dino_resnet50": |
| | self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") |
| | else: |
| | self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer) |
| |
|
| | |
| | self.projections = nn.ModuleDict({}) |
| | for index in range(1, cfg.num_layers): |
| | key = f"layer{index}" |
| | block = getattr(self.model, key) |
| | conv_index = 1 |
| | try: |
| | while True: |
| | d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels |
| | conv_index += 1 |
| | except AttributeError: |
| | pass |
| | self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1) |
| |
|
| | |
| | self.projections["layer0"] = nn.Conv2d( |
| | self.model.conv1.out_channels, cfg.d_out, 1 |
| | ) |
| |
|
| | def forward( |
| | self, |
| | context: BatchedViews, |
| | ) -> Float[Tensor, "batch view d_out height width"]: |
| | |
| | b, v, _, h, w = context["image"].shape |
| | x = rearrange(context["image"], "b v c h w -> (b v) c h w") |
| |
|
| | |
| | x = self.model.conv1(x) |
| | x = self.model.bn1(x) |
| | x = self.model.relu(x) |
| | features = [self.projections["layer0"](x)] |
| |
|
| | |
| | for index in range(1, self.cfg.num_layers): |
| | key = f"layer{index}" |
| | if index == 0 and self.cfg.use_first_pool: |
| | x = self.model.maxpool(x) |
| | x = getattr(self.model, key)(x) |
| | features.append(self.projections[key](x)) |
| |
|
| | |
| | features = [ |
| | F.interpolate(f, (h, w), mode="bilinear", align_corners=True) |
| | for f in features |
| | ] |
| | features = torch.stack(features).sum(dim=0) |
| |
|
| | |
| | return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v) |
| |
|
| | @property |
| | def d_out(self) -> int: |
| | return self.cfg.d_out |
| |
|