| """ |
| Copyright (c) 2024 The DEIM Authors. All Rights Reserved. |
| """ |
|
|
| import torch.nn as nn |
| from ..core import register |
|
|
|
|
| __all__ = ['DEIM', ] |
|
|
|
|
| @register() |
| class DEIM(nn.Module): |
| __inject__ = ['backbone', 'encoder', 'decoder', ] |
|
|
| def __init__(self, \ |
| backbone: nn.Module, |
| encoder: nn.Module, |
| decoder: nn.Module, |
| ): |
| super().__init__() |
| self.backbone = backbone |
| self.decoder = decoder |
| self.encoder = encoder |
|
|
| def forward(self, x, targets=None): |
| x = self.backbone(x) |
| x = self.encoder(x) |
| x = self.decoder(x, targets) |
|
|
| return x |
|
|
| def deploy(self, ): |
| self.eval() |
| for m in self.modules(): |
| if hasattr(m, 'convert_to_deploy'): |
| m.convert_to_deploy() |
| return self |
|
|