File size: 3,442 Bytes
4eeefd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys
from typing import Optional
import argparse
import json

"""
Python runner for Indoor/MaskClustering main pipeline without spawning a process.
It reuses the project's own main(args) and get_args() utilities.
"""

_MK_PATH = None
_get_args = None


def make_maskclustering_dir(MK_PATH: str) -> None:
    """
    Ensure Indoor/MaskClustering repo is on sys.path so that
    utils.*, graph.*, and main can be imported directly.
    """
    global _MK_PATH
    _MK_PATH = MK_PATH
    if MK_PATH not in sys.path:
        sys.path.insert(0, MK_PATH)
    from dataset.scannet import WildDataset
    def update_args(args):
        config = args.config
        config_file = config
        if config in ['scannet', 'scannet18']:
            config_file = 'scannet'
        if config in ['scannetpp_v2_dust3r_posed', 'scannetpp_v2_dust3r_unposed']:
            config_file = config
        config_path = f'/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/configs/{config_file}.json'
        with open(config_path, 'r') as f:
            config_data = json.load(f)
        for key in config_data:
            setattr(args, key, config_data[key])
        return args

    def get_args():
        parser = argparse.ArgumentParser()
        parser.add_argument('--seq_name', type=str)
        parser.add_argument('--seq_name_list', type=str)
        parser.add_argument('--config', type=str, default='scannet')
        parser.add_argument('--debug', action="store_true")
        parser.add_argument('--root', type=str)
        parser.add_argument('-d', '--devices', type=int, nargs='+', default=[0, 1, 2, 3])

        args = parser.parse_args()
        args = update_args(args)
        return args

    def get_dataset(args):
        
        if args.dataset == 'wild':
            dataset = WildDataset(args.seq_name, root=args.root)
        return dataset

    global _get_args
    _get_args = get_args

def run_mask_clustering(
    config: str,
    root: str,
    seq_name_list: str,
    step: Optional[int] = None,
    view_consensus_threshold: Optional[float] = None,
    debug: Optional[bool] = None,
) -> None:
    """
    Execute the MaskClustering pipeline for one or multiple sequences.
    Equivalent to:
      python main.py --config {config} --root {root} --seq_name_list {seq_name_list}
    with optional overrides for step, view_consensus_threshold, and debug.
    """
    if _MK_PATH is None or _MK_PATH not in sys.path:
        # Fallback: try to infer from environment variable or raise
        env_mk = os.environ.get("MASKCLUSTERING_PATH")
        if env_mk:
            make_maskclustering_dir(env_mk)
        else:
            # Proceed; imports might still work if paths are globally set elsewhere
            pass

    # Lazy imports to avoid cost at module import time
    from main import main as mk_main  # type: ignore

    



    # Build args from library defaults, then override what we need
    args = _get_args()
    args.config = config
    args.root = root
    args.seq_name_list = seq_name_list
    if step is not None:
        args.step = step
    if view_consensus_threshold is not None:
        args.view_consensus_threshold = view_consensus_threshold
    if debug is not None:
        args.debug = debug

    # Emulate original __main__ loop over sequences
    seqs = args.seq_name_list.split("+")
    for seq_name in seqs:
        args.seq_name = seq_name
        mk_main(args)