Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import torch | |
| from torch import nn | |
| from .inference import make_seg_postprocessor | |
| from .loss import make_seg_loss_evaluator | |
| import time | |
| def conv3x3(in_planes, out_planes, stride=1, has_bias=False): | |
| "3x3 convolution with padding" | |
| return nn.Conv2d( | |
| in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias | |
| ) | |
| def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False): | |
| return nn.Sequential( | |
| conv3x3(in_planes, out_planes, stride), | |
| nn.BatchNorm2d(out_planes), | |
| nn.ReLU(inplace=True), | |
| ) | |
| class SEGHead(nn.Module): | |
| """ | |
| Adds a simple SEG Head with pixel-level prediction | |
| """ | |
| def __init__(self, in_channels, cfg): | |
| """ | |
| Arguments: | |
| in_channels (int): number of channels of the input feature | |
| """ | |
| super(SEGHead, self).__init__() | |
| self.cfg = cfg | |
| ndim = 256 | |
| self.fpn_out5 = nn.Sequential( | |
| conv3x3(ndim, 64), nn.Upsample(scale_factor=8, mode="nearest") | |
| ) | |
| self.fpn_out4 = nn.Sequential( | |
| conv3x3(ndim, 64), nn.Upsample(scale_factor=4, mode="nearest") | |
| ) | |
| self.fpn_out3 = nn.Sequential( | |
| conv3x3(ndim, 64), nn.Upsample(scale_factor=2, mode="nearest") | |
| ) | |
| self.fpn_out2 = conv3x3(ndim, 64) | |
| self.seg_out = nn.Sequential( | |
| conv3x3_bn_relu(in_channels, 64, 1), | |
| nn.ConvTranspose2d(64, 64, 2, 2), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(64, 1, 2, 2), | |
| nn.Sigmoid(), | |
| ) | |
| if self.cfg.MODEL.SEG.USE_PPM: | |
| # PPM Module | |
| pool_scales=(2, 4, 8) | |
| fc_dim = 256 | |
| self.ppm_pooling = [] | |
| self.ppm_conv = [] | |
| for scale in pool_scales: | |
| self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) | |
| self.ppm_conv.append(nn.Sequential( | |
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.ppm_pooling = nn.ModuleList(self.ppm_pooling) | |
| self.ppm_conv = nn.ModuleList(self.ppm_conv) | |
| self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, ndim, 1) | |
| self.ppm_conv.apply(self.weights_init) | |
| self.ppm_last_conv.apply(self.weights_init) | |
| self.fpn_out5.apply(self.weights_init) | |
| self.fpn_out4.apply(self.weights_init) | |
| self.fpn_out3.apply(self.weights_init) | |
| self.fpn_out2.apply(self.weights_init) | |
| self.seg_out.apply(self.weights_init) | |
| def forward(self, x): | |
| if self.cfg.MODEL.SEG.USE_PPM: | |
| conv5 = x[-2] | |
| input_size = conv5.size() | |
| ppm_out = [conv5] | |
| for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): | |
| ppm_out.append(pool_conv(nn.functional.interpolate( | |
| pool_scale(conv5), | |
| (input_size[2], input_size[3]), | |
| mode='bilinear', align_corners=False))) | |
| ppm_out = torch.cat(ppm_out, 1) | |
| f = self.ppm_last_conv(ppm_out) | |
| else: | |
| f = x[-2] | |
| # p5 = self.fpn_out5(x[-2]) | |
| p5 = self.fpn_out5(f) | |
| p4 = self.fpn_out4(x[-3]) | |
| p3 = self.fpn_out3(x[-4]) | |
| p2 = self.fpn_out2(x[-5]) | |
| fuse = torch.cat((p5, p4, p3, p2), 1) | |
| out = self.seg_out(fuse) | |
| return out, fuse | |
| def weights_init(self, m): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| nn.init.kaiming_normal_(m.weight.data) | |
| elif classname.find("BatchNorm") != -1: | |
| m.weight.data.fill_(1.0) | |
| m.bias.data.fill_(1e-4) | |
| class SEGModule(torch.nn.Module): | |
| """ | |
| Module for RPN computation. Takes feature maps from the backbone and RPN | |
| proposals and losses. Works for both FPN and non-FPN. | |
| """ | |
| def __init__(self, cfg): | |
| super(SEGModule, self).__init__() | |
| self.cfg = cfg.clone() | |
| in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
| head = SEGHead(in_channels, cfg) | |
| box_selector_train = make_seg_postprocessor(cfg, is_train=True) | |
| box_selector_test = make_seg_postprocessor(cfg, is_train=False) | |
| loss_evaluator = make_seg_loss_evaluator(cfg) | |
| # self.anchor_generator = anchor_generator | |
| self.head = head | |
| self.box_selector_train = box_selector_train | |
| self.box_selector_test = box_selector_test | |
| self.loss_evaluator = loss_evaluator | |
| def forward(self, images, features, targets=None): | |
| """ | |
| Arguments: | |
| images (ImageList): images for which we want to compute the predictions | |
| features (Tensor): fused feature from FPN | |
| targets (Tensor): segmentaion gt map | |
| Returns: | |
| boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per | |
| image. | |
| losses (dict[Tensor]): the losses for the model during training. During | |
| testing, it is an empty dict. | |
| """ | |
| preds, fuse_feature = self.head(features) | |
| # anchors = self.anchor_generator(images, features) | |
| image_shapes = images.get_sizes() | |
| if self.training: | |
| return self._forward_train(preds, targets, image_shapes), [fuse_feature] | |
| else: | |
| return self._forward_test(preds, image_shapes), [fuse_feature] | |
| def _forward_train(self, preds, targets, image_shapes): | |
| # Segmentation map must be transformed into boxes for detection. | |
| # sampled into a training batch. | |
| with torch.no_grad(): | |
| boxes = self.box_selector_train(preds, image_shapes, targets) | |
| loss_seg = self.loss_evaluator(preds, targets) | |
| losses = {"loss_seg": loss_seg} | |
| return boxes, losses | |
| def _forward_test(self, preds, image_shapes): | |
| # torch.cuda.synchronize() | |
| # start_time = time.time() | |
| boxes, rotated_boxes, polygons, scores = self.box_selector_test(preds, image_shapes) | |
| # torch.cuda.synchronize() | |
| # end_time = time.time() | |
| # print('post time:', end_time - start_time) | |
| seg_results = {'rotated_boxes': rotated_boxes, 'polygons': polygons, 'preds': preds, 'scores': scores} | |
| return boxes, seg_results | |
| def build_segmentation(cfg): | |
| """ | |
| This gives the gist of it. Not super important because it doesn't change as much | |
| """ | |
| return SEGModule(cfg) | |