Spaces:
Sleeping
Sleeping
File size: 455 Bytes
1dc2504 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | from __future__ import annotations
import timm
import torch
import torch.nn as nn
def build_backbone(name: str, pretrained: bool = True) -> nn.Module:
model = timm.create_model(name, pretrained=pretrained, num_classes=0, global_pool="avg")
return model
@torch.no_grad()
def infer_feature_dim(backbone: nn.Module, image_size: int = 224) -> int:
x = torch.randn(1, 3, image_size, image_size)
y = backbone(x)
return int(y.shape[-1])
|