Spaces:
Build error
Build error
| """ | |
| code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py | |
| This file contains the Definition of GraphCNN | |
| GraphCNN includes ResNet50 as a submodule | |
| """ | |
| from __future__ import division | |
| import torch | |
| import torch.nn as nn | |
| # from .resnet import resnet50 | |
| import torchvision.models as models | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) | |
| from src.graph_networks.graphcmr.utils_mesh import Mesh | |
| from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear | |
| class GraphCNN(nn.Module): | |
| def __init__(self, A, ref_vertices, n_resnet_in, n_resnet_out, num_layers=5, num_channels=512): | |
| super(GraphCNN, self).__init__() | |
| self.A = A | |
| self.ref_vertices = ref_vertices | |
| # self.resnet = resnet50(pretrained=True) | |
| # -> within the GraphCMR network they ignore the last fully connected layer | |
| # replace the first layer | |
| self.resnet = models.resnet34(pretrained=False) | |
| n_in = 3 + 1 | |
| self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
| # replace the last layer | |
| self.resnet.fc = nn.Linear(512, n_resnet_out) | |
| layers = [GraphLinear(3 + n_resnet_out, 2 * num_channels)] # [GraphLinear(3 + 2048, 2 * num_channels)] | |
| layers.append(GraphResBlock(2 * num_channels, num_channels, A)) | |
| for i in range(num_layers): | |
| layers.append(GraphResBlock(num_channels, num_channels, A)) | |
| self.n_out_gc = 2 # two labels per vertex | |
| self.gc = nn.Sequential(GraphResBlock(num_channels, 64, A), | |
| GraphResBlock(64, 32, A), | |
| nn.GroupNorm(32 // 8, 32), | |
| nn.ReLU(inplace=True), | |
| GraphLinear(32, self.n_out_gc)) | |
| self.gcnn = nn.Sequential(*layers) | |
| self.n_out_flatground = 1 | |
| self.flat_ground = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels), | |
| nn.ReLU(inplace=True), | |
| GraphLinear(num_channels, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(A.shape[0], self.n_out_flatground)) | |
| def forward(self, image): | |
| """Forward pass | |
| Inputs: | |
| image: size = (B, 3, 256, 256) | |
| Returns: | |
| Regressed (subsampled) non-parametric shape: size = (B, 1723, 3) | |
| Weak-perspective camera: size = (B, 3) | |
| """ | |
| # import pdb; pdb.set_trace() | |
| batch_size = image.shape[0] | |
| ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) | |
| image_resnet = self.resnet(image) # (bs, 512) | |
| image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973) | |
| x = torch.cat([ref_vertices, image_enc], dim=1) | |
| x = self.gcnn(x) # (bs, 512, 973) | |
| ground_contact = self.gc(x) # (bs, 2, 973) | |
| ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1) | |
| return ground_contact, ground_flatness | |
| # how to use it: | |
| # | |
| # from src.graph_networks.graphcmr.utils_mesh import Mesh | |
| # | |
| # create Mesh object | |
| # self.mesh = Mesh() | |
| # self.faces = self.mesh.faces.to(self.device) | |
| # | |
| # create GraphCNN | |
| # self.graph_cnn = GraphCNN(self.mesh.adjmat, | |
| # self.mesh.ref_vertices.t(), | |
| # num_channels=self.options.num_channels, | |
| # num_layers=self.options.num_layers | |
| # ).to(self.device) | |
| # ------------ | |
| # | |
| # Feed image in the GraphCNN | |
| # Returns subsampled mesh and camera parameters | |
| # pred_vertices_sub, pred_camera = self.graph_cnn(images) | |
| # | |
| # Upsample mesh in the original size | |
| # pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2)) | |
| # |