from .pointnet2 import PointNet2 from .cluster_refine import ClusterRefineNet from .edge_pred_net import EdgeAttentionNet import torch.nn as nn from sklearn.cluster import DBSCAN class RoofNet(nn.Module): def __init__(self, model_cfg, input_channel=3): super().__init__() self.use_edge = False self.model_cfg = model_cfg self.keypoint_det_net = PointNet2(model_cfg.PointNet2, input_channel) self.cluster_refine_net = ClusterRefineNet(model_cfg.ClusterRefineNet, input_channel=self.keypoint_det_net.num_output_feature) self.edge_att_net = EdgeAttentionNet(model_cfg.EdgeAttentionNet, input_channel=self.cluster_refine_net.num_output_feature) def forward(self, batch_dict): batch_dict = self.keypoint_det_net(batch_dict) if self.use_edge: batch_dict = self.cluster_refine_net(batch_dict) batch_dict = self.edge_att_net(batch_dict) if self.training: loss = 0 loss_dict = {} disp_dict = {} tmp_loss, loss_dict, disp_dict = self.keypoint_det_net.loss(loss_dict, disp_dict) loss += tmp_loss if self.use_edge: tmp_loss, loss_dict, disp_dict = self.cluster_refine_net.loss(loss_dict, disp_dict) loss += tmp_loss tmp_loss, loss_dict, disp_dict = self.edge_att_net.loss(loss_dict, disp_dict) loss += tmp_loss return loss, loss_dict, disp_dict else: return batch_dict