File size: 1,664 Bytes
e321b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os.path as osp
import sys
import numpy as np
import re
from pathlib import Path
from itertools import permutations as pmt

ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)

from shape_data.faust import ShapeDataset as FaustShapeDataset
from shape_data.faust import ShapePairDataset as FaustShapePairDataset
from utils.io import list_files

def contains_any_regex(substrings, ext, texts):
    pattern = re.compile('|'.join(map(re.escape, substrings)))  # Compile regex once
    return [text for text in texts if bool(pattern.search(text)) and (ext in text)]  # Apply to all texts efficiently


class ShapeDataset(FaustShapeDataset):
    TRAIN_IDX = None
    TEST_IDX = None

    def _get_file_list(self):
        if self.mode.startswith('train'):
            categories = None
        elif self.mode.startswith('test'):
            categories = ['cat', 'dog', 'horse', 'wolf']
        else:
            raise RuntimeError(f'Mode {self.mode} is not supported.')
        file_list = list_files(self.shape_dir, '*.off', alphanum_sort=True)
        shape_list = contains_any_regex(categories, ".off", file_list)
        return shape_list


class ShapePairDataset(FaustShapePairDataset):
    categories = ['cat', 'dog', 'horse', 'wolf']

    def _init(self):
        assert self.mode.startswith('test')
        self.name_id_map = self.shape_data.get_name_id_map()
        self.pair_indices = list()
        for cat in self.categories:
            shape_list_temp = [self.name_id_map[fn] for fn in self.name_id_map if cat in fn]
            self.pair_indices += list(pmt(shape_list_temp, 2))