|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
|
|
|
def chunk_sequence(
|
|
|
data,
|
|
|
indices,
|
|
|
*,
|
|
|
names=None,
|
|
|
max_length=100,
|
|
|
min_length=1,
|
|
|
max_delay_s=None,
|
|
|
max_inter_dist=None,
|
|
|
max_total_dist=None,
|
|
|
):
|
|
|
sort_array = data.get("capture_time", data.get("index"))
|
|
|
if sort_array is None:
|
|
|
sort_array = indices if names is None else names
|
|
|
indices = sorted(indices, key=lambda i: sort_array[i].tolist())
|
|
|
centers = torch.stack([data["t_c2w"][i][:2] for i in indices]).numpy()
|
|
|
dists = np.linalg.norm(np.diff(centers, axis=0), axis=-1)
|
|
|
if "capture_time" in data:
|
|
|
times = torch.stack([data["capture_time"][i] for i in indices])
|
|
|
times = times.double() / 1e3
|
|
|
delays = np.diff(times, axis=0)
|
|
|
else:
|
|
|
delays = np.zeros_like(dists)
|
|
|
chunks = [[indices[0]]]
|
|
|
dist_total = 0
|
|
|
for dist, delay, idx in zip(dists, delays, indices[1:]):
|
|
|
dist_total += dist
|
|
|
if (
|
|
|
(max_inter_dist is not None and dist > max_inter_dist)
|
|
|
or (max_total_dist is not None and dist_total > max_total_dist)
|
|
|
or (max_delay_s is not None and delay > max_delay_s)
|
|
|
or len(chunks[-1]) >= max_length
|
|
|
):
|
|
|
chunks.append([])
|
|
|
dist_total = 0
|
|
|
chunks[-1].append(idx)
|
|
|
chunks = list(filter(lambda c: len(c) >= min_length, chunks))
|
|
|
chunks = sorted(chunks, key=len, reverse=True)
|
|
|
return chunks |