| import os |
|
|
|
|
| def _normalize_thread_env_var(name): |
| raw_value = os.getenv(name) |
| if raw_value is None: |
| return |
|
|
| value = raw_value.strip() |
| if value.endswith("m"): |
| try: |
| os.environ[name] = str(max(1, int(value[:-1]) // 1000)) |
| except ValueError: |
| os.environ.pop(name, None) |
| return |
|
|
| try: |
| os.environ[name] = str(max(1, int(value))) |
| except ValueError: |
| os.environ.pop(name, None) |
|
|
|
|
| for env_name in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"): |
| _normalize_thread_env_var(env_name) |
|
|
| try: |
| import onnxruntime as ort |
|
|
| if hasattr(ort, "set_default_logger_severity"): |
| ort.set_default_logger_severity(3) |
| except Exception: |
| ort = None |
|
|
| import rembg |
| import random |
| import torch |
| import numpy as np |
| from PIL import Image, ImageOps |
| import PIL |
| from typing import Any |
| import matplotlib.pyplot as plt |
| import io |
|
|
| def resize_foreground( |
| image: Image, |
| ratio: float, |
| ) -> Image: |
| image = np.array(image) |
| assert image.shape[-1] == 4 |
| alpha = np.where(image[..., 3] > 0) |
| y1, y2, x1, x2 = ( |
| alpha[0].min(), |
| alpha[0].max(), |
| alpha[1].min(), |
| alpha[1].max(), |
| ) |
| |
| fg = image[y1:y2, x1:x2] |
| |
| size = max(fg.shape[0], fg.shape[1]) |
| ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 |
| ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 |
| new_image = np.pad( |
| fg, |
| ((ph0, ph1), (pw0, pw1), (0, 0)), |
| mode="constant", |
| constant_values=((0, 0), (0, 0), (0, 0)), |
| ) |
|
|
| |
| new_size = int(new_image.shape[0] / ratio) |
| |
| ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 |
| ph1, pw1 = new_size - size - ph0, new_size - size - pw0 |
| new_image = np.pad( |
| new_image, |
| ((ph0, ph1), (pw0, pw1), (0, 0)), |
| mode="constant", |
| constant_values=((0, 0), (0, 0), (0, 0)), |
| ) |
| new_image = Image.fromarray(new_image) |
| return new_image |
|
|
| def remove_background(image: Image, |
| rembg_session: Any = None, |
| force: bool = False, |
| **rembg_kwargs, |
| ) -> Image: |
| do_remove = True |
| if image.mode == "RGBA" and image.getextrema()[3][0] < 255: |
| do_remove = False |
| do_remove = do_remove or force |
| if do_remove: |
| image = rembg.remove(image, session=rembg_session, **rembg_kwargs) |
| return image |
|
|
| def background_preprocess(input_image, do_remove_background): |
| if input_image is None: |
| return None |
| rembg_session = ( |
| rembg.new_session(providers=["CPUExecutionProvider"]) |
| if do_remove_background |
| else None |
| ) |
|
|
| if do_remove_background: |
| input_image = remove_background(input_image, rembg_session) |
| input_image = resize_foreground(input_image, 0.85) |
|
|
| return input_image |
|
|
| def axis_angle_rotation_batch(axis: torch.Tensor, theta: torch.Tensor, homogeneous: bool = False) -> torch.Tensor: |
| """ |
| 支持batch输入的版本: |
| Args: |
| axis: (3,) or (N,3) |
| theta: scalar or (N,) |
| homogeneous: 是否输出 4x4 齐次矩阵 |
| |
| Returns: |
| (N,3,3) or (N,4,4) |
| """ |
| axis = torch.as_tensor(axis).float() |
| theta = torch.as_tensor(theta).float() |
|
|
| if axis.ndim == 1: |
| axis = axis.unsqueeze(0) |
| if theta.ndim == 0: |
| theta = theta.unsqueeze(0) |
|
|
| N = axis.shape[0] |
| |
| |
| axis = axis / torch.norm(axis, dim=1, keepdim=True) |
|
|
| x, y, z = axis[:, 0], axis[:, 1], axis[:, 2] |
| cos_t = torch.cos(theta) |
| sin_t = torch.sin(theta) |
| one_minus_cos = 1 - cos_t |
|
|
| |
| rot = torch.zeros((N, 3, 3), dtype=axis.dtype, device=axis.device) |
| rot[:, 0, 0] = cos_t + x*x*one_minus_cos |
| rot[:, 0, 1] = x*y*one_minus_cos - z*sin_t |
| rot[:, 0, 2] = x*z*one_minus_cos + y*sin_t |
| rot[:, 1, 0] = y*x*one_minus_cos + z*sin_t |
| rot[:, 1, 1] = cos_t + y*y*one_minus_cos |
| rot[:, 1, 2] = y*z*one_minus_cos - x*sin_t |
| rot[:, 2, 0] = z*x*one_minus_cos - y*sin_t |
| rot[:, 2, 1] = z*y*one_minus_cos + x*sin_t |
| rot[:, 2, 2] = cos_t + z*z*one_minus_cos |
|
|
| if homogeneous: |
| rot_homo = torch.eye(4, dtype=axis.dtype, device=axis.device).unsqueeze(0).repeat(N, 1, 1) |
| rot_homo[:, :3, :3] = rot |
| return rot_homo |
|
|
| return rot |
|
|
| def azi_ele_rot_to_Obj_Rmatrix_batch(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor) -> torch.Tensor: |
| """支持batch输入的: (azi, ele, rot) -> R matrix (N,3,3)""" |
| |
| azi = torch.as_tensor(azi).float() * torch.pi / 180. |
| ele = torch.as_tensor(ele).float() * torch.pi / 180. |
| rot = torch.as_tensor(rot).float() * torch.pi / 180. |
|
|
| |
| if azi.ndim == 0: |
| azi = azi.unsqueeze(0) |
| if ele.ndim == 0: |
| ele = ele.unsqueeze(0) |
| if rot.ndim == 0: |
| rot = rot.unsqueeze(0) |
|
|
| N = azi.shape[0] |
| |
| device = azi.device |
| dtype = azi.dtype |
| |
| z0_axis = torch.tensor([0.,0.,1.], device=device, dtype=dtype).expand(N, -1) |
| y0_axis = torch.tensor([0.,1.,0.], device=device, dtype=dtype).expand(N, -1) |
| x0_axis = torch.tensor([1.,0.,0.], device=device, dtype=dtype).expand(N, -1) |
| |
| R_azi = axis_angle_rotation_batch(z0_axis, -1 * azi) |
| R_ele = axis_angle_rotation_batch(y0_axis, ele) |
| R_rot = axis_angle_rotation_batch(x0_axis, rot) |
|
|
| R_res = R_rot @ R_ele @ R_azi |
| return R_res |
|
|
| def Cam_Rmatrix_to_azi_ele_rot_batch(R: torch.Tensor): |
| """支持batch输入的: R matrix -> (azi, ele, rot),角度制 (度)""" |
| R = torch.as_tensor(R).float() |
|
|
| |
| if R.ndim == 2: |
| R = R.unsqueeze(0) |
|
|
| r0 = R[:, :, 0] |
| r1 = R[:, :, 1] |
| r2 = R[:, :, 2] |
|
|
| ele = torch.asin(r0[:, 2]) |
| cos_ele = torch.cos(ele) |
|
|
| |
| azi = torch.zeros_like(ele) |
| rot = torch.zeros_like(ele) |
|
|
| |
| normal_mask = (cos_ele.abs() >= 1e-6) |
| if normal_mask.any(): |
| azi[normal_mask] = torch.atan2(r0[normal_mask, 1], r0[normal_mask, 0]) |
| rot[normal_mask] = torch.atan2(-r1[normal_mask, 2], r2[normal_mask, 2]) |
|
|
| |
| gimbal_mask = ~normal_mask |
| if gimbal_mask.any(): |
| |
| azi[gimbal_mask] = 0.0 |
| rot[gimbal_mask] = torch.atan2(-r1[gimbal_mask, 0], r1[gimbal_mask, 1]) |
|
|
| |
| azi = azi * 180. / torch.pi |
| ele = ele * 180. / torch.pi |
| rot = rot * 180. / torch.pi |
|
|
| return azi, ele, rot |
|
|
| def Get_target_azi_ele_rot(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor, rel_azi: torch.Tensor, rel_ele: torch.Tensor, rel_rot: torch.Tensor): |
| Rmat0 = azi_ele_rot_to_Obj_Rmatrix_batch(azi = azi , ele = ele , rot = rot) |
| Rmat_rel = azi_ele_rot_to_Obj_Rmatrix_batch(azi = rel_azi, ele = rel_ele, rot = rel_rot) |
| |
| |
| |
| Rmat1 = Rmat_rel @ Rmat0 |
| azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat1.permute(0, 2, 1)) |
| |
| return azi_out, ele_out, rot_out |
|
|