S23DR-P2R / model /roofnet.py
colin1842's picture
add model
8d5039c
from .pointnet2 import PointNet2
from .cluster_refine import ClusterRefineNet
from .edge_pred_net import EdgeAttentionNet
import torch.nn as nn
from sklearn.cluster import DBSCAN
class RoofNet(nn.Module):
def __init__(self, model_cfg, input_channel=3):
super().__init__()
self.use_edge = False
self.model_cfg = model_cfg
self.keypoint_det_net = PointNet2(model_cfg.PointNet2, input_channel)
self.cluster_refine_net = ClusterRefineNet(model_cfg.ClusterRefineNet, input_channel=self.keypoint_det_net.num_output_feature)
self.edge_att_net = EdgeAttentionNet(model_cfg.EdgeAttentionNet, input_channel=self.cluster_refine_net.num_output_feature)
def forward(self, batch_dict):
batch_dict = self.keypoint_det_net(batch_dict)
if self.use_edge:
batch_dict = self.cluster_refine_net(batch_dict)
batch_dict = self.edge_att_net(batch_dict)
if self.training:
loss = 0
loss_dict = {}
disp_dict = {}
tmp_loss, loss_dict, disp_dict = self.keypoint_det_net.loss(loss_dict, disp_dict)
loss += tmp_loss
if self.use_edge:
tmp_loss, loss_dict, disp_dict = self.cluster_refine_net.loss(loss_dict, disp_dict)
loss += tmp_loss
tmp_loss, loss_dict, disp_dict = self.edge_att_net.loss(loss_dict, disp_dict)
loss += tmp_loss
return loss, loss_dict, disp_dict
else:
return batch_dict