kai-2054's picture
Initial commit: add code
cb0ad2d
import torch
from torch import nn
from .backbone import build_backbone
from .fpn import build_fpn
from .pan import PAN
from .segment_predictor import SegmentPredictor
from .divide_predictor import HeadBodyDividePredictor
from .cells_extractor import CellsExtractor
from .decoder import Decoder
from .utils import extend_segments, spatial_att_to_spans
class Model(nn.Module):
def __init__(self, cfg, norm_layer=nn.BatchNorm2d):
super().__init__()
self.backbone = build_backbone(cfg.arch, cfg.pretrained_backbone, norm_layer=norm_layer)
self.fpn = build_fpn(cfg.backbone_out_channels, cfg.fpn_out_channels)
self.pan = PAN(cfg.pan_num_levels, cfg.pan_in_dim, cfg.pan_out_dim)
self.row_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.rs_scale, type='row')
self.col_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.cs_scale, type='col')
self.divide_predictor = HeadBodyDividePredictor(cfg.fpn_out_channels, cfg.dp_head_nums, scale=cfg.dp_scale)
self.cells_extractor = CellsExtractor(cfg.fpn_out_channels, cfg.ce_dim, cfg.ce_heads, cfg.ce_head_nums, cfg.ce_pool_size, cfg.ce_scale)
self.decoder = Decoder(cfg.vocab, cfg.embed_dim, cfg.feat_dim, cfg.lm_state_dim, cfg.proj_dim, cfg.cover_kernel, cfg.att_threshold, cfg.spatial_att_weight_loss_wight)
def forward(self, images, images_size, cls_labels=None, labels_mask=None, layouts=None, rows_fg_spans=None,
rows_bg_spans=None, cols_fg_spans=None, cols_bg_spans=None, cells_spans=None, divide_labels=None):
feats = self.fpn(self.backbone(images))
row_feats = torch.mean(feats[0], dim=3)
result_info = dict()
ext_info = dict()
row_segments, rs_result_info, rs_ext_info = self.row_segment_predictor(feats[0], images_size, rows_fg_spans, rows_bg_spans)
rs_result_info = {'row_%s' % key: val for key, val in rs_result_info.items()}
rs_ext_info = {'row_%s' % key: val for key, val in rs_ext_info.items()}
result_info.update(rs_result_info)
ext_info.update(rs_ext_info)
col_segments, cs_result_info, cs_ext_info = self.col_segment_predictor(feats[0], images_size, cols_fg_spans, cols_bg_spans)
cs_result_info = {'col_%s' % key: val for key, val in cs_result_info.items()}
cs_ext_info = {'col_%s' % key: val for key, val in cs_ext_info.items()}
result_info.update(cs_result_info)
ext_info.update(cs_ext_info)
if self.training:
row_segments, col_segments, cells_spans, layouts, divide_labels = extend_segments(row_segments, rs_ext_info['row_ext_segments'],
col_segments, cs_ext_info['col_ext_segments'], cells_spans, layouts, divide_labels)
divide_preds, dp_result_info, dp_ext_info = self.divide_predictor(row_feats, row_segments, divide_labels=divide_labels)
result_info.update(dp_result_info)
ext_info.update(dp_ext_info)
feat_maps, feats_masks = self.cells_extractor(self.pan(feats), row_segments, col_segments, images_size)
if self.training:
assert feat_maps.shape[-2:] == layouts.shape[-2:], print('feat_maps is not the same with layouts')
de_preds, de_result_info = self.decoder(feat_maps, feats_masks.unsqueeze(1), cls_labels, labels_mask, layouts)
result_info.update(de_result_info)
if not self.training:
assert de_preds.shape[0] == 1, print("batch size should be 1")
de_recog_spans = spatial_att_to_spans(de_preds[0])
return (row_segments, col_segments, divide_preds, de_recog_spans), result_info
else:
return (row_segments, col_segments, divide_preds), result_info