| | |
| | |
| |
|
| | |
| | |
| |
|
| | """Backbones from the TIMM library.""" |
| |
|
| | from typing import List, Tuple |
| |
|
| | import torch |
| |
|
| | from timm.models import create_model |
| | from torch import nn |
| |
|
| |
|
| | class TimmBackbone(nn.Module): |
| | def __init__( |
| | self, |
| | name: str, |
| | features: Tuple[str, ...], |
| | ): |
| | super().__init__() |
| |
|
| | out_indices = tuple(int(f[len("layer") :]) for f in features) |
| |
|
| | backbone = create_model( |
| | name, |
| | pretrained=True, |
| | in_chans=3, |
| | features_only=True, |
| | out_indices=out_indices, |
| | ) |
| |
|
| | num_channels = backbone.feature_info.channels() |
| | self.channel_list = num_channels[::-1] |
| | self.body = backbone |
| |
|
| | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
| | xs = self.body(x) |
| |
|
| | out = [] |
| | for i, x in enumerate(xs): |
| | out.append(x) |
| | return out |
| |
|