| import torch | |
| from torch import nn | |
| from .backbone import build_backbone | |
| from .fpn import build_fpn | |
| from .pan import PAN | |
| from .segment_predictor import SegmentPredictor | |
| from .divide_predictor import HeadBodyDividePredictor | |
| from .cells_extractor import CellsExtractor | |
| from .decoder import Decoder | |
| from .utils import extend_segments, spatial_att_to_spans | |
| class Model(nn.Module): | |
| def __init__(self, cfg, norm_layer=nn.BatchNorm2d): | |
| super().__init__() | |
| self.backbone = build_backbone(cfg.arch, cfg.pretrained_backbone, norm_layer=norm_layer) | |
| self.fpn = build_fpn(cfg.backbone_out_channels, cfg.fpn_out_channels) | |
| self.pan = PAN(cfg.pan_num_levels, cfg.pan_in_dim, cfg.pan_out_dim) | |
| self.row_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.rs_scale, type='row') | |
| self.col_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.cs_scale, type='col') | |
| self.divide_predictor = HeadBodyDividePredictor(cfg.fpn_out_channels, cfg.dp_head_nums, scale=cfg.dp_scale) | |
| self.cells_extractor = CellsExtractor(cfg.fpn_out_channels, cfg.ce_dim, cfg.ce_heads, cfg.ce_head_nums, cfg.ce_pool_size, cfg.ce_scale) | |
| self.decoder = Decoder(cfg.vocab, cfg.embed_dim, cfg.feat_dim, cfg.lm_state_dim, cfg.proj_dim, cfg.cover_kernel, cfg.att_threshold, cfg.spatial_att_weight_loss_wight) | |
| def forward(self, images, images_size, cls_labels=None, labels_mask=None, layouts=None, rows_fg_spans=None, | |
| rows_bg_spans=None, cols_fg_spans=None, cols_bg_spans=None, cells_spans=None, divide_labels=None): | |
| feats = self.fpn(self.backbone(images)) | |
| row_feats = torch.mean(feats[0], dim=3) | |
| result_info = dict() | |
| ext_info = dict() | |
| row_segments, rs_result_info, rs_ext_info = self.row_segment_predictor(feats[0], images_size, rows_fg_spans, rows_bg_spans) | |
| rs_result_info = {'row_%s' % key: val for key, val in rs_result_info.items()} | |
| rs_ext_info = {'row_%s' % key: val for key, val in rs_ext_info.items()} | |
| result_info.update(rs_result_info) | |
| ext_info.update(rs_ext_info) | |
| col_segments, cs_result_info, cs_ext_info = self.col_segment_predictor(feats[0], images_size, cols_fg_spans, cols_bg_spans) | |
| cs_result_info = {'col_%s' % key: val for key, val in cs_result_info.items()} | |
| cs_ext_info = {'col_%s' % key: val for key, val in cs_ext_info.items()} | |
| result_info.update(cs_result_info) | |
| ext_info.update(cs_ext_info) | |
| if self.training: | |
| row_segments, col_segments, cells_spans, layouts, divide_labels = extend_segments(row_segments, rs_ext_info['row_ext_segments'], | |
| col_segments, cs_ext_info['col_ext_segments'], cells_spans, layouts, divide_labels) | |
| divide_preds, dp_result_info, dp_ext_info = self.divide_predictor(row_feats, row_segments, divide_labels=divide_labels) | |
| result_info.update(dp_result_info) | |
| ext_info.update(dp_ext_info) | |
| feat_maps, feats_masks = self.cells_extractor(self.pan(feats), row_segments, col_segments, images_size) | |
| if self.training: | |
| assert feat_maps.shape[-2:] == layouts.shape[-2:], print('feat_maps is not the same with layouts') | |
| de_preds, de_result_info = self.decoder(feat_maps, feats_masks.unsqueeze(1), cls_labels, labels_mask, layouts) | |
| result_info.update(de_result_info) | |
| if not self.training: | |
| assert de_preds.shape[0] == 1, print("batch size should be 1") | |
| de_recog_spans = spatial_att_to_spans(de_preds[0]) | |
| return (row_segments, col_segments, divide_preds, de_recog_spans), result_info | |
| else: | |
| return (row_segments, col_segments, divide_preds), result_info | |