File size: 1,976 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os.path as osp
import sys
import numpy as np
import itertools
from pathlib import Path
from collections import defaultdict

ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)

from .dt4dintra import IGNORED_CATEGORIES
from .dt4dintra import ShapeDataset
from .faust import ShapePairDataset as FaustShapePairDataset
from utils.mesh import list_files

#IGNORED_CATEGORIES = ["drake", "mannequin", "ninja", "prisoner", "zlorp", "pumpkinhulk"]
IGNORED_CATEGORIES = ["pumpkinhulk"]
class ShapePairDataset(FaustShapePairDataset):

    def _init(self):
        self.name_id_map = self.shape_data.get_name_id_map()
        categories = defaultdict(list)
        for sname in self.name_id_map.keys():
            categories[sname.split('/')[0]].append(sname)
        self.pair_indices = list()
        for filename in list_files(osp.join(self.corr_dir, 'cross_category_corres'), '*.vts', alphanum_sort=False):
            cname0, cname1 = filename[:-4].split('_')
            if cname0 in IGNORED_CATEGORIES or cname1 in IGNORED_CATEGORIES:
                continue
            for sname0 in categories[cname0]:
                for sname1 in categories[cname1]:
                    self.pair_indices.append((self.name_id_map[sname0], self.name_id_map[sname1]))

    def _load_corr_gt(self, sdict0, sdict1):
        sname0 = sdict0['name']
        sname1 = sdict1['name']
        cname0 = sname0.split('/')[0]
        cname1 = sname1.split('/')[0]
        assert cname0 != cname1
        lmk01 = self._load_corr_file(f'cross_category_corres/{cname0}_{cname1}')
        corr0 = self._load_corr_file(sname0)
        corr1 = self._load_corr_file(sname1)
        corr_gt = np.stack((corr0, corr1[lmk01]), axis=1)
        return corr_gt

    def _load_corr_file(self, sname):
        corr_path = osp.join(self.corr_dir, f'{sname}.vts')
        corr = np.loadtxt(corr_path, dtype=np.int32)
        return corr - 1