File size: 455 Bytes
1dc2504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from __future__ import annotations

import timm
import torch
import torch.nn as nn


def build_backbone(name: str, pretrained: bool = True) -> nn.Module:
    model = timm.create_model(name, pretrained=pretrained, num_classes=0, global_pool="avg")
    return model


@torch.no_grad()
def infer_feature_dim(backbone: nn.Module, image_size: int = 224) -> int:
    x = torch.randn(1, 3, image_size, image_size)
    y = backbone(x)
    return int(y.shape[-1])