| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| from ppdet.core.workspace import register, create |
| from .meta_arch import BaseArch |
|
|
| __all__ = ['FCOS'] |
|
|
|
|
| @register |
| class FCOS(BaseArch): |
| """ |
| FCOS network, see https://arxiv.org/abs/1904.01355 |
| |
| Args: |
| backbone (object): backbone instance |
| neck (object): 'FPN' instance |
| fcos_head (object): 'FCOSHead' instance |
| ssod_loss (object): 'SSODFCOSLoss' instance, only used for semi-det(ssod) |
| """ |
|
|
| __category__ = 'architecture' |
| __inject__ = ['ssod_loss'] |
|
|
| def __init__(self, |
| backbone='ResNet', |
| neck='FPN', |
| fcos_head='FCOSHead', |
| ssod_loss='SSODFCOSLoss'): |
| super(FCOS, self).__init__() |
| self.backbone = backbone |
| self.neck = neck |
| self.fcos_head = fcos_head |
|
|
| |
| self.is_teacher = False |
| self.ssod_loss = ssod_loss |
|
|
| @classmethod |
| def from_config(cls, cfg, *args, **kwargs): |
| backbone = create(cfg['backbone']) |
|
|
| kwargs = {'input_shape': backbone.out_shape} |
| neck = create(cfg['neck'], **kwargs) |
|
|
| kwargs = {'input_shape': neck.out_shape} |
| fcos_head = create(cfg['fcos_head'], **kwargs) |
|
|
| return { |
| 'backbone': backbone, |
| 'neck': neck, |
| "fcos_head": fcos_head, |
| } |
|
|
| def _forward(self): |
| body_feats = self.backbone(self.inputs) |
| fpn_feats = self.neck(body_feats) |
|
|
| self.is_teacher = self.inputs.get('is_teacher', False) |
| if self.training or self.is_teacher: |
| losses = self.fcos_head(fpn_feats, self.inputs) |
| return losses |
| else: |
| fcos_head_outs = self.fcos_head(fpn_feats) |
| bbox_pred, bbox_num = self.fcos_head.post_process( |
| fcos_head_outs, self.inputs['scale_factor']) |
| return {'bbox': bbox_pred, 'bbox_num': bbox_num} |
|
|
| def get_loss(self): |
| return self._forward() |
|
|
| def get_pred(self): |
| return self._forward() |
|
|
| def get_loss_keys(self): |
| return ['loss_cls', 'loss_box', 'loss_quality'] |
|
|
| def get_ssod_loss(self, student_head_outs, teacher_head_outs, train_cfg): |
| ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs, |
| train_cfg) |
| return ssod_losses |
|
|