WJAD / src /wjad /modules /rays.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""f-theta 相机模型 + 射线计算。
依据 Cosmos-Drive-Dreams 数据集 README:
ftheta_intrinsic 存储为 ``[cx, cy, w, h, *poly(6), is_bw_poly, *linear_cde(3)]``。
f-theta 相机模型用 6 阶多项式将像素半径 ``r_pix = ||(u-cx, v-cy)||`` 映射到
入射角 ``theta``(或反向)。``is_bw_poly == True`` 表示多项式是从 ``r_pix`` 反
求 ``theta`` 的 backward polynomial(pixel -> theta);否则是 forward polynomial
(theta -> r_pix)。``linear_cde`` 是仿射修正系数 ``[c, d, e]``,用于补偿轻微
的非旋转对称形变。
为了简单与可微,本模块默认假设 backward polynomial(``is_bw_poly=True``,
即 ``theta = poly(r_pix)``);实际数据通常是这种格式。如需 forward 多项式,
这里使用牛顿迭代反求。
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass
class FThetaIntrinsic:
"""f-theta 内参(PyTorch 张量形式)。
所有字段均为标量或一维向量;外部使用时通常 broadcast 到 batch。
"""
cx: torch.Tensor # ()
cy: torch.Tensor # ()
w: int
h: int
poly: torch.Tensor # (6,)
is_bw_poly: bool
linear_cde: torch.Tensor # (3,)
class FThetaCamera:
"""f-theta 相机:像素 -> 单位射线方向(相机坐标系)。"""
def __init__(self, intr: FThetaIntrinsic) -> None:
self.intr = intr
@staticmethod
def from_vector(vec: torch.Tensor) -> "FThetaCamera":
"""从 NVIDIA ftheta 向量构造:``[cx, cy, w, h, poly×6, is_bw_poly?, linear_cde×3?]``。
官方常见 14 维;部分 clip 仅 11 维(无 ``linear_cde``),此时用 ``(1,0,1)``,
与 ``unproject`` 里近似一致。
"""
v = vec.flatten().float()
n = int(v.numel())
if n < 10:
raise ValueError(f"ftheta intrinsic 维度 {n} < 10(至少需要 cx,cy,w,h + 6 poly)")
cx = v[0]
cy = v[1]
w = int(v[2].item())
h = int(v[3].item())
poly = v[4:10].clone()
if n >= 11:
is_bw = bool(v[10].item() > 0.5)
else:
is_bw = True
if n >= 14:
linear_cde = v[11:14].clone()
else:
linear_cde = torch.tensor([1.0, 0.0, 1.0], dtype=v.dtype, device=v.device)
return FThetaCamera(
FThetaIntrinsic(cx=cx, cy=cy, w=w, h=h, poly=poly, is_bw_poly=is_bw, linear_cde=linear_cde)
)
def _eval_poly(self, r: torch.Tensor) -> torch.Tensor:
"""用 Horner 法计算 poly(r) = sum_{i=0..5} c_i * r^i。"""
c = self.intr.poly
out = torch.zeros_like(r)
for i in range(c.numel() - 1, -1, -1):
out = out * r + c[i]
return out
def _eval_poly_grad(self, r: torch.Tensor) -> torch.Tensor:
"""poly 的导数。"""
c = self.intr.poly
n = c.numel()
out = torch.zeros_like(r)
for i in range(n - 1, 0, -1):
out = out * r + c[i] * float(i)
return out
def pixel_to_theta(self, r_pix: torch.Tensor) -> torch.Tensor:
"""像素半径 -> 入射角 theta(弧度)。"""
if self.intr.is_bw_poly:
return self._eval_poly(r_pix)
# forward: r_pix = poly(theta) -> 牛顿迭代
theta = r_pix.clone()
for _ in range(8):
f = self._eval_poly(theta) - r_pix
df = self._eval_poly_grad(theta).clamp_min(1e-6)
theta = theta - f / df
return theta
def unproject(self, uv: torch.Tensor) -> torch.Tensor:
"""像素坐标 ``[..., 2]`` -> 相机坐标系下的单位方向 ``[..., 3]``。
f-theta 反投影:
(du, dv) = (u - cx, v - cy) (并应用 linear_cde 的微小仿射)
r_pix = ||(du, dv)||
theta = poly(r_pix) 或 inv_poly(r_pix)
phi = atan2(dv, du)
dir_cam = (sin(theta)*cos(phi), sin(theta)*sin(phi), cos(theta))
"""
cx = self.intr.cx
cy = self.intr.cy
c, d, e = self.intr.linear_cde[0], self.intr.linear_cde[1], self.intr.linear_cde[2]
u = uv[..., 0]
v = uv[..., 1]
# 应用线性修正:du' = c*du + d*dv + e*1(NVIDIA 工具中通常是 2x2 仿射,这里做近似)
du0 = u - cx
dv0 = v - cy
du = c * du0 + d * dv0
dv = e * du0 + dv0 # 简化:保持 dv 不变量、加入 e*du 微调
r_pix = torch.sqrt(du * du + dv * dv).clamp_min(1e-6)
theta = self.pixel_to_theta(r_pix)
sin_t = torch.sin(theta)
cos_t = torch.cos(theta)
cos_p = du / r_pix
sin_p = dv / r_pix
x = sin_t * cos_p
y = sin_t * sin_p
z = cos_t
dir_cam = torch.stack([x, y, z], dim=-1)
return F.normalize(dir_cam, dim=-1)
def compute_ego_rays(
intr_vec: torch.Tensor,
cam2vehicle: torch.Tensor,
height: int,
width: int,
grid_h: int,
grid_w: int,
device: torch.device,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""对一个 ``grid_h x grid_w`` 的均匀像素网格计算自车系下单位射线方向。
参数
----
intr_vec : ``[B, 14]`` 或 ``[14]``,f-theta 内参向量。
cam2vehicle : ``[B, 4, 4]`` 或 ``[4, 4]`` 相机系到自车系的变换。
height, width : 图像分辨率(像素),用于在 ``[0, w] x [0, h]`` 网格采样。
grid_h, grid_w : 输出射线网格分辨率(与 patch 网格一致,例如 24x64)。
返回
----
rays : ``[B, grid_h, grid_w, 3]``,自车系下单位方向。
"""
if intr_vec.dim() == 1:
intr_vec = intr_vec.unsqueeze(0)
if cam2vehicle.dim() == 2:
cam2vehicle = cam2vehicle.unsqueeze(0)
B = intr_vec.shape[0]
# 在像素中心采样
u = (torch.arange(grid_w, device=device, dtype=dtype) + 0.5) * (width / grid_w)
v = (torch.arange(grid_h, device=device, dtype=dtype) + 0.5) * (height / grid_h)
vv, uu = torch.meshgrid(v, u, indexing="ij") # [gh, gw]
uv = torch.stack([uu, vv], dim=-1) # [gh, gw, 2]
out = []
for b in range(B):
cam = FThetaCamera.from_vector(intr_vec[b].to(dtype))
dir_cam = cam.unproject(uv) # [gh, gw, 3]
# 旋到自车系:取 cam2vehicle 的 3x3 旋转部分
R = cam2vehicle[b, :3, :3].to(dtype)
dir_veh = dir_cam @ R.T # [gh, gw, 3]
out.append(F.normalize(dir_veh, dim=-1))
return torch.stack(out, dim=0)