Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from ...core import register | |
| __all__ = ["Classification", "ClassHead"] | |
| class Classification(torch.nn.Module): | |
| __inject__ = ["backbone", "head"] | |
| def __init__(self, backbone: nn.Module, head: nn.Module = None): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.head = head | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| if self.head is not None: | |
| x = self.head(x) | |
| return x | |
| class ClassHead(nn.Module): | |
| def __init__(self, hidden_dim, num_classes): | |
| super().__init__() | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| self.proj = nn.Linear(hidden_dim, num_classes) | |
| def forward(self, x): | |
| x = x[0] if isinstance(x, (list, tuple)) else x | |
| x = self.pool(x) | |
| x = x.reshape(x.shape[0], -1) | |
| x = self.proj(x) | |
| return x | |