|
|
import json
|
|
|
from collections import defaultdict
|
|
|
import os
|
|
|
import shutil
|
|
|
import tarfile
|
|
|
from pathlib import Path
|
|
|
from typing import Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import pytorch_lightning as pl
|
|
|
import torch
|
|
|
import torch.utils.data as torchdata
|
|
|
from omegaconf import DictConfig
|
|
|
|
|
|
from ... import logger
|
|
|
from .dataset import MapLocDataset
|
|
|
from ..sequential import chunk_sequence
|
|
|
from ..torch import collate, worker_init_fn
|
|
|
from ..schema import MIADataConfiguration
|
|
|
|
|
|
def pack_dump_dict(dump):
|
|
|
for per_seq in dump.values():
|
|
|
if "points" in per_seq:
|
|
|
for chunk in list(per_seq["points"]):
|
|
|
points = per_seq["points"].pop(chunk)
|
|
|
if points is not None:
|
|
|
per_seq["points"][chunk] = np.array(
|
|
|
per_seq["points"][chunk], np.float64
|
|
|
)
|
|
|
for view in per_seq["views"].values():
|
|
|
for k in ["R_c2w", "roll_pitch_yaw"]:
|
|
|
view[k] = np.array(view[k], np.float32)
|
|
|
for k in ["chunk_id"]:
|
|
|
if k in view:
|
|
|
view.pop(k)
|
|
|
if "observations" in view:
|
|
|
view["observations"] = np.array(view["observations"])
|
|
|
for camera in per_seq["cameras"].values():
|
|
|
for k in ["params"]:
|
|
|
camera[k] = np.array(camera[k], np.float32)
|
|
|
return dump
|
|
|
|
|
|
|
|
|
class MapillaryDataModule(pl.LightningDataModule):
|
|
|
dump_filename = "dump.json"
|
|
|
images_archive = "images.tar.gz"
|
|
|
images_dirname = "images/"
|
|
|
semantic_masks_dirname = "semantic_masks/"
|
|
|
flood_dirname = "flood_fill/"
|
|
|
|
|
|
def __init__(self, cfg: MIADataConfiguration):
|
|
|
super().__init__()
|
|
|
self.cfg = cfg
|
|
|
self.root = self.cfg.data_dir
|
|
|
self.local_dir = None
|
|
|
|
|
|
def prepare_data(self):
|
|
|
for scene in self.cfg.scenes:
|
|
|
dump_dir = self.root / scene
|
|
|
assert (dump_dir / self.dump_filename).exists(), dump_dir
|
|
|
|
|
|
if self.local_dir is None:
|
|
|
assert (dump_dir / self.images_dirname).exists(), dump_dir
|
|
|
continue
|
|
|
assert (dump_dir / self.semantic_masks_dirname).exists(), dump_dir
|
|
|
assert (dump_dir / self.flood_dirname).exists(), dump_dir
|
|
|
|
|
|
local_dir = self.local_dir / scene
|
|
|
if local_dir.exists():
|
|
|
shutil.rmtree(local_dir)
|
|
|
local_dir.mkdir(exist_ok=True, parents=True)
|
|
|
images_archive = dump_dir / self.images_archive
|
|
|
logger.info("Extracting the image archive %s.", images_archive)
|
|
|
with tarfile.open(images_archive) as fp:
|
|
|
fp.extractall(local_dir)
|
|
|
|
|
|
def setup(self, stage: Optional[str] = None):
|
|
|
self.dumps = {}
|
|
|
|
|
|
self.image_dirs = {}
|
|
|
self.seg_masks_dir = {}
|
|
|
self.flood_masks_dir = {}
|
|
|
names = []
|
|
|
|
|
|
for scene in self.cfg.scenes:
|
|
|
logger.info("Loading scene %s.", scene)
|
|
|
dump_dir = self.root / scene
|
|
|
|
|
|
logger.info("Loading dump json file %s.", self.dump_filename)
|
|
|
with (dump_dir / self.dump_filename).open("r") as fp:
|
|
|
self.dumps[scene] = pack_dump_dict(json.load(fp))
|
|
|
for seq, per_seq in self.dumps[scene].items():
|
|
|
for cam_id, cam_dict in per_seq["cameras"].items():
|
|
|
if cam_dict["model"] != "PINHOLE":
|
|
|
raise ValueError(
|
|
|
f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}"
|
|
|
)
|
|
|
|
|
|
self.image_dirs[scene] = (
|
|
|
(self.local_dir or self.root) / scene / self.images_dirname
|
|
|
)
|
|
|
assert self.image_dirs[scene].exists(), self.image_dirs[scene]
|
|
|
|
|
|
self.seg_masks_dir[scene] = (
|
|
|
(self.local_dir or self.root) / scene / self.semantic_masks_dirname
|
|
|
)
|
|
|
assert self.seg_masks_dir[scene].exists(), self.seg_masks_dir[scene]
|
|
|
|
|
|
self.flood_masks_dir[scene] = (
|
|
|
(self.local_dir or self.root) / scene / self.flood_dirname
|
|
|
)
|
|
|
assert self.flood_masks_dir[scene].exists(), self.flood_masks_dir[scene]
|
|
|
|
|
|
images = set(x.split('.')[0] for x in os.listdir(self.image_dirs[scene]))
|
|
|
flood_masks = set(x.split('.')[0] for x in os.listdir(self.flood_masks_dir[scene]))
|
|
|
semantic_masks = set(x.split('.')[0] for x in os.listdir(self.seg_masks_dir[scene]))
|
|
|
|
|
|
for seq, data in self.dumps[scene].items():
|
|
|
for name in data["views"]:
|
|
|
if name in images and name.split("_")[0] in flood_masks and name.split("_")[0] in semantic_masks:
|
|
|
names.append((scene, seq, name))
|
|
|
|
|
|
self.parse_splits(self.cfg.split, names)
|
|
|
if self.cfg.filter_for is not None:
|
|
|
self.filter_elements()
|
|
|
self.pack_data()
|
|
|
|
|
|
def pack_data(self):
|
|
|
|
|
|
exclude = {
|
|
|
"compass_angle",
|
|
|
"compass_accuracy",
|
|
|
"gps_accuracy",
|
|
|
"chunk_key",
|
|
|
"panorama_offset",
|
|
|
}
|
|
|
cameras = {
|
|
|
scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()}
|
|
|
for scene, per_scene in self.dumps.items()
|
|
|
}
|
|
|
points = {
|
|
|
scene: {
|
|
|
seq: {
|
|
|
i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items()
|
|
|
}
|
|
|
for seq, per_seq in per_scene.items()
|
|
|
}
|
|
|
for scene, per_scene in self.dumps.items()
|
|
|
}
|
|
|
self.data = {}
|
|
|
|
|
|
|
|
|
if self.cfg.split == "splits_MGL_13loc.json":
|
|
|
|
|
|
num_samples_to_move = int(len(self.splits['train']) * 0.2)
|
|
|
samples_to_move = self.splits['train'][-num_samples_to_move:]
|
|
|
self.splits['val'].extend(samples_to_move)
|
|
|
self.splits['train'] = self.splits['train'][:-num_samples_to_move]
|
|
|
print(f"Dataset Len: {len(self.splits['train']), len(self.splits['val'])}\n\n\n\n")
|
|
|
elif self.cfg.split == "splits_MGL_soma_70k_mappred_random.json":
|
|
|
for stage, names in self.splits.items():
|
|
|
print("Length of splits {}: ".format(stage), len(self.splits[stage]))
|
|
|
for stage, names in self.splits.items():
|
|
|
view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]]
|
|
|
data = {k: [] for k in view.keys() - exclude}
|
|
|
for scene, seq, name in names:
|
|
|
for k in data:
|
|
|
data[k].append(self.dumps[scene][seq]["views"][name].get(k, None))
|
|
|
for k in data:
|
|
|
v = np.array(data[k])
|
|
|
if np.issubdtype(v.dtype, np.integer) or np.issubdtype(
|
|
|
v.dtype, np.floating
|
|
|
):
|
|
|
v = torch.from_numpy(v)
|
|
|
data[k] = v
|
|
|
data["cameras"] = cameras
|
|
|
data["points"] = points
|
|
|
self.data[stage] = data
|
|
|
self.splits[stage] = np.array(names)
|
|
|
|
|
|
def filter_elements(self):
|
|
|
for stage, names in self.splits.items():
|
|
|
names_select = []
|
|
|
for scene, seq, name in names:
|
|
|
view = self.dumps[scene][seq]["views"][name]
|
|
|
if self.cfg.filter_for == "ground_plane":
|
|
|
if not (1.0 <= view["height"] <= 3.0):
|
|
|
continue
|
|
|
planes = self.dumps[scene][seq].get("plane")
|
|
|
if planes is not None:
|
|
|
inliers = planes[str(view["chunk_id"])][-1]
|
|
|
if inliers < 10:
|
|
|
continue
|
|
|
if self.cfg.filter_by_ground_angle is not None:
|
|
|
plane = np.array(view["plane_params"])
|
|
|
normal = plane[:3] / np.linalg.norm(plane[:3])
|
|
|
angle = np.rad2deg(np.arccos(np.abs(normal[-1])))
|
|
|
if angle > self.cfg.filter_by_ground_angle:
|
|
|
continue
|
|
|
elif self.cfg.filter_for == "pointcloud":
|
|
|
if len(view["observations"]) < self.cfg.min_num_points:
|
|
|
continue
|
|
|
elif self.cfg.filter_for is not None:
|
|
|
raise ValueError(f"Unknown filtering: {self.cfg.filter_for}")
|
|
|
names_select.append((scene, seq, name))
|
|
|
logger.info(
|
|
|
"%s: Keep %d/%d images after filtering for %s.",
|
|
|
stage,
|
|
|
len(names_select),
|
|
|
len(names),
|
|
|
self.cfg.filter_for,
|
|
|
)
|
|
|
self.splits[stage] = names_select
|
|
|
|
|
|
def parse_splits(self, split_arg, names):
|
|
|
if split_arg is None:
|
|
|
self.splits = {
|
|
|
"train": names,
|
|
|
"val": names,
|
|
|
}
|
|
|
elif isinstance(split_arg, int):
|
|
|
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
|
|
|
self.splits = {
|
|
|
"train": names[split_arg:],
|
|
|
"val": names[:split_arg],
|
|
|
}
|
|
|
elif isinstance(split_arg, float):
|
|
|
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
|
|
|
self.splits = {
|
|
|
"train": names[int(split_arg * len(names)) :],
|
|
|
"val": names[: int(split_arg * len(names))],
|
|
|
}
|
|
|
elif isinstance(split_arg, DictConfig):
|
|
|
scenes_val = set(split_arg.val)
|
|
|
scenes_train = set(split_arg.train)
|
|
|
assert len(scenes_val - set(self.cfg.scenes)) == 0
|
|
|
assert len(scenes_train - set(self.cfg.scenes)) == 0
|
|
|
self.splits = {
|
|
|
"train": [n for n in names if n[0] in scenes_train],
|
|
|
"val": [n for n in names if n[0] in scenes_val],
|
|
|
}
|
|
|
elif isinstance(split_arg, str):
|
|
|
|
|
|
if "/" in split_arg:
|
|
|
split_path = self.root / split_arg
|
|
|
else:
|
|
|
split_path = Path(split_arg)
|
|
|
|
|
|
with split_path.open("r") as fp:
|
|
|
splits = json.load(fp)
|
|
|
splits = {
|
|
|
k: {loc: set(ids) for loc, ids in split.items()}
|
|
|
for k, split in splits.items()
|
|
|
}
|
|
|
self.splits = {}
|
|
|
|
|
|
for k, split in splits.items():
|
|
|
self.splits[k] = [
|
|
|
n
|
|
|
for n in names
|
|
|
if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]]
|
|
|
]
|
|
|
else:
|
|
|
raise ValueError(split_arg)
|
|
|
|
|
|
def dataset(self, stage: str):
|
|
|
return MapLocDataset(
|
|
|
stage,
|
|
|
self.cfg,
|
|
|
self.splits[stage],
|
|
|
self.data[stage],
|
|
|
self.image_dirs,
|
|
|
self.seg_masks_dir,
|
|
|
self.flood_masks_dir,
|
|
|
|
|
|
image_ext=".jpg",
|
|
|
)
|
|
|
|
|
|
def sequence_dataset(self, stage: str, **kwargs):
|
|
|
keys = self.splits[stage]
|
|
|
seq2indices = defaultdict(list)
|
|
|
for index, (_, seq, _) in enumerate(keys):
|
|
|
seq2indices[seq].append(index)
|
|
|
|
|
|
chunk2indices = {}
|
|
|
for seq, indices in seq2indices.items():
|
|
|
chunks = chunk_sequence(self.data[stage], indices, **kwargs)
|
|
|
for i, sub_indices in enumerate(chunks):
|
|
|
chunk2indices[seq, i] = sub_indices
|
|
|
|
|
|
chunk_indices = torch.full((len(keys),), -1)
|
|
|
for (_, chunk_index), idx in chunk2indices.items():
|
|
|
chunk_indices[idx] = chunk_index
|
|
|
self.data[stage]["chunk_index"] = chunk_indices
|
|
|
dataset = self.dataset(stage)
|
|
|
return dataset, chunk2indices
|
|
|
|
|
|
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
|
|
|
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
|
|
|
chunk_keys = sorted(chunk2idx)
|
|
|
if shuffle:
|
|
|
perm = torch.randperm(len(chunk_keys))
|
|
|
chunk_keys = [chunk_keys[i] for i in perm]
|
|
|
key_indices = [i for key in chunk_keys for i in chunk2idx[key]]
|
|
|
num_workers = self.cfg.loading[stage]["num_workers"]
|
|
|
loader = torchdata.DataLoader(
|
|
|
dataset,
|
|
|
batch_size=None,
|
|
|
sampler=key_indices,
|
|
|
num_workers=num_workers,
|
|
|
shuffle=False,
|
|
|
pin_memory=True,
|
|
|
persistent_workers=num_workers > 0,
|
|
|
worker_init_fn=worker_init_fn,
|
|
|
collate_fn=collate,
|
|
|
)
|
|
|
return loader, chunk_keys, chunk2idx
|
|
|
|