LongStream / longstream /core /infer.py
Cc
init
e340a84
import argparse
import os
import yaml
import cv2
import numpy as np
import torch
from PIL import Image
from longstream.core.model import LongStreamModel
from longstream.data.dataloader import LongStreamDataLoader
from longstream.streaming.keyframe_selector import KeyframeSelector
from longstream.streaming.refresh import run_batch_refresh, run_streaming_refresh
from longstream.utils.vendor.models.components.utils.pose_enc import (
pose_encoding_to_extri_intri,
)
from longstream.utils.camera import compose_abs_from_rel
from longstream.utils.depth import colorize_depth, unproject_depth_to_points
from longstream.utils.sky_mask import compute_sky_mask
from longstream.io.save_points import save_pointcloud
from longstream.io.save_poses_txt import save_w2c_txt, save_intri_txt, save_rel_pose_txt
from longstream.io.save_images import save_image_sequence, save_video
def _to_uint8_rgb(images):
imgs = images.detach().cpu().numpy()
imgs = np.clip(imgs, 0.0, 1.0)
imgs = (imgs * 255.0).astype(np.uint8)
return imgs
def _ensure_dir(path):
os.makedirs(path, exist_ok=True)
def _apply_sky_mask(depth, mask):
if mask is None:
return depth
m = (mask > 0).astype(np.float32)
return depth * m
def _camera_points_to_world(points, extri):
pts = np.asarray(points, dtype=np.float64).reshape(-1, 3)
R = np.asarray(extri[:3, :3], dtype=np.float64)
t = np.asarray(extri[:3, 3], dtype=np.float64)
world = (R.T @ (pts.T - t[:, None])).T
return world.astype(np.float32, copy=False)
def _mask_points_and_colors(points, colors, mask):
pts = points.reshape(-1, 3)
cols = None if colors is None else colors.reshape(-1, 3)
if mask is None:
return pts, cols
valid = mask.reshape(-1) > 0
pts = pts[valid]
if cols is not None:
cols = cols[valid]
return pts, cols
def _resize_long_edge(arr, long_edge_size, interpolation):
h, w = arr.shape[:2]
scale = float(long_edge_size) / float(max(h, w))
new_w = int(round(w * scale))
new_h = int(round(h * scale))
return cv2.resize(arr, (new_w, new_h), interpolation=interpolation)
def _prepare_mask_for_model(
mask, size, crop, patch_size, target_shape, square_ok=False
):
if mask is None:
return None
long_edge = (
round(size * max(mask.shape[1] / mask.shape[0], mask.shape[0] / mask.shape[1]))
if size == 224
else size
)
mask = _resize_long_edge(mask, long_edge, cv2.INTER_NEAREST)
h, w = mask.shape[:2]
cx, cy = w // 2, h // 2
if size == 224:
half = min(cx, cy)
target_w = 2 * half
target_h = 2 * half
if crop:
mask = mask[cy - half : cy + half, cx - half : cx + half]
else:
mask = cv2.resize(
mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST
)
else:
halfw = ((2 * cx) // patch_size) * (patch_size // 2)
halfh = ((2 * cy) // patch_size) * (patch_size // 2)
if not square_ok and w == h:
halfh = int(3 * halfw / 4)
target_w = 2 * halfw
target_h = 2 * halfh
if crop:
mask = mask[cy - halfh : cy + halfh, cx - halfw : cx + halfw]
else:
mask = cv2.resize(
mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST
)
if mask.shape[:2] != tuple(target_shape):
mask = cv2.resize(
mask, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST
)
return mask
def _save_full_pointcloud(path, point_chunks, color_chunks, max_points=None, seed=0):
if not point_chunks:
return
points = np.concatenate(point_chunks, axis=0)
colors = None
if color_chunks and len(color_chunks) == len(point_chunks):
colors = np.concatenate(color_chunks, axis=0)
if max_points is not None and len(points) > max_points:
rng = np.random.default_rng(seed)
keep = rng.choice(len(points), size=max_points, replace=False)
points = points[keep]
if colors is not None:
colors = colors[keep]
np.save(os.path.splitext(path)[0] + ".npy", points.astype(np.float32, copy=False))
save_pointcloud(path, points, colors=colors, max_points=None, seed=seed)
def run_inference_cfg(cfg: dict):
device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu")
device_type = torch.device(device).type
model_cfg = cfg.get("model", {})
data_cfg = cfg.get("data", {})
infer_cfg = cfg.get("inference", {})
output_cfg = cfg.get("output", {})
print(f"[longstream] device={device}", flush=True)
model = LongStreamModel(model_cfg).to(device)
model.eval()
print("[longstream] model ready", flush=True)
loader = LongStreamDataLoader(data_cfg)
keyframe_stride = int(infer_cfg.get("keyframe_stride", 8))
keyframe_mode = infer_cfg.get("keyframe_mode", "fixed")
refresh = int(
infer_cfg.get("refresh", int(infer_cfg.get("keyframes_per_batch", 3)) + 1)
)
if refresh < 2:
raise ValueError(
"refresh must be >= 2 because it counts both keyframe endpoints"
)
mode = infer_cfg.get("mode", "streaming_refresh")
if mode == "streaming":
mode = "streaming_refresh"
streaming_mode = infer_cfg.get("streaming_mode", "causal")
window_size = int(infer_cfg.get("window_size", 5))
selector = KeyframeSelector(
min_interval=keyframe_stride,
max_interval=keyframe_stride,
force_first=True,
mode="random" if keyframe_mode == "random" else "fixed",
)
out_root = output_cfg.get("root", "outputs")
_ensure_dir(out_root)
save_videos = bool(output_cfg.get("save_videos", True))
save_points = bool(output_cfg.get("save_points", True))
save_frame_points = bool(output_cfg.get("save_frame_points", True))
save_depth = bool(output_cfg.get("save_depth", True))
save_images = bool(output_cfg.get("save_images", True))
mask_sky = bool(output_cfg.get("mask_sky", True))
max_full_pointcloud_points = output_cfg.get("max_full_pointcloud_points", None)
if max_full_pointcloud_points is not None:
max_full_pointcloud_points = int(max_full_pointcloud_points)
max_frame_pointcloud_points = output_cfg.get("max_frame_pointcloud_points", None)
if max_frame_pointcloud_points is not None:
max_frame_pointcloud_points = int(max_frame_pointcloud_points)
skyseg_path = output_cfg.get(
"skyseg_path",
os.path.join(os.path.dirname(__file__), "..", "..", "skyseg.onnx"),
)
with torch.no_grad():
for seq in loader:
images = seq.images
B, S, C, H, W = images.shape
print(
f"[longstream] sequence {seq.name}: inference start ({S} frames)",
flush=True,
)
is_keyframe, keyframe_indices = selector.select_keyframes(
S, B, images.device
)
rel_pose_cfg = infer_cfg.get("rel_pose_head_cfg", {"num_iterations": 4})
if mode == "batch_refresh":
outputs = run_batch_refresh(
model,
images,
is_keyframe,
keyframe_indices,
streaming_mode,
keyframe_stride,
refresh,
rel_pose_cfg,
)
elif mode == "streaming_refresh":
outputs = run_streaming_refresh(
model,
images,
is_keyframe,
keyframe_indices,
streaming_mode,
window_size,
refresh,
rel_pose_cfg,
)
else:
raise ValueError(f"Unsupported inference mode: {mode}")
print(f"[longstream] sequence {seq.name}: inference done", flush=True)
if device_type == "cuda":
torch.cuda.empty_cache()
seq_dir = os.path.join(out_root, seq.name)
_ensure_dir(seq_dir)
frame_ids = list(range(S))
rgb = _to_uint8_rgb(images[0].permute(0, 2, 3, 1))
if "rel_pose_enc" in outputs:
rel_pose_enc = outputs["rel_pose_enc"][0]
abs_pose_enc = compose_abs_from_rel(rel_pose_enc, keyframe_indices[0])
extri, intri = pose_encoding_to_extri_intri(
abs_pose_enc[None], image_size_hw=(H, W)
)
extri_np = extri[0].detach().cpu().numpy()
intri_np = intri[0].detach().cpu().numpy()
pose_dir = os.path.join(seq_dir, "poses")
_ensure_dir(pose_dir)
save_w2c_txt(
os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids
)
save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids)
save_rel_pose_txt(
os.path.join(pose_dir, "rel_pose.txt"), rel_pose_enc, frame_ids
)
elif "pose_enc" in outputs:
pose_enc = outputs["pose_enc"][0]
extri, intri = pose_encoding_to_extri_intri(
pose_enc[None], image_size_hw=(H, W)
)
extri_np = extri[0].detach().cpu().numpy()
intri_np = intri[0].detach().cpu().numpy()
pose_dir = os.path.join(seq_dir, "poses")
_ensure_dir(pose_dir)
save_w2c_txt(
os.path.join(pose_dir, "abs_pose.txt"), extri_np, frame_ids
)
save_intri_txt(os.path.join(pose_dir, "intri.txt"), intri_np, frame_ids)
if save_images:
print(f"[longstream] sequence {seq.name}: saving rgb", flush=True)
rgb_dir = os.path.join(seq_dir, "images", "rgb")
save_image_sequence(rgb_dir, list(rgb))
if save_videos:
save_video(
os.path.join(seq_dir, "images", "rgb.mp4"),
os.path.join(rgb_dir, "frame_*.png"),
)
sky_masks = None
if mask_sky:
raw_sky_masks = compute_sky_mask(
seq.image_paths, skyseg_path, os.path.join(seq_dir, "sky_masks")
)
if raw_sky_masks is not None:
sky_masks = [
_prepare_mask_for_model(
mask,
size=int(data_cfg.get("size", 518)),
crop=bool(data_cfg.get("crop", False)),
patch_size=int(data_cfg.get("patch_size", 14)),
target_shape=(H, W),
)
for mask in raw_sky_masks
]
if save_depth and "depth" in outputs:
print(f"[longstream] sequence {seq.name}: saving depth", flush=True)
depth = outputs["depth"][0, :, :, :, 0].detach().cpu().numpy()
depth_dir = os.path.join(seq_dir, "depth", "dpt")
_ensure_dir(depth_dir)
color_dir = os.path.join(seq_dir, "depth", "dpt_plasma")
_ensure_dir(color_dir)
color_frames = []
for i in range(S):
d = depth[i]
if sky_masks is not None and sky_masks[i] is not None:
d = _apply_sky_mask(d, sky_masks[i])
np.save(os.path.join(depth_dir, f"frame_{i:06d}.npy"), d)
colored = colorize_depth(d, cmap="plasma")
Image.fromarray(colored).save(
os.path.join(color_dir, f"frame_{i:06d}.png")
)
color_frames.append(colored)
if save_videos:
save_video(
os.path.join(seq_dir, "depth", "dpt_plasma.mp4"),
os.path.join(color_dir, "frame_*.png"),
)
if save_points:
print(
f"[longstream] sequence {seq.name}: saving point clouds", flush=True
)
if "world_points" in outputs:
if "rel_pose_enc" in outputs:
abs_pose_enc = compose_abs_from_rel(
outputs["rel_pose_enc"][0], keyframe_indices[0]
)
extri, intri = pose_encoding_to_extri_intri(
abs_pose_enc[None], image_size_hw=(H, W)
)
else:
extri, intri = pose_encoding_to_extri_intri(
outputs["pose_enc"][0][None], image_size_hw=(H, W)
)
extri = extri[0]
intri = intri[0]
pts_dir = os.path.join(seq_dir, "points", "point_head")
_ensure_dir(pts_dir)
pts = outputs["world_points"][0].detach().cpu().numpy()
full_pts = []
full_cols = []
for i in range(S):
pts_world = _camera_points_to_world(
pts[i], extri[i].detach().cpu().numpy()
)
pts_world = pts_world.reshape(pts[i].shape)
pts_i, cols_i = _mask_points_and_colors(
pts_world,
rgb[i],
None if sky_masks is None else sky_masks[i],
)
if save_frame_points:
save_pointcloud(
os.path.join(pts_dir, f"frame_{i:06d}.ply"),
pts_i,
colors=cols_i,
max_points=max_frame_pointcloud_points,
seed=i,
)
if len(pts_i):
full_pts.append(pts_i)
full_cols.append(cols_i)
_save_full_pointcloud(
os.path.join(seq_dir, "points", "point_head_full.ply"),
full_pts,
full_cols,
max_points=max_full_pointcloud_points,
seed=0,
)
if "depth" in outputs and (
"rel_pose_enc" in outputs or "pose_enc" in outputs
):
depth = outputs["depth"][0, :, :, :, 0]
if "rel_pose_enc" in outputs:
abs_pose_enc = compose_abs_from_rel(
outputs["rel_pose_enc"][0], keyframe_indices[0]
)
extri, intri = pose_encoding_to_extri_intri(
abs_pose_enc[None], image_size_hw=(H, W)
)
else:
extri, intri = pose_encoding_to_extri_intri(
outputs["pose_enc"][0][None], image_size_hw=(H, W)
)
extri = extri[0]
intri = intri[0]
dpt_pts_dir = os.path.join(seq_dir, "points", "dpt_unproj")
_ensure_dir(dpt_pts_dir)
full_pts = []
full_cols = []
for i in range(S):
d = depth[i]
pts_cam = unproject_depth_to_points(d[None], intri[i : i + 1])[
0
]
R = extri[i, :3, :3]
t = extri[i, :3, 3]
pts_world = (
R.t() @ (pts_cam.reshape(-1, 3).t() - t[:, None])
).t()
pts_world = pts_world.cpu().numpy().reshape(-1, 3)
pts_i, cols_i = _mask_points_and_colors(
pts_world,
rgb[i],
None if sky_masks is None else sky_masks[i],
)
if save_frame_points:
save_pointcloud(
os.path.join(dpt_pts_dir, f"frame_{i:06d}.ply"),
pts_i,
colors=cols_i,
max_points=max_frame_pointcloud_points,
seed=i,
)
if len(pts_i):
full_pts.append(pts_i)
full_cols.append(cols_i)
_save_full_pointcloud(
os.path.join(seq_dir, "points", "dpt_unproj_full.ply"),
full_pts,
full_cols,
max_points=max_full_pointcloud_points,
seed=1,
)
del outputs
if device_type == "cuda":
torch.cuda.empty_cache()
def run_inference(config_path: str):
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
run_inference_cfg(cfg)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
args = parser.parse_args()
run_inference(args.config)
if __name__ == "__main__":
main()