Orient-Anything-V2 / app_utils.py
choephix's picture
fix early startup thread and ORT warnings
c011205
import os
def _normalize_thread_env_var(name):
raw_value = os.getenv(name)
if raw_value is None:
return
value = raw_value.strip()
if value.endswith("m"):
try:
os.environ[name] = str(max(1, int(value[:-1]) // 1000))
except ValueError:
os.environ.pop(name, None)
return
try:
os.environ[name] = str(max(1, int(value)))
except ValueError:
os.environ.pop(name, None)
for env_name in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"):
_normalize_thread_env_var(env_name)
try:
import onnxruntime as ort
if hasattr(ort, "set_default_logger_severity"):
ort.set_default_logger_severity(3)
except Exception:
ort = None
import rembg
import random
import torch
import numpy as np
from PIL import Image, ImageOps
import PIL
from typing import Any
import matplotlib.pyplot as plt
import io
def resize_foreground(
image: Image,
ratio: float,
) -> Image:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1:y2, x1:x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_image = Image.fromarray(new_image)
return new_image
def remove_background(image: Image,
rembg_session: Any = None,
force: bool = False,
**rembg_kwargs,
) -> Image:
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def background_preprocess(input_image, do_remove_background):
if input_image is None:
return None
rembg_session = (
rembg.new_session(providers=["CPUExecutionProvider"])
if do_remove_background
else None
)
if do_remove_background:
input_image = remove_background(input_image, rembg_session)
input_image = resize_foreground(input_image, 0.85)
return input_image
def axis_angle_rotation_batch(axis: torch.Tensor, theta: torch.Tensor, homogeneous: bool = False) -> torch.Tensor:
"""
支持batch输入的版本:
Args:
axis: (3,) or (N,3)
theta: scalar or (N,)
homogeneous: 是否输出 4x4 齐次矩阵
Returns:
(N,3,3) or (N,4,4)
"""
axis = torch.as_tensor(axis).float()
theta = torch.as_tensor(theta).float()
if axis.ndim == 1:
axis = axis.unsqueeze(0) # (1,3)
if theta.ndim == 0:
theta = theta.unsqueeze(0) # (1,)
N = axis.shape[0]
# normalize axis
axis = axis / torch.norm(axis, dim=1, keepdim=True)
x, y, z = axis[:, 0], axis[:, 1], axis[:, 2]
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
one_minus_cos = 1 - cos_t
# 公式展开
rot = torch.zeros((N, 3, 3), dtype=axis.dtype, device=axis.device)
rot[:, 0, 0] = cos_t + x*x*one_minus_cos
rot[:, 0, 1] = x*y*one_minus_cos - z*sin_t
rot[:, 0, 2] = x*z*one_minus_cos + y*sin_t
rot[:, 1, 0] = y*x*one_minus_cos + z*sin_t
rot[:, 1, 1] = cos_t + y*y*one_minus_cos
rot[:, 1, 2] = y*z*one_minus_cos - x*sin_t
rot[:, 2, 0] = z*x*one_minus_cos - y*sin_t
rot[:, 2, 1] = z*y*one_minus_cos + x*sin_t
rot[:, 2, 2] = cos_t + z*z*one_minus_cos
if homogeneous:
rot_homo = torch.eye(4, dtype=axis.dtype, device=axis.device).unsqueeze(0).repeat(N, 1, 1)
rot_homo[:, :3, :3] = rot
return rot_homo
return rot
def azi_ele_rot_to_Obj_Rmatrix_batch(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor) -> torch.Tensor:
"""支持batch输入的: (azi, ele, rot) -> R matrix (N,3,3)"""
# 转成tensor
azi = torch.as_tensor(azi).float() * torch.pi / 180.
ele = torch.as_tensor(ele).float() * torch.pi / 180.
rot = torch.as_tensor(rot).float() * torch.pi / 180.
# 保证有batch维度
if azi.ndim == 0:
azi = azi.unsqueeze(0)
if ele.ndim == 0:
ele = ele.unsqueeze(0)
if rot.ndim == 0:
rot = rot.unsqueeze(0)
N = azi.shape[0]
device = azi.device
dtype = azi.dtype
z0_axis = torch.tensor([0.,0.,1.], device=device, dtype=dtype).expand(N, -1)
y0_axis = torch.tensor([0.,1.,0.], device=device, dtype=dtype).expand(N, -1)
x0_axis = torch.tensor([1.,0.,0.], device=device, dtype=dtype).expand(N, -1)
# print(z0_axis.shape, azi.shape)
R_azi = axis_angle_rotation_batch(z0_axis, -1 * azi)
R_ele = axis_angle_rotation_batch(y0_axis, ele)
R_rot = axis_angle_rotation_batch(x0_axis, rot)
R_res = R_rot @ R_ele @ R_azi
return R_res
def Cam_Rmatrix_to_azi_ele_rot_batch(R: torch.Tensor):
"""支持batch输入的: R matrix -> (azi, ele, rot),角度制 (度)"""
R = torch.as_tensor(R).float()
# 如果是(3,3),补batch维度
if R.ndim == 2:
R = R.unsqueeze(0)
r0 = R[:, :, 0] # shape (N,3)
r1 = R[:, :, 1]
r2 = R[:, :, 2]
ele = torch.asin(r0[:, 2]) # r0.z
cos_ele = torch.cos(ele)
# 创建默认azi、rot
azi = torch.zeros_like(ele)
rot = torch.zeros_like(ele)
# 正常情况
normal_mask = (cos_ele.abs() >= 1e-6)
if normal_mask.any():
azi[normal_mask] = torch.atan2(r0[normal_mask, 1], r0[normal_mask, 0])
rot[normal_mask] = torch.atan2(-r1[normal_mask, 2], r2[normal_mask, 2])
# Gimbal lock特殊情况
gimbal_mask = ~normal_mask
if gimbal_mask.any():
# 这里设azi为0
azi[gimbal_mask] = 0.0
rot[gimbal_mask] = torch.atan2(-r1[gimbal_mask, 0], r1[gimbal_mask, 1])
# 弧度转角度
azi = azi * 180. / torch.pi
ele = ele * 180. / torch.pi
rot = rot * 180. / torch.pi
return azi, ele, rot
def Get_target_azi_ele_rot(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor, rel_azi: torch.Tensor, rel_ele: torch.Tensor, rel_rot: torch.Tensor):
Rmat0 = azi_ele_rot_to_Obj_Rmatrix_batch(azi = azi , ele = ele , rot = rot)
Rmat_rel = azi_ele_rot_to_Obj_Rmatrix_batch(azi = rel_azi, ele = rel_ele, rot = rel_rot)
# Rmat_rel = Rmat1 @ Rmat0.permute(0, 2, 1)
# azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat_rel.permute(0, 2, 1))
Rmat1 = Rmat_rel @ Rmat0
azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat1.permute(0, 2, 1))
return azi_out, ele_out, rot_out