UniSHARP / scripts /infer_unisharp.py
Insta360-Research's picture
Use foreground p25 depth for near-scene demo motion scaling
8fff9f2 verified
from __future__ import annotations
import argparse
import json
import logging
import math
import os
import re
import sys
from pathlib import Path
from typing import Any, Literal
import numpy as np
import torch
from PIL import Image, ImageOps
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
from unisharp.cli.unified_trainer import UnifiedTrainer # noqa: E402
from unisharp.models.unisharp_feature import UnisharpFeatureConfig, UnisharpFeatureModel # noqa: E402
from unisharp.utils.camera_utils import transform_gaussians_to_world # noqa: E402
from unisharp.utils.color_space import linearRGB2sRGB # noqa: E402
from unisharp.utils.fisheye_geer import render_gaussians_fisheye624 # noqa: E402
from unisharp.utils.gaussians import save_ply # noqa: E402
from unisharp.utils.gsplat import GSplatRenderer # noqa: E402
from unisharp.utils.camera_projection import build_extrinsics_w2c # noqa: E402
from unisharp.utils.rayfit_camera import fit_fisheye624_params_from_rays, fit_pinhole_intrinsics_from_rays # noqa: E402
LOGGER = logging.getLogger("infer_unisharp")
IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".PNG", ".JPG", ".JPEG", ".WEBP"}
CameraKind = Literal["perspective", "fisheye", "panorama"]
FACE_NAMES = ["up", "back", "left", "front", "right", "down"]
MAX_LONG_EDGE = 0
PERSPECTIVE_MAX_LONG_EDGE = 0
PANORAMA_MAX_LONG_EDGE = 0
FORWARD_VIEWS = 10
FORWARD_DISTANCE_M = 0.2
ROTATE_VIEWS = 10
ROTATE_RADIUS_M = 0.1
GIF_DURATION_MS = 300
# Shrink demo view motion when predicted scene depth is near (e.g. WildRGBD tabletop).
VIEW_MOTION_NEAR_SCENE_DEPTH_M = 1.5
VIEW_MOTION_MIN_SCALE = 0.12
# Above this median depth, trust median only (outdoor / large-scale scenes).
VIEW_MOTION_FAR_SCENE_MEDIAN_M = 2.5
VIEW_MOTION_FOREGROUND_DEPTH_QUANTILE = 0.25
FISHEYE_FOV_THRESHOLD_DEG = 120.0
FISHEYE_DIAG_THRESHOLD_DEG = 150.0
FISHEYE_VFOV_MIN_DEG = 80.0
FISHEYE_MAX_ASPECT = 1.65
PANORAMA_HFOV_THRESHOLD_DEG = 300.0
PANORAMA_VFOV_THRESHOLD_DEG = 120.0
PANORAMA_ASPECT_MIN = 1.9
PANORAMA_ASPECT_MAX = 2.1
def _configure_torchhub_cache() -> None:
torchhub_dir = REPO_ROOT / "checkpoints" / "torchhub"
torchhub_dir.mkdir(parents=True, exist_ok=True)
os.environ["TORCH_HOME"] = str(torchhub_dir)
torch.hub.set_dir(str(torchhub_dir))
def _feature_config_from_checkpoint(checkpoint_path: Path, ckpt: dict[str, Any]) -> UnisharpFeatureConfig:
cfg = UnisharpFeatureConfig()
merged: dict[str, Any] = {}
cfg_payload = ckpt.get("config", {})
if isinstance(cfg_payload, dict):
merged.update(cfg_payload)
for key in cfg.__dict__.keys():
if key in ckpt:
merged[key] = ckpt[key]
config_path = checkpoint_path.parent / "config.json"
if config_path.exists():
try:
sidecar = json.loads(config_path.read_text(encoding="utf-8"))
except Exception:
sidecar = None
if isinstance(sidecar, dict):
merged.update({k: v for k, v in sidecar.items() if k in cfg.__dict__})
for key in cfg.__dict__.keys():
if key in merged:
setattr(cfg, key, merged[key])
return cfg
def _load_model(checkpoint_path: Path, device: torch.device) -> tuple[UnisharpFeatureModel, int]:
try:
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
except TypeError:
ckpt = torch.load(checkpoint_path, map_location="cpu")
if not isinstance(ckpt, dict):
raise ValueError(f"Expected checkpoint dict, got {type(ckpt)} from {checkpoint_path}")
cfg = _feature_config_from_checkpoint(checkpoint_path, ckpt)
model = UnisharpFeatureModel(cfg).to(device)
missing, unexpected = model.load_from_checkpoint(str(checkpoint_path), strict=False)
if missing or unexpected:
LOGGER.warning("Loaded checkpoint with missing=%s unexpected=%s", missing[:20], unexpected[:20])
model.eval()
return model, int(ckpt.get("step", 0))
def _collect_image_paths(args: argparse.Namespace) -> list[Path]:
paths: list[Path] = []
if args.image is not None:
paths.append(Path(args.image))
if args.image_list is not None:
for raw in Path(args.image_list).read_text(encoding="utf-8").splitlines():
line = raw.strip()
if line and not line.startswith("#"):
paths.append(Path(line))
if args.image_dir is not None:
root = Path(args.image_dir)
paths.extend(sorted(p for p in root.iterdir() if p.is_file() and p.suffix in IMAGE_SUFFIXES))
if not paths:
raise ValueError("Provide --image, --image-list, or --image-dir.")
return paths[: int(args.max_images)] if int(args.max_images) > 0 else paths
def _perspective_max_long_edge() -> int:
return int(PERSPECTIVE_MAX_LONG_EDGE)
def _panorama_max_long_edge() -> int:
return int(PANORAMA_MAX_LONG_EDGE)
def _image_hw_from_path(image_path: Path) -> tuple[int, int]:
with Image.open(image_path) as raw:
image = ImageOps.exif_transpose(raw)
w, h = image.size
return int(h), int(w)
def _should_load_panorama_native(
*,
image_path: Path,
args: argparse.Namespace,
camera_json_entry: dict[str, Any] | None,
) -> bool:
forced = str(args.camera).strip().lower()
if forced in {"panorama", "erp"}:
return True
if forced in {"perspective", "pinhole", "fisheye"}:
return False
json_camera_name = _camera_name_from_json(camera_json_entry)
if json_camera_name in {"panorama", "erp", "spherical"}:
return True
if json_camera_name in {"perspective", "pinhole", "fisheye", "fisheye624", "opencv_fisheye"}:
return False
image_h, image_w = _image_hw_from_path(image_path)
return _camera_name_from_aspect(image_h=image_h, image_w=image_w) == "panorama"
def _initial_max_long_edge(
*,
image_path: Path,
args: argparse.Namespace,
camera_json_entry: dict[str, Any] | None,
) -> int:
if _should_load_panorama_native(image_path=image_path, args=args, camera_json_entry=camera_json_entry):
return _panorama_max_long_edge()
return _perspective_max_long_edge()
def _load_rgb_u8(image_path: Path, max_long_edge: int) -> torch.Tensor:
with Image.open(image_path) as raw:
image = ImageOps.exif_transpose(raw).convert("RGB")
if int(max_long_edge) > 0:
w, h = image.size
scale = min(1.0, float(max_long_edge) / float(max(h, w)))
if scale < 1.0:
image = image.resize(
(max(1, int(round(w * scale))), max(1, int(round(h * scale)))),
resample=Image.BILINEAR,
)
arr = np.asarray(image, dtype=np.uint8).copy()
return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
def _to_u8_hwc(img_chw: torch.Tensor) -> np.ndarray:
if img_chw.dtype == torch.uint8:
return img_chw.permute(1, 2, 0).detach().cpu().numpy()
x = img_chw.detach().to(torch.float32).clamp(0.0, 1.0)
return (x * 255.0).round().to(torch.uint8).permute(1, 2, 0).cpu().numpy()
def _crop_border_u8(frame: np.ndarray, fraction: float) -> np.ndarray:
if float(fraction) <= 0.0:
return frame
if frame.ndim < 2:
return frame
h, w = int(frame.shape[0]), int(frame.shape[1])
crop_y = int(round(float(h) * float(fraction)))
crop_x = int(round(float(w) * float(fraction)))
if crop_y <= 0 and crop_x <= 0:
return frame
if crop_y * 2 >= h or crop_x * 2 >= w:
return frame
return frame[crop_y : h - crop_y, crop_x : w - crop_x].copy()
def _save_gif(frames: list[np.ndarray], out_file: Path, duration_ms: int) -> None:
if not frames:
raise ValueError(f"No frames to save for {out_file}")
out_file.parent.mkdir(parents=True, exist_ok=True)
pil_frames = [Image.fromarray(frame) for frame in frames]
pil_frames[0].save(
out_file,
save_all=True,
append_images=pil_frames[1:],
duration=int(duration_ms),
loop=0,
disposal=2,
)
def _slug_from_path(image_path: Path) -> str:
raw = f"{image_path.parent.name}_{image_path.stem}"
return re.sub(r"[^A-Za-z0-9_.-]+", "_", raw)
def _normalize_rays(rays: torch.Tensor) -> torch.Tensor:
rays_f = rays.detach().to(torch.float32)
return rays_f / torch.linalg.vector_norm(rays_f, dim=1, keepdim=True).clamp(min=1e-6)
def _angular_span_deg(a: np.ndarray) -> float:
a = a[np.isfinite(a)]
if a.size < 2:
return 0.0
return float(np.degrees(np.nanpercentile(a, 99.0) - np.nanpercentile(a, 1.0)))
def _angle_between_deg(a: np.ndarray, b: np.ndarray) -> float:
denom = max(float(np.linalg.norm(a) * np.linalg.norm(b)), 1e-8)
return float(np.degrees(np.arccos(np.clip(float(np.dot(a, b)) / denom, -1.0, 1.0))))
def _ray_fov_stats(rays_b3hw: torch.Tensor) -> dict[str, float]:
rays = _normalize_rays(rays_b3hw)[0].detach().cpu().numpy()
_, h, w = rays.shape
rows = [max(0, min(h - 1, int(round(h * q)))) for q in (0.25, 0.5, 0.75)]
cols = [max(0, min(w - 1, int(round(w * q)))) for q in (0.25, 0.5, 0.75)]
h_spans = []
for row in rows:
lon = np.unwrap(np.arctan2(rays[0, row], rays[2, row]))
h_spans.append(_angular_span_deg(lon))
v_spans = []
for col in cols:
x = rays[0, :, col]
y = rays[1, :, col]
z = rays[2, :, col]
lat = np.arctan2(y, np.sqrt(x * x + z * z))
v_spans.append(_angular_span_deg(lat))
corners = [rays[:, 0, 0], rays[:, 0, w - 1], rays[:, h - 1, 0], rays[:, h - 1, w - 1]]
diag = max(_angle_between_deg(corners[i], corners[j]) for i in range(4) for j in range(i + 1, 4))
return {
"horizontal_fov_deg": float(np.median(h_spans)),
"vertical_fov_deg": float(np.median(v_spans)),
"diagonal_fov_deg": float(diag),
"aspect": float(w) / float(max(h, 1)),
}
def _classify_camera(stats: dict[str, float], args: argparse.Namespace) -> CameraKind:
forced = str(args.camera).strip().lower()
if forced != "auto":
return {"pinhole": "perspective", "erp": "panorama"}.get(forced, forced) # type: ignore[return-value]
aspect = float(stats["aspect"])
h_fov = float(stats["horizontal_fov_deg"])
v_fov = float(stats["vertical_fov_deg"])
diag_fov = float(stats["diagonal_fov_deg"])
if (
PANORAMA_ASPECT_MIN <= aspect <= PANORAMA_ASPECT_MAX
and h_fov >= PANORAMA_HFOV_THRESHOLD_DEG
and v_fov >= PANORAMA_VFOV_THRESHOLD_DEG
):
return "panorama"
fishlike_aspect = aspect <= FISHEYE_MAX_ASPECT
fishlike_fov = (
max(h_fov, v_fov) >= FISHEYE_FOV_THRESHOLD_DEG
or (diag_fov >= FISHEYE_DIAG_THRESHOLD_DEG and v_fov >= FISHEYE_VFOV_MIN_DEG)
)
if fishlike_aspect and fishlike_fov:
return "fisheye"
return "perspective"
def _empty_ray_stats() -> dict[str, float]:
return {
"horizontal_fov_deg": float("nan"),
"vertical_fov_deg": float("nan"),
"diagonal_fov_deg": float("nan"),
"aspect": float("nan"),
}
def _pinhole_intrinsics_from_values(values: list[float] | None, *, device: torch.device) -> torch.Tensor | None:
if values is None:
return None
vals = [float(v) for v in values]
if len(vals) == 4:
fx, fy, cx, cy = vals
k = torch.tensor(
[[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]],
dtype=torch.float32,
device=device,
)
elif len(vals) == 9:
k = torch.tensor(vals, dtype=torch.float32, device=device).reshape(3, 3)
else:
raise ValueError("--camera-intrinsics expects 4 values (fx fy cx cy) or 9 row-major K values.")
return k.unsqueeze(0)
def _fisheye624_params_from_values(values: list[float] | None, *, device: torch.device) -> torch.Tensor | None:
if values is None:
return None
vals = [float(v) for v in values]
if len(vals) == 8:
vals = vals + [0.0] * 8
if len(vals) != 16:
raise ValueError("--camera-params expects 8 or 16 Fisheye624 values.")
return torch.tensor(vals, dtype=torch.float32, device=device).reshape(1, 16)
def _load_camera_json(path: Path | None) -> Any:
if path is None:
return None
payload = json.loads(Path(path).read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise ValueError("--camera-json must point to a JSON object.")
return payload
def _camera_json_for_image(payload: Any, image_path: Path) -> dict[str, Any] | None:
if not isinstance(payload, dict):
return None
images = payload.get("images", None)
if isinstance(images, dict):
keys = [
str(image_path),
image_path.as_posix(),
image_path.name,
image_path.stem,
]
for key in keys:
value = images.get(key, None)
if isinstance(value, dict):
base = payload.get("default", {})
merged = dict(base) if isinstance(base, dict) else {}
merged.update(value)
return merged
if isinstance(payload.get("default", None), dict):
return dict(payload["default"])
return dict(payload)
def _values_from_camera_json(entry: dict[str, Any] | None, *names: str) -> list[float] | None:
if not isinstance(entry, dict):
return None
for name in names:
value = entry.get(name, None)
if value is None:
continue
if isinstance(value, dict):
if all(k in value for k in ("fx", "fy", "cx", "cy")):
return [float(value["fx"]), float(value["fy"]), float(value["cx"]), float(value["cy"])]
if "K" in value:
value = value["K"]
else:
continue
if isinstance(value, (list, tuple)):
if len(value) == 3 and all(isinstance(row, (list, tuple)) for row in value):
flat = [float(x) for row in value for x in row]
else:
flat = [float(x) for x in value]
return flat
return None
def _camera_name_from_json(entry: dict[str, Any] | None) -> str | None:
if not isinstance(entry, dict):
return None
value = entry.get("camera", entry.get("camera_model", entry.get("type", None)))
return str(value).strip().lower() if value is not None and str(value).strip() else None
def _camera_name_from_aspect(image_h: int, image_w: int) -> str | None:
aspect = float(image_w) / float(max(image_h, 1))
if PANORAMA_ASPECT_MIN <= aspect <= PANORAMA_ASPECT_MAX:
return "panorama"
return None
@torch.no_grad()
def _predict_unik3d_rays(
model: UnisharpFeatureModel,
image_u8: torch.Tensor,
*,
image_h: int,
image_w: int,
) -> torch.Tensor:
model.feature_extractor.forward(
rgb_u8=image_u8,
target_h=int(image_h),
target_w=int(image_w),
use_predicted_rays=True,
)
output = model.feature_extractor._unisharp_last_unik3d_output
if not isinstance(output, dict) or not torch.is_tensor(output.get("rays", None)):
raise RuntimeError("UniK3D did not return predicted rays for camera classification.")
return output["rays"]
def _build_forward_poses(num_views: int, distance_m: float, device: torch.device) -> list[torch.Tensor]:
poses = []
r_c2w = torch.eye(3, dtype=torch.float32, device=device)
views = max(1, int(num_views))
for idx in range(views):
alpha = float(idx + 1) / float(views)
eye = torch.tensor([0.0, 0.0, float(distance_m) * alpha], dtype=torch.float32, device=device)
poses.append(build_extrinsics_w2c(r_c2w, eye, "c2w"))
return poses
def _build_rotate_poses(num_views: int, radius_m: float, device: torch.device) -> list[torch.Tensor]:
poses = []
src_r_c2w = torch.eye(3, dtype=torch.float32, device=device)
views = max(1, int(num_views))
for idx in range(views):
theta = -2.0 * math.pi * float(idx) / float(views)
eye = torch.tensor(
[
float(radius_m) * math.sin(theta),
float(radius_m) * math.cos(theta),
0.0,
],
dtype=torch.float32,
device=device,
)
poses.append(build_extrinsics_w2c(src_r_c2w, eye, "c2w"))
return poses
def _predicted_depth_samples_m(model_output: dict[str, Any]) -> torch.Tensor | None:
depth = model_output.get("unik3d_distance")
if not torch.is_tensor(depth):
layers = model_output.get("distance_layers")
if torch.is_tensor(layers) and layers.ndim >= 4 and int(layers.shape[1]) >= 1:
depth = layers[:, 0:1]
if torch.is_tensor(depth) and depth.numel() > 0:
values = depth.detach().reshape(-1).to(torch.float32)
valid = values[torch.isfinite(values) & (values > 1e-3) & (values < 1e4)]
if int(valid.numel()) > 0:
return valid
gaussians = model_output.get("gaussians")
if gaussians is not None and hasattr(gaussians, "mean_vectors"):
z = gaussians.mean_vectors.detach().reshape(-1, 3)[..., 2].reshape(-1).to(torch.float32)
valid = z[torch.isfinite(z) & (z > 1e-3) & (z < 1e4)]
if int(valid.numel()) > 0:
return valid
return None
def _scene_depth_for_motion_m(model_output: dict[str, Any]) -> tuple[float | None, float | None, float | None]:
"""Return (effective_depth, median_depth, foreground_depth_p25) in meters."""
valid = _predicted_depth_samples_m(model_output)
if valid is None or int(valid.numel()) == 0:
return None, None, None
median_depth_m = float(torch.median(valid).item())
q = float(VIEW_MOTION_FOREGROUND_DEPTH_QUANTILE)
foreground_depth_m = float(torch.quantile(valid, q).item())
if float(median_depth_m) >= float(VIEW_MOTION_FAR_SCENE_MEDIAN_M):
effective_depth_m = float(median_depth_m)
else:
effective_depth_m = float(min(median_depth_m, foreground_depth_m))
return effective_depth_m, median_depth_m, foreground_depth_m
def _adaptive_view_motion_distances(
model_output: dict[str, Any],
*,
default_forward_m: float,
default_radius_m: float,
) -> tuple[float, float, float | None, float, float | None, float | None]:
effective_depth_m, median_depth_m, foreground_depth_m = _scene_depth_for_motion_m(model_output)
near_threshold_m = float(VIEW_MOTION_NEAR_SCENE_DEPTH_M)
if (
effective_depth_m is None
or not math.isfinite(effective_depth_m)
or float(effective_depth_m) >= near_threshold_m
):
return (
float(default_forward_m),
float(default_radius_m),
effective_depth_m,
1.0,
median_depth_m,
foreground_depth_m,
)
scale = max(float(VIEW_MOTION_MIN_SCALE), float(effective_depth_m) / near_threshold_m)
forward_m = float(default_forward_m) * scale
radius_m = float(default_radius_m) * scale
return forward_m, radius_m, effective_depth_m, scale, median_depth_m, foreground_depth_m
def _render_pinhole_frame(
renderer: GSplatRenderer,
gaussians: Any,
*,
extr_w2c: torch.Tensor,
intrinsics: torch.Tensor,
image_h: int,
image_w: int,
) -> np.ndarray:
out = renderer(
gaussians,
extrinsics=extr_w2c[None],
intrinsics=intrinsics[None],
image_width=int(image_w),
image_height=int(image_h),
)
alpha = out.alpha.detach().to(torch.float32).clamp(0.0, 1.0)
rgb = linearRGB2sRGB((out.color / alpha.clamp(min=1e-4)).clamp(0.0, 1.0)).clamp(0.0, 1.0)
return _to_u8_hwc(rgb[0])
def _render_fisheye_frame(
gaussians: Any,
*,
extr_w2c: torch.Tensor,
camera_params: torch.Tensor,
image_h: int,
image_w: int,
) -> np.ndarray:
out = render_gaussians_fisheye624(
gaussians,
extrinsics_w2c=extr_w2c[None],
camera_params=camera_params,
image_h=int(image_h),
image_w=int(image_w),
valid_mask=None,
)
alpha = out["alpha"].detach().to(torch.float32).clamp(0.0, 1.0)
rgb = linearRGB2sRGB((out["color"] / alpha.clamp(min=1e-4)).clamp(0.0, 1.0)).clamp(0.0, 1.0)
return _to_u8_hwc(rgb[0])
def _render_panorama_frame_and_faces(
trainer: UnifiedTrainer,
gaussians: Any,
*,
extr_w2c: torch.Tensor,
equ_h: int,
equ_w: int,
face_w: int,
) -> tuple[np.ndarray, dict[str, np.ndarray]]:
cube_color, _, cube_alpha = trainer._render_cubemap(gaussians, extr_w2c, face_w=int(face_w))
erp_color = trainer._cube_to_erp(cube_color, equ_h=int(equ_h), equ_w=int(equ_w), face_w=int(face_w))
erp_alpha = trainer._cube_to_erp(cube_alpha, equ_h=int(equ_h), equ_w=int(equ_w), face_w=int(face_w))
erp = linearRGB2sRGB((erp_color / erp_alpha.clamp(min=1e-4)).clamp(0.0, 1.0)).clamp(0.0, 1.0)
face_views: dict[str, np.ndarray] = {}
for face_idx, face_name in enumerate(FACE_NAMES):
face = linearRGB2sRGB(
(cube_color[face_idx : face_idx + 1] / cube_alpha[face_idx : face_idx + 1].clamp(min=1e-4)).clamp(0.0, 1.0)
).clamp(0.0, 1.0)
face_views[face_name] = _to_u8_hwc(face[0])
return _to_u8_hwc(erp[0]), face_views
@torch.no_grad()
def _run_model_pinhole(
model: UnisharpFeatureModel,
image: torch.Tensor,
image_u8: torch.Tensor,
*,
intrinsics: torch.Tensor,
distance_init_cap_m: float,
) -> dict[str, Any]:
return model(
image=image,
image_u8=image_u8,
camera_intrinsics=intrinsics,
camera_params=None,
camera_model="pinhole",
depth_gt=None,
distance_init_cap_m=(float(distance_init_cap_m) if float(distance_init_cap_m) > 0.0 else None),
return_aux=True,
)
@torch.no_grad()
def _run_model_fisheye(
model: UnisharpFeatureModel,
image: torch.Tensor,
image_u8: torch.Tensor,
*,
camera_params: torch.Tensor,
distance_init_cap_m: float,
) -> dict[str, Any]:
return model(
image=image,
image_u8=image_u8,
camera_intrinsics=None,
camera_params=camera_params,
camera_model="fisheye624",
depth_gt=None,
distance_init_cap_m=(float(distance_init_cap_m) if float(distance_init_cap_m) > 0.0 else None),
return_aux=True,
)
@torch.no_grad()
def _run_model_panorama(
model: UnisharpFeatureModel,
image: torch.Tensor,
image_u8: torch.Tensor,
*,
distance_init_cap_m: float,
) -> dict[str, Any]:
return model(
image=image,
image_u8=image_u8,
camera_intrinsics=None,
camera_params=None,
camera_model="spherical",
depth_gt=None,
distance_init_cap_m=(float(distance_init_cap_m) if float(distance_init_cap_m) > 0.0 else None),
return_aux=True,
)
def _save_ply_if_requested(gaussians: Any, path: Path, f_px: float, image_h: int, image_w: int, enabled: bool) -> None:
if not enabled:
return
path.parent.mkdir(parents=True, exist_ok=True)
save_ply(gaussians, f_px=float(f_px), image_shape=(int(image_h), int(image_w)), path=path)
@torch.no_grad()
def _process_one(
*,
model: UnisharpFeatureModel,
renderer: GSplatRenderer,
train_renderer: UnifiedTrainer,
image_path: Path,
out_root: Path,
step: int,
args: argparse.Namespace,
) -> None:
native_h, native_w = _image_hw_from_path(image_path)
camera_json_entry = _camera_json_for_image(getattr(args, "_camera_json_data", None), image_path)
load_max_long_edge = _initial_max_long_edge(
image_path=image_path,
args=args,
camera_json_entry=camera_json_entry,
)
for reload_attempt in range(2):
rgb_u8 = _load_rgb_u8(image_path, max_long_edge=load_max_long_edge)
_, h, w = rgb_u8.shape
if h < 4 or w < 4:
raise ValueError(f"Invalid image size for {image_path}: {tuple(rgb_u8.shape)}")
device = next(model.parameters()).device
image_u8 = rgb_u8.unsqueeze(0).to(device=device)
image = image_u8.to(torch.float32) / 255.0
json_camera_name = _camera_name_from_json(camera_json_entry)
aspect_camera_name = _camera_name_from_aspect(image_h=h, image_w=w)
forced_camera_name = str(args.camera).strip().lower()
forced_camera_name = None if forced_camera_name == "auto" else {"pinhole": "perspective", "erp": "panorama"}.get(forced_camera_name, forced_camera_name)
json_intrinsics = _values_from_camera_json(camera_json_entry, "intrinsics", "camera_intrinsics", "K")
json_camera_params = _values_from_camera_json(camera_json_entry, "camera_params", "fisheye624_params", "params")
explicit_intrinsics = _pinhole_intrinsics_from_values(json_intrinsics or args.camera_intrinsics, device=device)
explicit_camera_params = _fisheye624_params_from_values(json_camera_params or args.camera_params, device=device)
if explicit_intrinsics is not None and explicit_camera_params is not None:
raise ValueError("Use only one of --camera-intrinsics or --camera-params.")
rays: torch.Tensor | None
render_intrinsics: torch.Tensor | None = None
render_camera_params: torch.Tensor | None = None
if explicit_intrinsics is not None:
camera_kind: CameraKind = "panorama" if json_camera_name in {"panorama", "erp", "spherical"} else "perspective"
render_intrinsics = explicit_intrinsics
if camera_kind == "panorama":
out = _run_model_panorama(model, image, image_u8, distance_init_cap_m=0.0)
else:
out = _run_model_pinhole(
model,
image,
image_u8,
intrinsics=explicit_intrinsics,
distance_init_cap_m=0.0,
)
rays = out.get("geometry_rays", out.get("unik3d_gt_rays", out.get("unik3d_rays", None)))
stats = _ray_fov_stats(rays) if torch.is_tensor(rays) else _empty_ray_stats()
elif explicit_camera_params is not None:
camera_kind = "fisheye"
render_camera_params = explicit_camera_params
out = _run_model_fisheye(
model,
image,
image_u8,
camera_params=explicit_camera_params,
distance_init_cap_m=0.0,
)
rays = out.get("geometry_rays", out.get("unik3d_gt_rays", out.get("unik3d_rays", None)))
stats = _ray_fov_stats(rays) if torch.is_tensor(rays) else _empty_ray_stats()
elif forced_camera_name == "panorama" or (
forced_camera_name is None and (json_camera_name in {"panorama", "erp", "spherical"} or aspect_camera_name == "panorama")
):
camera_kind = "panorama"
out = _run_model_panorama(model, image, image_u8, distance_init_cap_m=0.0)
rays = out.get("geometry_rays", out.get("unik3d_gt_rays", out.get("unik3d_rays", None)))
stats = _ray_fov_stats(rays) if torch.is_tensor(rays) else _empty_ray_stats()
else:
rays = _predict_unik3d_rays(model, image_u8, image_h=h, image_w=w)
stats = _ray_fov_stats(rays)
if forced_camera_name == "fisheye":
camera_kind = "fisheye"
elif forced_camera_name == "perspective":
camera_kind = "perspective"
elif json_camera_name in {"fisheye", "fisheye624", "opencv_fisheye"}:
camera_kind = "fisheye"
elif json_camera_name in {"perspective", "pinhole"}:
camera_kind = "perspective"
else:
camera_kind = _classify_camera(stats, args)
if camera_kind == "panorama":
out = _run_model_panorama(model, image, image_u8, distance_init_cap_m=0.0)
elif camera_kind == "fisheye":
render_camera_params = fit_fisheye624_params_from_rays(rays).detach().to(device=device, dtype=torch.float32)
out = _run_model_fisheye(
model,
image,
image_u8,
camera_params=render_camera_params,
distance_init_cap_m=0.0,
)
else:
render_intrinsics = fit_pinhole_intrinsics_from_rays(rays).detach().to(device=device, dtype=torch.float32)
out = _run_model_pinhole(
model,
image,
image_u8,
intrinsics=render_intrinsics,
distance_init_cap_m=0.0,
)
needs_native_panorama = (
camera_kind == "panorama"
and (h < native_h or w < native_w)
and load_max_long_edge != _panorama_max_long_edge()
)
if needs_native_panorama and reload_attempt == 0:
load_max_long_edge = _panorama_max_long_edge()
continue
break
LOGGER.info(
"%s -> %s | hfov=%.1f vfov=%.1f diag=%.1f aspect=%.3f",
image_path,
camera_kind,
stats["horizontal_fov_deg"],
stats["vertical_fov_deg"],
stats["diagonal_fov_deg"],
stats["aspect"],
)
src_w2c = torch.eye(4, dtype=torch.float32, device=device)
gaussians_world = transform_gaussians_to_world(out["gaussians"], src_w2c)
model_output = out if isinstance(out, dict) else {"gaussians": out}
(
forward_distance_m,
rotate_radius_m,
scene_depth_m,
motion_scale,
median_depth_m,
foreground_depth_m,
) = _adaptive_view_motion_distances(
model_output,
default_forward_m=FORWARD_DISTANCE_M,
default_radius_m=ROTATE_RADIUS_M,
)
if float(motion_scale) < 0.999:
LOGGER.info(
"Near-scene view motion | depth_eff=%.3fm median=%.3fm p25=%.3fm scale=%.3f forward=%.3fm orbit=%.3fm",
float(scene_depth_m) if scene_depth_m is not None else float("nan"),
float(median_depth_m) if median_depth_m is not None else float("nan"),
float(foreground_depth_m) if foreground_depth_m is not None else float("nan"),
float(motion_scale),
float(forward_distance_m),
float(rotate_radius_m),
)
forward_poses = _build_forward_poses(
num_views=FORWARD_VIEWS,
distance_m=forward_distance_m,
device=device,
)
rotate_poses = _build_rotate_poses(
num_views=ROTATE_VIEWS,
radius_m=rotate_radius_m,
device=device,
)
sample_dir = out_root / _slug_from_path(image_path)
sample_dir.mkdir(parents=True, exist_ok=True)
output_crop_border_fraction = 0.0 if camera_kind == "panorama" else 0.05
forward_frames: list[np.ndarray] = []
rotate_frames: list[np.ndarray] = []
if camera_kind == "panorama":
face_w = max(16, int(min(h, w // 4)))
forward_dir = sample_dir / "forward_erp"
rotate_dir = sample_dir / "rotate_erp"
rotate_faces_dir = sample_dir / "rotate_cubemap_faces"
forward_dir.mkdir(parents=True, exist_ok=True)
rotate_dir.mkdir(parents=True, exist_ok=True)
for face_name in FACE_NAMES:
(rotate_faces_dir / face_name).mkdir(parents=True, exist_ok=True)
for pose in forward_poses:
erp_u8, _ = _render_panorama_frame_and_faces(
train_renderer,
gaussians_world,
extr_w2c=pose,
equ_h=h,
equ_w=w,
face_w=face_w,
)
forward_dir.joinpath(f"forward_{len(forward_frames):02d}.png").parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(erp_u8).save(forward_dir / f"forward_{len(forward_frames):02d}.png")
forward_frames.append(erp_u8)
for pose in rotate_poses:
erp_u8, face_views = _render_panorama_frame_and_faces(
train_renderer,
gaussians_world,
extr_w2c=pose,
equ_h=h,
equ_w=w,
face_w=face_w,
)
frame_idx = len(rotate_frames)
Image.fromarray(erp_u8).save(rotate_dir / f"rotate_{frame_idx:02d}.png")
for face_name, face_u8 in face_views.items():
Image.fromarray(face_u8).save(rotate_faces_dir / face_name / f"rotate_{frame_idx:02d}_{face_name}.png")
rotate_frames.append(erp_u8)
f_px = float(w) / (2.0 * math.pi)
elif camera_kind == "fisheye":
if render_camera_params is None:
if not torch.is_tensor(rays):
raise RuntimeError("Fisheye ray fitting requires model rays.")
render_camera_params = fit_fisheye624_params_from_rays(rays)
params = render_camera_params
params = params.detach().to(device=device, dtype=torch.float32)
for pose in forward_poses:
forward_frames.append(_render_fisheye_frame(gaussians_world, extr_w2c=pose, camera_params=params, image_h=h, image_w=w))
for pose in rotate_poses:
rotate_frames.append(_render_fisheye_frame(gaussians_world, extr_w2c=pose, camera_params=params, image_h=h, image_w=w))
f_px = float(0.5 * (float(params[0, 0].detach().cpu()) + float(params[0, 1].detach().cpu())))
else:
if render_intrinsics is None:
if not torch.is_tensor(rays):
raise RuntimeError("Pinhole ray fitting requires model rays.")
render_intrinsics = fit_pinhole_intrinsics_from_rays(rays)
intrinsics = render_intrinsics
k3 = intrinsics.detach().to(device=device, dtype=torch.float32)[0]
for pose in forward_poses:
forward_frames.append(_render_pinhole_frame(renderer, gaussians_world, extr_w2c=pose, intrinsics=k3, image_h=h, image_w=w))
for pose in rotate_poses:
rotate_frames.append(_render_pinhole_frame(renderer, gaussians_world, extr_w2c=pose, intrinsics=k3, image_h=h, image_w=w))
f_px = float(0.5 * (float(k3[0, 0].detach().cpu()) + float(k3[1, 1].detach().cpu())))
if output_crop_border_fraction > 0.0:
forward_frames = [_crop_border_u8(frame, output_crop_border_fraction) for frame in forward_frames]
rotate_frames = [_crop_border_u8(frame, output_crop_border_fraction) for frame in rotate_frames]
_save_gif(forward_frames, sample_dir / "forward.gif", duration_ms=GIF_DURATION_MS)
_save_gif(rotate_frames, sample_dir / "rotate.gif", duration_ms=GIF_DURATION_MS)
_save_ply_if_requested(gaussians_world, sample_dir / "gaussians.ply", f_px=f_px, image_h=h, image_w=w, enabled=bool(args.save_ply))
metadata = {
"checkpoint": str(args.checkpoint),
"checkpoint_step": int(step),
"image": str(image_path),
"camera_kind": camera_kind,
"ray_stats": stats,
"camera_json": str(args.camera_json) if args.camera_json is not None else None,
"camera_json_entry": camera_json_entry,
"aspect_camera_name": aspect_camera_name,
"explicit_camera_intrinsics": args.camera_intrinsics,
"explicit_camera_params": args.camera_params,
"forward_distance_m": float(forward_distance_m),
"rotate_radius_m": float(rotate_radius_m),
"forward_distance_m_default": float(FORWARD_DISTANCE_M),
"rotate_radius_m_default": float(ROTATE_RADIUS_M),
"scene_depth_for_motion_m": scene_depth_m,
"median_predicted_depth_m": median_depth_m,
"foreground_depth_p25_m": foreground_depth_m,
"view_motion_scale": float(motion_scale),
"rotate_path": "clockwise_camera_xy_orbit_fixed_source_orientation",
"panorama_renderer": "unisharp.cli.unified_trainer.UnifiedTrainer._render_cubemap/_cube_to_erp",
"low_pass_filter_eps": float(args.low_pass_filter_eps),
"output_crop_border_fraction": float(output_crop_border_fraction),
"height": int(h),
"width": int(w),
}
(sample_dir / "metadata.json").write_text(json.dumps(metadata, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
LOGGER.info("Saved outputs -> %s", sample_dir)
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="UniSharp single-image inference with automatic camera-type detection.")
p.add_argument("--checkpoint", type=Path, required=True)
p.add_argument("--image", type=Path, default=None)
p.add_argument("--image-list", type=Path, default=None)
p.add_argument("--image-dir", type=Path, default=None)
p.add_argument("--out-dir", type=Path, default=REPO_ROOT / "outputs" / "inference")
p.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
p.add_argument("--max-images", type=int, default=0)
p.add_argument("--save-ply", action="store_true")
p.add_argument(
"--camera-json",
type=Path,
default=None,
help="JSON file with calibrated camera parameters. Supports a global object or an images mapping keyed by path/name/stem.",
)
p.add_argument(
"--camera-intrinsics",
type=float,
nargs="+",
default=None,
help="Explicit pinhole intrinsics. Pass fx fy cx cy or 9 row-major K values. If omitted, intrinsics are fitted from rays.",
)
p.add_argument(
"--camera-params",
type=float,
nargs="+",
default=None,
help="Explicit Fisheye624 parameters. Pass 8 values (fx fy cx cy k1 k2 k3 k4) or all 16 values. If omitted, parameters are fitted from rays.",
)
p.add_argument(
"--camera",
type=str,
default="auto",
choices=["auto", "perspective", "pinhole", "fisheye", "panorama", "erp"],
help="Override automatic ray-range camera classification.",
)
p.add_argument("--low-pass-filter-eps", type=float, default=0.0)
return p
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
_configure_torchhub_cache()
args = _build_argparser().parse_args()
args._camera_json_data = _load_camera_json(args.camera_json)
device = torch.device(str(args.device))
model, step = _load_model(Path(args.checkpoint), device=device)
renderer = GSplatRenderer(
color_space="sRGB",
background_color="black",
low_pass_filter_eps=float(args.low_pass_filter_eps),
).to(device)
train_renderer = UnifiedTrainer(
model=model,
renderer=renderer,
loss_fn=None,
device=device,
)
image_paths = _collect_image_paths(args)
Path(args.out_dir).mkdir(parents=True, exist_ok=True)
LOGGER.info("Rendering %d image(s) to %s", len(image_paths), args.out_dir)
for image_path in image_paths:
_process_one(
model=model,
renderer=renderer,
train_renderer=train_renderer,
image_path=Path(image_path),
out_root=Path(args.out_dir),
step=int(step),
args=args,
)
if device.type == "cuda":
torch.cuda.empty_cache()
if __name__ == "__main__":
main()