zoo3d / exts /maskclustering_runner.py
drozdgk's picture
init
4eeefd1
import os
import sys
from typing import Optional
import argparse
import json
"""
Python runner for Indoor/MaskClustering main pipeline without spawning a process.
It reuses the project's own main(args) and get_args() utilities.
"""
_MK_PATH = None
_get_args = None
def make_maskclustering_dir(MK_PATH: str) -> None:
"""
Ensure Indoor/MaskClustering repo is on sys.path so that
utils.*, graph.*, and main can be imported directly.
"""
global _MK_PATH
_MK_PATH = MK_PATH
if MK_PATH not in sys.path:
sys.path.insert(0, MK_PATH)
from dataset.scannet import WildDataset
def update_args(args):
config = args.config
config_file = config
if config in ['scannet', 'scannet18']:
config_file = 'scannet'
if config in ['scannetpp_v2_dust3r_posed', 'scannetpp_v2_dust3r_unposed']:
config_file = config
config_path = f'/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/configs/{config_file}.json'
with open(config_path, 'r') as f:
config_data = json.load(f)
for key in config_data:
setattr(args, key, config_data[key])
return args
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--seq_name', type=str)
parser.add_argument('--seq_name_list', type=str)
parser.add_argument('--config', type=str, default='scannet')
parser.add_argument('--debug', action="store_true")
parser.add_argument('--root', type=str)
parser.add_argument('-d', '--devices', type=int, nargs='+', default=[0, 1, 2, 3])
args = parser.parse_args()
args = update_args(args)
return args
def get_dataset(args):
if args.dataset == 'wild':
dataset = WildDataset(args.seq_name, root=args.root)
return dataset
global _get_args
_get_args = get_args
def run_mask_clustering(
config: str,
root: str,
seq_name_list: str,
step: Optional[int] = None,
view_consensus_threshold: Optional[float] = None,
debug: Optional[bool] = None,
) -> None:
"""
Execute the MaskClustering pipeline for one or multiple sequences.
Equivalent to:
python main.py --config {config} --root {root} --seq_name_list {seq_name_list}
with optional overrides for step, view_consensus_threshold, and debug.
"""
if _MK_PATH is None or _MK_PATH not in sys.path:
# Fallback: try to infer from environment variable or raise
env_mk = os.environ.get("MASKCLUSTERING_PATH")
if env_mk:
make_maskclustering_dir(env_mk)
else:
# Proceed; imports might still work if paths are globally set elsewhere
pass
# Lazy imports to avoid cost at module import time
from main import main as mk_main # type: ignore
# Build args from library defaults, then override what we need
args = _get_args()
args.config = config
args.root = root
args.seq_name_list = seq_name_list
if step is not None:
args.step = step
if view_consensus_threshold is not None:
args.view_consensus_threshold = view_consensus_threshold
if debug is not None:
args.debug = debug
# Emulate original __main__ loop over sequences
seqs = args.seq_name_list.split("+")
for seq_name in seqs:
args.seq_name = seq_name
mk_main(args)