|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import sys
|
| sys.path.append("./")
|
| import shutil
|
| import argparse
|
| import numpy as np
|
|
|
| from utils.log_utils import AverageMeter
|
| from utils.os_utils import isdir, mkdir_p, isfile
|
| from utils.io_utils import output_rigging
|
|
|
| import torch
|
| import torch.backends.cudnn as cudnn
|
| from torch_geometric.data import DataLoader
|
| from torch.utils.tensorboard import SummaryWriter
|
|
|
| import models
|
| from models.supplemental_layers.cross_entropy_with_probs import cross_entropy_with_probs
|
| from datasets.skin_dataset import SkinDataset
|
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None):
|
| filepath = os.path.join(checkpoint, filename)
|
| torch.save(state, filepath)
|
|
|
| if snapshot and state['epoch'] % snapshot == 0:
|
| shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch'])))
|
|
|
| if is_best:
|
| shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))
|
|
|
|
|
| def get_bone_names(skin_filename):
|
| with open(skin_filename, 'r') as fin:
|
| lines = fin.readlines()
|
| bone_names = []
|
| for li in lines:
|
| words = li.strip().split()
|
| if words[0] == 'bones':
|
| bone_names.append([words[1], words[2]])
|
| return bone_names
|
|
|
|
|
| def post_filter(skin_weights, topology_edge, num_ring=1):
|
| skin_weights_new = np.zeros_like(skin_weights)
|
| for v in range(len(skin_weights)):
|
| adj_verts_multi_ring = []
|
| current_seeds = [v]
|
| for r in range(num_ring):
|
| adj_verts = []
|
| for seed in current_seeds:
|
| adj_edges = topology_edge[:, np.argwhere(topology_edge == seed)[:, 1]]
|
| adj_verts_seed = list(set(adj_edges.flatten().tolist()))
|
| adj_verts_seed.remove(seed)
|
| adj_verts += adj_verts_seed
|
| adj_verts_multi_ring += adj_verts
|
| current_seeds = adj_verts
|
| adj_verts_multi_ring = list(set(adj_verts_multi_ring))
|
| if v in adj_verts_multi_ring:
|
| adj_verts_multi_ring.remove(v)
|
| skin_weights_neighbor = [skin_weights[int(i), :][np.newaxis, :] for i in adj_verts_multi_ring]
|
| skin_weights_neighbor = np.concatenate(skin_weights_neighbor, axis=0)
|
|
|
|
|
|
|
|
|
|
|
| skin_weights_new[v, :] = np.mean(skin_weights_neighbor, axis=0)
|
|
|
|
|
| return skin_weights_new
|
|
|
|
|
| def main(args):
|
| global device
|
| lowest_loss = 1e20
|
|
|
|
|
| if not isdir(args.checkpoint):
|
| print("Create new checkpoint folder " + args.checkpoint)
|
| mkdir_p(args.checkpoint)
|
| if not args.resume:
|
| if isdir(args.logdir):
|
| shutil.rmtree(args.logdir)
|
| mkdir_p(args.logdir)
|
|
|
|
|
| model = models.__dict__["skinnet"](args.nearest_bone, args.Dg, args.Lf)
|
| model.to(device)
|
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
|
|
|
| lr = args.lr
|
| if args.resume:
|
| if isfile(args.resume):
|
| print("=> loading checkpoint '{}'".format(args.resume))
|
| checkpoint = torch.load(args.resume)
|
| args.start_epoch = checkpoint['epoch']
|
| lowest_loss = checkpoint['lowest_loss']
|
| model.load_state_dict(checkpoint['state_dict'])
|
| optimizer.load_state_dict(checkpoint['optimizer'])
|
| lr = optimizer.param_groups[0]['lr']
|
| print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
| else:
|
| print("=> no checkpoint found at '{}'".format(args.resume))
|
|
|
| cudnn.benchmark = True
|
| print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
|
| train_loader = DataLoader(SkinDataset(root=args.train_folder), batch_size=args.train_batch, shuffle=True)
|
| val_loader = DataLoader(SkinDataset(root=args.val_folder), batch_size=args.test_batch, shuffle=False)
|
| test_loader = DataLoader(SkinDataset(root=args.test_folder), batch_size=args.test_batch, shuffle=False)
|
| if args.evaluate:
|
| print('\nEvaluation only')
|
| test_loss = test(test_loader, model, args, save_result=True)
|
| print('test_loss {:6f}'.format(test_loss))
|
| return
|
| scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=args.gamma)
|
| logger = SummaryWriter(log_dir=args.logdir)
|
| for epoch in range(args.start_epoch, args.epochs):
|
| lr = scheduler.get_last_lr()
|
| print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
|
| train_loss = train(train_loader, model, optimizer, args)
|
| val_loss = test(val_loader, model, args)
|
| test_loss = test(test_loader, model, args)
|
| scheduler.step()
|
| print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
|
| print('Epoch{:d}. val_loss: {:.6f}.'.format(epoch + 1, val_loss))
|
| print('Epoch{:d}. test_loss: {:.6f}.'.format(epoch + 1, test_loss))
|
|
|
|
|
| is_best = val_loss < lowest_loss
|
| lowest_loss = min(val_loss, lowest_loss)
|
| save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lowest_loss': lowest_loss,
|
| 'optimizer': optimizer.state_dict()}, is_best, checkpoint=args.checkpoint)
|
|
|
| info = {'train_loss': train_loss, 'val_loss': val_loss, 'test_loss': test_loss}
|
| for tag, value in info.items():
|
| logger.add_scalar(tag, value, epoch + 1)
|
|
|
|
|
| def train(train_loader, model, optimizer, args):
|
| global device
|
| model.train()
|
| loss_meter = AverageMeter()
|
| for data in train_loader:
|
| data = data.to(device)
|
| optimizer.zero_grad()
|
| skin_pred = model(data)
|
| skin_gt = data.skin_label[:, 0:args.nearest_bone]
|
| loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone]
|
| skin_gt = skin_gt * loss_mask_batch
|
| skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8)
|
| vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float()
|
| loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none')
|
| loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum()
|
| loss.backward()
|
| optimizer.step()
|
| loss_meter.update(loss.item())
|
| return loss_meter.avg
|
|
|
|
|
| def test(test_loader, model, args, save_result=False):
|
| global device
|
| model.eval()
|
| loss_meter = AverageMeter()
|
| outdir = args.checkpoint.split('/')[-1]
|
| for data in test_loader:
|
| data = data.to(device)
|
| with torch.no_grad():
|
| skin_pred = model(data)
|
| skin_gt = data.skin_label[:, 0:args.nearest_bone]
|
| loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone]
|
| skin_gt = skin_gt * loss_mask_batch
|
| skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8)
|
| vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float()
|
| loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none')
|
| loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum()
|
| loss_meter.update(loss.item())
|
|
|
| if save_result:
|
| output_folder = 'results/{:s}/'.format(outdir)
|
| if not os.path.exists(output_folder):
|
| mkdir_p(output_folder)
|
| for i in range(len(torch.unique(data.batch))):
|
| print('output result for model {:d}'.format(data.name[i].item()))
|
| skin_pred_i = skin_pred[data.batch == i]
|
| bone_names = get_bone_names(os.path.join(args.test_folder, "{:d}_skin.txt".format(data.name[i].item())))
|
| tpl_e = np.loadtxt(os.path.join(args.test_folder, "{:d}_tpl_e.txt".format(data.name[i].item()))).T
|
| loss_mask_sample = data.loss_mask.float()[data.batch == i, 0:args.nearest_bone]
|
| skin_pred_i = torch.softmax(skin_pred_i, dim=1)
|
| skin_pred_i = skin_pred_i * loss_mask_sample
|
| skin_nn_i = data.skin_nn[data.batch == i, 0:args.nearest_bone]
|
| skin_pred_asarray = np.zeros((len(skin_pred_i), len(bone_names)))
|
| for v in range(len(skin_pred_i)):
|
| for nn_id in range(len(skin_nn_i[v, :])):
|
| skin_pred_asarray[v, skin_nn_i[v, nn_id]] = skin_pred_i[v, nn_id]
|
| skin_pred_asarray = post_filter(skin_pred_asarray, tpl_e, num_ring=1)
|
| skin_pred_asarray[skin_pred_asarray < np.max(skin_pred_asarray, axis=1, keepdims=True) * 0.5] = 0.0
|
| skin_pred_asarray = skin_pred_asarray / (skin_pred_asarray.sum(axis=1, keepdims=True) + 1e-10)
|
| with open(os.path.join(output_folder, "{:d}_bone_names.txt".format(data.name[i].item())), 'w') as fout:
|
| for bone_name in bone_names:
|
| fout.write("{:s} {:s}\n".format(bone_name[0], bone_name[1]))
|
| np.save(os.path.join(output_folder, "{:d}_full_pred.npy".format(data.name[i].item())), skin_pred_asarray)
|
| skel_filename = os.path.join(args.info_folder, "{:d}.txt".format(data.name[i].item()))
|
| output_rigging(skel_filename, skin_pred_asarray, output_folder, data.name[i].item())
|
| return loss_meter.avg
|
|
|
|
|
| if __name__ == '__main__':
|
| parser = argparse.ArgumentParser(description='skinning predition network')
|
| parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number')
|
| parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
|
| parser.add_argument('--gamma', type=float, default=0.5, help='LR is multiplied by gamma on schedule.')
|
| parser.add_argument('-j', '--workers', default=1, type=int, metavar='N', help='number of data loading workers (default: 4)')
|
| parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
|
| parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='initial learning rate')
|
| parser.add_argument('--schedule', type=int, nargs='+', default=[], help='Decrease learning rate at these epochs.')
|
| parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on val/test set')
|
|
|
| parser.add_argument('--train_batch', default=2, type=int, metavar='N', help='train batchsize')
|
| parser.add_argument('--test_batch', default=2, type=int, metavar='N', help='test batchsize')
|
| parser.add_argument('-c', '--checkpoint', default='checkpoints/test', type=str, metavar='PATH',
|
| help='path to save checkpoint (default: checkpoint)')
|
| parser.add_argument('--logdir', default='logs/test', type=str, metavar='LOG', help='directory to save logs')
|
| parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
|
| parser.add_argument('--train_folder', default='/media/zhanxu/4T/ModelResource_RigNetv1_preproccessed/train/',
|
| type=str, help='folder of training data')
|
| parser.add_argument('--val_folder', default='/media/zhanxu/4T/ModelResource_RigNetv1_preproccessed/val/',
|
| type=str, help='folder of validation data')
|
| parser.add_argument('--test_folder', default='/media/zhanxu/4T/ModelResource_RigNetv1_preproccessed/test/',
|
| type=str, help='folder of testing data')
|
| parser.add_argument('--nearest_bone', type=int, default=5)
|
| parser.add_argument('--info_folder', default='/media/zhanxu/4T/ModelResource_RigNetv1_preproccessed/rig_info_remesh/',
|
| type=str, help='folder of skeleton information')
|
| parser.add_argument('--Dg', action='store_true', help='input inverset geodesic as addtional feature')
|
| parser.add_argument('--Lf', action='store_true', help='input isleaf indicator as addtional feature')
|
| print(parser.parse_args())
|
| main(parser.parse_args())
|
|
|