Spaces:
Runtime error
Runtime error
| #!/usr/bin/python | |
| # | |
| # Copyright 2018 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licens8.0es/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.ops import RoIAlign | |
| from . import box_utils | |
| from .graph import GraphTripleConv, GraphTripleConvNet | |
| from .layout import boxes_to_layout, masks_to_layout, boxes_to_seg, masks_to_seg | |
| from .layers import build_mlp,build_cnn | |
| from .utils import vocab | |
| class Model(nn.Module): | |
| def __init__(self, | |
| embedding_dim=128, | |
| image_size=(128,128), | |
| input_dim=3, | |
| attribute_dim=35, | |
| # graph_net | |
| gconv_dim=128, | |
| gconv_hidden_dim=512, | |
| gconv_num_layers=5, | |
| # inside_cnn | |
| inside_cnn_arch="C3-32-2,C3-64-2,C3-128-2,C3-256-2", | |
| # refinement_net | |
| refinement_dims=(1024, 512, 256, 128, 64), | |
| # box_refine | |
| box_refine_arch = "I15,C3-64-2,C3-128-2,C3-256-2", | |
| roi_output_size = (8,8), | |
| roi_spatial_scale = 1.0/8.0, | |
| roi_cat_feature = True, | |
| # others | |
| mlp_activation='leakyrelu', | |
| mlp_normalization='none', | |
| cnn_activation='leakyrelu', | |
| cnn_normalization='batch' | |
| ): | |
| super(Model, self).__init__() | |
| ''' embedding ''' | |
| self.vocab = vocab | |
| num_objs = len(vocab['object_idx_to_name']) | |
| num_preds = len(vocab['pred_idx_to_name']) | |
| num_doors = len(vocab['door_idx_to_name']) | |
| self.obj_embeddings = nn.Embedding(num_objs, embedding_dim) | |
| self.pred_embeddings = nn.Embedding(num_preds, embedding_dim) | |
| self.image_size = image_size | |
| self.feature_dim = embedding_dim+attribute_dim | |
| ''' graph_net ''' | |
| self.gconv = GraphTripleConv( | |
| embedding_dim, | |
| attributes_dim=attribute_dim, | |
| output_dim=gconv_dim, | |
| hidden_dim=gconv_hidden_dim, | |
| mlp_normalization=mlp_normalization | |
| ) | |
| self.gconv_net = GraphTripleConvNet( | |
| gconv_dim, | |
| num_layers=gconv_num_layers-1, | |
| mlp_normalization=mlp_normalization | |
| ) | |
| ''' inside_cnn ''' | |
| inside_cnn,inside_feat_dim = build_cnn( | |
| f'I{input_dim},{inside_cnn_arch}', | |
| padding='valid' | |
| ) | |
| self.inside_cnn = nn.Sequential( | |
| inside_cnn, | |
| nn.AdaptiveAvgPool2d(1) | |
| ) | |
| inside_output_dim = inside_feat_dim | |
| obj_vecs_dim = gconv_dim+inside_output_dim | |
| ''' box_net ''' | |
| box_net_dim = 4 | |
| box_net_layers = [obj_vecs_dim, gconv_hidden_dim, box_net_dim] | |
| self.box_net = build_mlp( | |
| box_net_layers, | |
| activation=mlp_activation, | |
| batch_norm=mlp_normalization | |
| ) | |
| ''' relationship_net ''' | |
| rel_aux_layers = [obj_vecs_dim, gconv_hidden_dim, num_doors] | |
| self.rel_aux_net = build_mlp( | |
| rel_aux_layers, | |
| activation=mlp_activation, | |
| batch_norm=mlp_normalization | |
| ) | |
| ''' refinement_net ''' | |
| if refinement_dims!=None: | |
| self.refinement_net,_ = build_cnn(f"I{obj_vecs_dim},C3-128,C3-64,C3-{num_objs}") | |
| else: | |
| self.refinement_net = None | |
| ''' roi ''' | |
| self.box_refine_backbone = None | |
| self.roi_cat_feature = roi_cat_feature | |
| if box_refine_arch!=None: | |
| box_refine_cnn,box_feat_dim = build_cnn( | |
| box_refine_arch, | |
| padding='valid' | |
| ) | |
| self.box_refine_backbone = box_refine_cnn | |
| self.roi_align = RoIAlign(roi_output_size,roi_spatial_scale,-1) #(256,8,8) | |
| self.down_sample = nn.AdaptiveAvgPool2d(1) | |
| box_refine_layers = [obj_vecs_dim+256 if self.roi_cat_feature else 256, 512, 4] | |
| self.box_reg =build_mlp( | |
| box_refine_layers, | |
| activation=mlp_activation, | |
| batch_norm=mlp_normalization | |
| ) | |
| def forward( | |
| self, | |
| objs, | |
| triples, | |
| boundary, | |
| obj_to_img=None, | |
| attributes=None, | |
| boxes_gt=None, | |
| generate=False, | |
| refine=False, | |
| relative=False, | |
| inside_box=None | |
| ): | |
| """ | |
| Required Inputs: | |
| - objs: LongTensor of shape (O,) giving categories for all objects | |
| - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] | |
| means that there is a triple (objs[s], p, objs[o]) | |
| Optional Inputs: | |
| - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i | |
| means that objects[o] is an object in image i. If not given then | |
| all objects are assumed to belong to the same image. | |
| - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing | |
| the spatial layout; if not given then use predicted boxes. | |
| """ | |
| # input size | |
| O, T = objs.size(0), triples.size(0) | |
| s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) | |
| s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) | |
| edges = torch.stack([s, o], dim=1) # Shape is (T, 2) | |
| B = boundary.size(0) | |
| H, W = self.image_size | |
| if obj_to_img is None: | |
| obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) | |
| ''' embedding ''' | |
| obj_vecs = self.obj_embeddings(objs) | |
| pred_vecs = self.pred_embeddings(p) | |
| ''' attribute ''' | |
| if attributes is not None: | |
| obj_vecs = torch.cat([obj_vecs,attributes],1) | |
| obj_vecs_orig = obj_vecs | |
| ''' gconv ''' | |
| obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) | |
| obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) | |
| ''' inside ''' | |
| inside_vecs = self.inside_cnn(boundary).view(B,-1) | |
| obj_vecs = torch.cat([obj_vecs,inside_vecs[obj_to_img]],dim=1) | |
| ''' box ''' | |
| boxes_pred = self.box_net(obj_vecs) | |
| if relative: boxes_pred = box_utils.box_rel2abs(boxes_pred,inside_box,obj_to_img) | |
| ''' relation ''' | |
| # unused, for door position predition | |
| # rel_scores = self.rel_aux_net(obj_vecs) | |
| ''' generate ''' | |
| gene_layout = None | |
| boxes_refine = None | |
| layout_boxes = boxes_pred if boxes_gt is None else boxes_gt | |
| if generate: | |
| layout_features = boxes_to_layout(obj_vecs,layout_boxes,obj_to_img,H,W) | |
| gene_layout = self.refinement_net(layout_features) | |
| ''' box refine ''' | |
| if refine: | |
| gene_feat = self.box_refine_backbone(gene_layout) | |
| rois = torch.cat([ | |
| obj_to_img.float().view(-1,1), | |
| box_utils.centers_to_extents(layout_boxes)*H | |
| ],-1) | |
| roi_feat = self.down_sample(self.roi_align(gene_feat,rois)).flatten(1) | |
| roi_feat = torch.cat([ | |
| roi_feat, | |
| obj_vecs | |
| ],-1) | |
| boxes_refine = self.box_reg(roi_feat) | |
| if relative: boxes_refine = box_utils.box_rel2abs(boxes_refine,inside_box,obj_to_img) | |
| return boxes_pred, gene_layout, boxes_refine | |