| | |
| | |
| | |
| | |
| | |
| | |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): |
| | pairs = [] |
| | if scene_graph == 'complete': |
| | for i in range(len(imgs)): |
| | for j in range(i): |
| | pairs.append((imgs[i], imgs[j])) |
| | elif scene_graph.startswith('swin'): |
| | iscyclic = not scene_graph.endswith('noncyclic') |
| | try: |
| | winsize = int(scene_graph.split('-')[1]) |
| | except Exception as e: |
| | winsize = 3 |
| | pairsid = set() |
| | for i in range(len(imgs)): |
| | for j in range(1, winsize + 1): |
| | idx = (i + j) |
| | if iscyclic: |
| | idx = idx % len(imgs) |
| | if idx >= len(imgs): |
| | continue |
| | pairsid.add((i, idx) if i < idx else (idx, i)) |
| | for i, j in pairsid: |
| | pairs.append((imgs[i], imgs[j])) |
| | elif scene_graph.startswith('logwin'): |
| | iscyclic = not scene_graph.endswith('noncyclic') |
| | try: |
| | winsize = int(scene_graph.split('-')[1]) |
| | except Exception as e: |
| | winsize = 3 |
| | offsets = [2**i for i in range(winsize)] |
| | pairsid = set() |
| | for i in range(len(imgs)): |
| | ixs_l = [i - off for off in offsets] |
| | ixs_r = [i + off for off in offsets] |
| | for j in ixs_l + ixs_r: |
| | if iscyclic: |
| | j = j % len(imgs) |
| | if j < 0 or j >= len(imgs) or j == i: |
| | continue |
| | pairsid.add((i, j) if i < j else (j, i)) |
| | for i, j in pairsid: |
| | pairs.append((imgs[i], imgs[j])) |
| | elif scene_graph.startswith('oneref'): |
| | refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 |
| | for j in range(len(imgs)): |
| | if j != refid: |
| | pairs.append((imgs[refid], imgs[j])) |
| | if symmetrize: |
| | pairs += [(img2, img1) for img1, img2 in pairs] |
| |
|
| | |
| | if isinstance(prefilter, str) and prefilter.startswith('seq'): |
| | pairs = filter_pairs_seq(pairs, int(prefilter[3:])) |
| |
|
| | if isinstance(prefilter, str) and prefilter.startswith('cyc'): |
| | pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) |
| |
|
| | return pairs |
| |
|
| |
|
| | def sel(x, kept): |
| | if isinstance(x, dict): |
| | return {k: sel(v, kept) for k, v in x.items()} |
| | if isinstance(x, (torch.Tensor, np.ndarray)): |
| | return x[kept] |
| | if isinstance(x, (tuple, list)): |
| | return type(x)([x[k] for k in kept]) |
| |
|
| |
|
| | def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): |
| | |
| | n = max(max(e) for e in edges) + 1 |
| |
|
| | kept = [] |
| | for e, (i, j) in enumerate(edges): |
| | dis = abs(i - j) |
| | if cyclic: |
| | dis = min(dis, abs(i + n - j), abs(i - n - j)) |
| | if dis <= seq_dis_thr: |
| | kept.append(e) |
| | return kept |
| |
|
| |
|
| | def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): |
| | edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] |
| | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) |
| | return [pairs[i] for i in kept] |
| |
|
| |
|
| | def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): |
| | edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] |
| | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) |
| | print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') |
| | return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) |
| |
|