| | import os |
| | from tqdm import tqdm |
| | import pickle |
| | import argparse |
| | import pathlib |
| | import json |
| | import time |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.parallel |
| | import torch.utils.data |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from metrics import ConfusionMatrix |
| | import data_transforms |
| | import argparse |
| | import random |
| | import traceback |
| |
|
| | """ |
| | Model |
| | """ |
| | class STN3d(nn.Module): |
| | def __init__(self, in_channels): |
| | super(STN3d, self).__init__() |
| | self.conv_layers = nn.Sequential( |
| | nn.Conv1d(in_channels, 64, 1), |
| | nn.BatchNorm1d(64), |
| | nn.ReLU(inplace=True), |
| | nn.Conv1d(64, 128, 1), |
| | nn.BatchNorm1d(128), |
| | nn.ReLU(inplace=True), |
| | nn.Conv1d(128, 1024, 1), |
| | nn.BatchNorm1d(1024), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.linear_layers = nn.Sequential( |
| | nn.Linear(1024, 512), |
| | nn.BatchNorm1d(512), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(512, 256), |
| | nn.BatchNorm1d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(256, 9) |
| | ) |
| | self.iden = torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)).reshape(1, 9) |
| |
|
| | def forward(self, x): |
| | batchsize = x.size()[0] |
| | x = self.conv_layers(x) |
| | x = torch.max(x, 2, keepdim=True)[0] |
| | x = x.view(-1, 1024) |
| |
|
| | x = self.linear_layers(x) |
| | iden = self.iden.repeat(batchsize, 1).to(x.device) |
| | x = x + iden |
| | x = x.view(-1, 3, 3) |
| | return x |
| |
|
| |
|
| | class STNkd(nn.Module): |
| | def __init__(self, k=64): |
| | super(STNkd, self).__init__() |
| | self.conv_layers = nn.Sequential( |
| | nn.Conv1d(k, 64, 1), |
| | nn.BatchNorm1d(64), |
| | nn.ReLU(inplace=True), |
| | nn.Conv1d(64, 128, 1), |
| | nn.BatchNorm1d(128), |
| | nn.ReLU(inplace=True), |
| | nn.Conv1d(128, 1024, 1), |
| | nn.BatchNorm1d(1024), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.linear_layers = nn.Sequential( |
| | nn.Linear(1024, 512), |
| | nn.BatchNorm1d(512), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(512, 256), |
| | nn.BatchNorm1d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(256, k * k) |
| | ) |
| | self.k = k |
| | self.iden = torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)).reshape(1, self.k * self.k) |
| |
|
| | def forward(self, x): |
| | batchsize = x.size()[0] |
| | x = self.conv_layers(x) |
| | x = torch.max(x, 2, keepdim=True)[0] |
| | x = x.view(-1, 1024) |
| | x = self.linear_layers(x) |
| | iden = self.iden.repeat(batchsize, 1).to(x.device) |
| | x = x + iden |
| | x = x.view(-1, self.k, self.k) |
| | return x |
| |
|
| |
|
| | class PointNetEncoder(nn.Module): |
| | def __init__(self, global_feat=True, feature_transform=False, in_channels=3): |
| | super(PointNetEncoder, self).__init__() |
| | self.stn = STN3d(in_channels) |
| | self.conv_layer1 = nn.Sequential( |
| | nn.Conv1d(in_channels, 64, 1), |
| | nn.BatchNorm1d(64), |
| | nn.ReLU(inplace=True), |
| | nn.Conv1d(64, 64, 1), |
| | nn.BatchNorm1d(64), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.conv_layer2 = nn.Sequential( |
| | nn.Conv1d(64, 64, 1), |
| | nn.BatchNorm1d(64), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.conv_layer3 = nn.Sequential( |
| | nn.Conv1d(64, 128, 1), |
| | nn.BatchNorm1d(128), |
| | nn.ReLU(inplace=True) |
| | ) |
| | self.conv_layer4 = nn.Sequential( |
| | nn.Conv1d(128, 1024, 1), |
| | nn.BatchNorm1d(1024) |
| | ) |
| | self.global_feat = global_feat |
| | self.feature_transform = feature_transform |
| | if self.feature_transform: |
| | self.fstn = STNkd(k=64) |
| |
|
| | def forward(self, x): |
| | B, D, N = x.size() |
| | trans = self.stn(x) |
| | x = x.transpose(2, 1) |
| | if D > 3: |
| | feature = x[:, :, 3:] |
| | x = x[:, :, :3] |
| | x = torch.bmm(x, trans) |
| | if D > 3: |
| | x = torch.cat([x, feature], dim=2) |
| | x = x.transpose(2, 1) |
| | x = self.conv_layer1(x) |
| |
|
| | if self.feature_transform: |
| | trans_feat = self.fstn(x) |
| | x = x.transpose(2, 1) |
| | x = torch.bmm(x, trans_feat) |
| | x = x.transpose(2, 1) |
| | else: |
| | trans_feat = None |
| |
|
| | pointfeat = x |
| | x = self.conv_layer2(x) |
| | x = self.conv_layer3(x) |
| | x = self.conv_layer4(x) |
| | x = torch.max(x, 2, keepdim=True)[0] |
| | x = x.view(-1, 1024) |
| | |
| | |
| | graph = construct_graph(x, args.k) |
| | context_features = compute_context_aware_features(x, graph) |
| | x = x + context_features |
| | |
| | if self.global_feat: |
| | return x, trans, trans_feat |
| | else: |
| | x = x.view(-1, 1024, 1).repeat(1, 1, N) |
| | return torch.cat([x, pointfeat], 1), trans, trans_feat |
| |
|
| |
|
| |
|
| | def construct_graph(points, k): |
| | """ |
| | Construct a dynamic graph where nodes represent points and edges capture semantic similarities. |
| | """ |
| | |
| | dist = torch.cdist(points, points) |
| | |
| | _, indices = torch.topk(dist, k, largest=False, dim=1) |
| | return indices |
| |
|
| | def compute_context_aware_features(points, graph, normalization_method='mean'): |
| | """ |
| | Compute context-aware feature adjustments using the constructed graph. |
| | """ |
| | |
| | context_features = torch.zeros_like(points) |
| | for i in range(points.size(0)): |
| | neighbors = graph[i] |
| | if normalization_method == 'mean': |
| | context_features[i] = points[neighbors].mean(dim=0) |
| | elif normalization_method == 'max': |
| | context_features[i] = points[neighbors].max(dim=0)[0] |
| | elif normalization_method == 'min': |
| | context_features[i] = points[neighbors].min(dim=0)[0] |
| | elif normalization_method == 'std': |
| | context_features[i] = points[neighbors].std(dim=0) |
| | else: |
| | raise ValueError("Unknown normalization method: {}".format(normalization_method)) |
| | return context_features |
| |
|
| | def feature_transform_reguliarzer(trans): |
| | d = trans.size()[1] |
| | I = torch.eye(d)[None, :, :] |
| | if trans.is_cuda: |
| | I = I.cuda() |
| | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))) |
| | return loss |
| |
|
| | class Model(nn.Module): |
| | def __init__(self, in_channels=3, num_classes=40, scale=0.001): |
| | super().__init__() |
| | self.mat_diff_loss_scale = scale |
| | self.backbone = PointNetEncoder(global_feat=True, feature_transform=True, in_channels=in_channels) |
| | self.cls_head = nn.Sequential( |
| | nn.Linear(1024, 512), |
| | nn.BatchNorm1d(512), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(512, 256), |
| | nn.Dropout(p=0.4), |
| | nn.BatchNorm1d(256), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(256, num_classes) |
| | ) |
| | |
| | def forward(self, x, gts): |
| | x, trans, trans_feat = self.backbone(x) |
| | x = self.cls_head(x) |
| | x = F.log_softmax(x, dim=1) |
| | loss = F.nll_loss(x, gts) |
| | mat_diff_loss = feature_transform_reguliarzer(trans_feat) |
| | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale |
| | return total_loss, x |
| |
|
| |
|
| | """ |
| | dataset and normalization |
| | """ |
| | def pc_normalize(pc): |
| | centroid = np.mean(pc, axis=0) |
| | pc = pc - centroid |
| | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) |
| | pc = pc / m |
| | return pc |
| |
|
| |
|
| | class ModelNetDataset(Dataset): |
| | def __init__(self, data_root, num_category, num_points, split='train'): |
| | self.root = data_root |
| | self.npoints = num_points |
| | self.uniform = True |
| | self.use_normals = True |
| | self.num_category = num_category |
| |
|
| | if self.num_category == 10: |
| | self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt') |
| | else: |
| | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') |
| |
|
| | self.cat = [line.rstrip() for line in open(self.catfile)] |
| | self.classes = dict(zip(self.cat, range(len(self.cat)))) |
| |
|
| | shape_ids = {} |
| | if self.num_category == 10: |
| | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))] |
| | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))] |
| | else: |
| | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] |
| | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] |
| |
|
| | assert (split == 'train' or split == 'test') |
| | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] |
| | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i |
| | in range(len(shape_ids[split]))] |
| | print('The size of %s data is %d' % (split, len(self.datapath))) |
| |
|
| | if self.uniform: |
| | self.data_path = os.path.join(data_root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints)) |
| | else: |
| | self.data_path = os.path.join(data_root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints)) |
| |
|
| | print('Load processed data from %s...' % self.data_path) |
| | with open(self.data_path, 'rb') as f: |
| | self.list_of_points, self.list_of_labels = pickle.load(f) |
| |
|
| | def __len__(self): |
| | return len(self.datapath) |
| |
|
| | def __getitem__(self, index): |
| | point_set, label = self.list_of_points[index], self.list_of_labels[index] |
| | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) |
| | if not self.use_normals: |
| | point_set = point_set[:, 0:3] |
| | return point_set, label[0] |
| |
|
| |
|
| | def seed_everything(seed=11): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| |
|
| | def main(args): |
| |
|
| | seed_everything(args.seed) |
| |
|
| | final_infos = {} |
| | all_results = {} |
| |
|
| | pathlib.Path(args.out_dir).mkdir(parents=True, exist_ok=True) |
| | |
| | datasets, dataloaders = {}, {} |
| | for split in ['train', 'test']: |
| | datasets[split] = ModelNetDataset(args.data_root, args.num_category, args.num_points, split) |
| | dataloaders[split] = DataLoader(datasets[split], batch_size=args.batch_size, shuffle=(split == 'train'), |
| | drop_last=(split == 'train'), num_workers=8) |
| | |
| | model = Model(in_channels=args.in_channels).cuda() |
| | optimizer = torch.optim.Adam( |
| | model.parameters(), lr=args.learning_rate, |
| | betas=(0.9, 0.999), eps=1e-8, |
| | weight_decay=1e-4 |
| | ) |
| | scheduler = torch.optim.lr_scheduler.StepLR( |
| | optimizer, step_size=20, gamma=0.7 |
| | ) |
| | train_losses = [] |
| | print("Training model...") |
| | model.train() |
| | global_step = 0 |
| | cur_epoch = 0 |
| | best_oa = 0 |
| | best_acc = 0 |
| |
|
| | start_time = time.time() |
| | for epoch in tqdm(range(args.max_epoch), desc='training'): |
| | model.train() |
| | cm = ConfusionMatrix(num_classes=len(datasets['train'].classes)) |
| | for points, target in tqdm(dataloaders['train'], desc=f'epoch {cur_epoch}/{args.max_epoch}'): |
| | |
| | points = points.data.numpy() |
| | points = data_transforms.random_point_dropout(points) |
| | points[:, :, 0:3] = data_transforms.random_scale_point_cloud(points[:, :, 0:3]) |
| | points[:, :, 0:3] = data_transforms.shift_point_cloud(points[:, :, 0:3]) |
| | points = torch.from_numpy(points).transpose(2, 1).contiguous() |
| | |
| | points, target = points.cuda(), target.long().cuda() |
| | |
| | loss, logits = model(points, target) |
| | loss.backward() |
| |
|
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1, norm_type=2) |
| | optimizer.step() |
| | model.zero_grad() |
| | |
| | |
| | logs = {"loss": loss.detach().item()} |
| | train_losses.append(loss.detach().item()) |
| | cm.update(logits.argmax(dim=1), target) |
| | |
| | scheduler.step() |
| | end_time = time.time() |
| | training_time = end_time - start_time |
| | macc, overallacc, accs = cm.all_acc() |
| | print(f"iter: {global_step}/{args.max_epoch*len(dataloaders['train'])}, \ |
| | train_macc: {macc}, train_oa: {overallacc}") |
| | |
| | if (cur_epoch % args.val_per_epoch == 0 and cur_epoch != 0) or cur_epoch == (args.max_epoch - 1): |
| | model.eval() |
| | cm = ConfusionMatrix(num_classes=datasets['test'].num_category) |
| | pbar = tqdm(enumerate(dataloaders['test']), total=dataloaders['test'].__len__()) |
| | |
| | for idx, (points, target) in pbar: |
| | points, target = points.cuda(), target.long().cuda() |
| | points = points.transpose(2, 1).contiguous() |
| | loss, logits = model(points, target) |
| | cm.update(logits.argmax(dim=1), target) |
| | |
| | tp, count = cm.tp, cm.count |
| | macc, overallacc, accs = cm.cal_acc(tp, count) |
| | print(f"iter: {global_step}/{args.max_epoch*len(dataloaders['train'])}, \ |
| | val_macc: {macc}, val_oa: {overallacc}") |
| | |
| | if overallacc > best_oa: |
| | best_oa = overallacc |
| | best_acc = macc |
| | best_epoch = cur_epoch |
| | torch.save(model.state_dict(), os.path.join(args.out_dir, 'best.pth')) |
| | cur_epoch += 1 |
| |
|
| | print(f"finish epoch {cur_epoch} training") |
| |
|
| | final_infos = { |
| | "modelnet" + str(args.num_category):{ |
| | "means":{ |
| | "best_oa": best_oa, |
| | "best_acc": best_acc, |
| | "epoch": best_epoch |
| | } |
| | } |
| | } |
| | with open(os.path.join(args.out_dir, "final_info.json"), "w") as f: |
| | json.dump(final_infos, f) |
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--batch_size", type=int, default=64) |
| | parser.add_argument("--out_dir", type=str, default="run_0") |
| | parser.add_argument("--in_channels", type=int, default=6) |
| | parser.add_argument("--num_points", type=int, default=1024) |
| | parser.add_argument("--num_category", type=int, choices=[10, 40], default=40) |
| | parser.add_argument("--data_root", type=str, default='./datasets/modelnet40') |
| | parser.add_argument("--learning_rate", type=float, default=1e-3) |
| | parser.add_argument("--max_epoch", type=int, default=200) |
| | parser.add_argument("--val_per_epoch", type=int, default=5) |
| | parser.add_argument("--k", type=int, default=5, help="Number of neighbors for graph construction") |
| | parser.add_argument("--seed", type=int, default=666) |
| | args = parser.parse_args() |
| |
|
| | try: |
| | main(args) |
| | except Exception as e: |
| | print("Original error in subprocess:", flush=True) |
| | traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
| | raise |