| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import torch |
| import torchvision |
|
|
| from ..core import register |
| from .utils import IntermediateLayerGetter |
|
|
| __all__ = ['TorchVisionModel'] |
|
|
| @register() |
| class TorchVisionModel(torch.nn.Module): |
| def __init__(self, name, return_layers, weights=None, **kwargs) -> None: |
| super().__init__() |
|
|
| if weights is not None: |
| weights = getattr(torchvision.models.get_model_weights(name), weights) |
|
|
| model = torchvision.models.get_model(name, weights=weights, **kwargs) |
|
|
| |
| if hasattr(model, 'features'): |
| model = IntermediateLayerGetter(model.features, return_layers) |
| else: |
| model = IntermediateLayerGetter(model, return_layers) |
|
|
| self.model = model |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
|
|