WJAD / src /wjad /data /transforms.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""图像与运动学的数据增广。"""
from __future__ import annotations
import numpy as np
import torch
# DINOv3 的 ImageNet 标准化参数
DINOV3_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
DINOV3_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
def crop_top_half(image: torch.Tensor) -> torch.Tensor:
"""裁去图像上半部分(主要是天空)。
输入 ``[3, H, W]`` 或 ``[T, 3, H, W]``;返回相同维度但 H 减半。
"""
if image.dim() == 4:
h = image.shape[2]
return image[:, :, h // 2 :, :]
elif image.dim() == 3:
h = image.shape[1]
return image[:, h // 2 :, :]
raise ValueError(f"unsupported image dim: {image.dim()}")
def normalize_image(image: torch.Tensor, mean: torch.Tensor = DINOV3_MEAN, std: torch.Tensor = DINOV3_STD) -> torch.Tensor:
"""对 [0, 1] 范围的图像做标准化。支持 ``[3,H,W]``/``[T,3,H,W]``/``[B,T,3,H,W]``。"""
while mean.dim() < image.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return (image - mean.to(image.device, image.dtype)) / std.to(image.device, image.dtype)
def add_gaussian_noise(image: torch.Tensor, std: float = 0.01) -> torch.Tensor:
"""高斯噪声增广。``image`` 应已归一化(mean=0,std=1 之后)。"""
if std <= 0:
return image
return image + torch.randn_like(image) * std
def perturb_kinematics(
ego_6d: torch.Tensor, # [T, 6]
intr_vec: torch.Tensor, # [14]
extr_6d: torch.Tensor, # [6]
translation_std_m: float,
rotation_std_deg: float,
intrinsic_std: float,
extrinsic_std: float,
rng: np.random.Generator,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""在 Stage1 中期对运动学和内外参添加微小扰动,作为校准训练增广。
返回扰动后值与扰动量(GT 残差 = -扰动量,因为校准网络要把扰动反推回去)。
返回
----
perturbed_ego, perturbed_intr, perturbed_extr,
gt_residual_concat (在 symlog 空间作为 calibration 监督,可选;
本文件仅返回扰动后的真实空间值,校准 GT 由 trainer 构造)
"""
rot_std_rad = np.deg2rad(rotation_std_deg)
# ego 8x6
delta_ego = np.zeros_like(ego_6d.numpy())
delta_ego[:, :3] = rng.normal(0.0, translation_std_m, size=(ego_6d.shape[0], 3))
delta_ego[:, 3:] = rng.normal(0.0, rot_std_rad, size=(ego_6d.shape[0], 3))
perturbed_ego = ego_6d + torch.from_numpy(delta_ego).to(ego_6d)
# intrinsic 14
delta_intr = rng.normal(0.0, intrinsic_std, size=(intr_vec.shape[0],))
perturbed_intr = intr_vec + torch.from_numpy(delta_intr).to(intr_vec)
# extrinsic 6
delta_extr = np.zeros_like(extr_6d.numpy())
delta_extr[:3] = rng.normal(0.0, extrinsic_std, size=(3,))
delta_extr[3:] = rng.normal(0.0, rot_std_rad, size=(3,))
perturbed_extr = extr_6d + torch.from_numpy(delta_extr).to(extr_6d)
return (
perturbed_ego,
perturbed_intr,
perturbed_extr,
torch.from_numpy(np.concatenate([delta_ego.flatten(), delta_intr, delta_extr])).to(ego_6d.dtype),
)