choephix's picture
fix Gradio 6 launch argument regression
67ce2d9
import os
def _normalize_thread_env_var(name):
raw_value = os.getenv(name)
if raw_value is None:
return
value = raw_value.strip()
if value.endswith("m"):
try:
os.environ[name] = str(max(1, int(value[:-1]) // 1000))
except ValueError:
os.environ.pop(name, None)
return
try:
os.environ[name] = str(max(1, int(value)))
except ValueError:
os.environ.pop(name, None)
for env_name in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"):
_normalize_thread_env_var(env_name)
import gradio as gr
import numpy as np
from PIL import Image
import torch
import tempfile
import io
import base64
from urllib.request import urlopen
from paths import *
from vision_tower import VGGT_OriAny_Ref
from inference import *
from app_utils import *
from axis_renderer import BlendRenderer
import spaces
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
print(ckpt_path)
# if torch.cuda.is_available():
# mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
# device = torch.device('cuda')
# else:
# mark_dtype = torch.float16
# device = torch.device('cpu')
mark_dtype = torch.bfloat16
# mark_dtype = torch.float16
model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
model.eval()
print('Model loaded.')
axis_renderer = BlendRenderer(RENDER_FILE)
@spaces.GPU(duration=20)
@torch.no_grad()
def inf_single_batch(batch):
model.to(device='cuda', dtype=mark_dtype)
batch_img_inputs = batch.to(device='cuda', dtype=mark_dtype) # (B, S, 3, H, W)
# print(batch_img_inputs.shape)
B, S, C, H, W = batch_img_inputs.shape
pose_enc = model(batch_img_inputs) # (B, S, D) S = 1
pose_enc = pose_enc.view(B*S, -1)
angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1)
angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90
angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180
# ori_val
# trained with BCE loss
distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy()
# trained with CE loss
# distribute = pose_enc[:, 0:360].cpu().float().numpy()
alpha_pred = val_fit_alpha(distribute = distribute)
# ref_val
if S > 1:
ref_az_pred = angle_az_pred.reshape(B,S)[:,0]
ref_el_pred = angle_el_pred.reshape(B,S)[:,0]
ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0]
ref_alpha_pred = alpha_pred.reshape(B,S)[:,0]
rel_az_pred = angle_az_pred.reshape(B,S)[:,1]
rel_el_pred = angle_el_pred.reshape(B,S)[:,1]
rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1]
else:
ref_az_pred = angle_az_pred[0]
ref_el_pred = angle_el_pred[0]
ref_ro_pred = angle_ro_pred[0]
ref_alpha_pred = alpha_pred[0]
rel_az_pred = 0.
rel_el_pred = 0.
rel_ro_pred = 0.
ans_dict = {
'ref_az_pred': ref_az_pred,
'ref_el_pred': ref_el_pred,
'ref_ro_pred': ref_ro_pred,
'ref_alpha_pred' : ref_alpha_pred,
'rel_az_pred' : rel_az_pred,
'rel_el_pred' : rel_el_pred,
'rel_ro_pred' : rel_ro_pred,
}
return ans_dict
# input PIL Image
@torch.no_grad()
def inf_single_case(image_ref, image_tgt):
if image_tgt is None:
image_list = [image_ref]
else:
image_list = [image_ref, image_tgt]
image_tensors = preprocess_images(image_list, mode="pad")
ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0))
print(ans_dict)
return ans_dict
# ====== 工具函数:安全图像处理 ======
def safe_image_input(image):
"""确保返回合法的 numpy 数组或 None"""
if image is None:
return None
if isinstance(image, np.ndarray):
return image
try:
return np.array(image)
except Exception:
return None
def preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg):
if do_rm_bkg:
pil_ref = background_preprocess(pil_ref, True)
if pil_tgt is not None:
pil_tgt = background_preprocess(pil_tgt, True)
return pil_ref, pil_tgt
def prepare_numpy_inputs(image_ref, image_tgt, do_rm_bkg):
image_ref = safe_image_input(image_ref)
image_tgt = safe_image_input(image_tgt)
if image_ref is None:
raise ValueError("Please upload a reference image before running inference.")
pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB")
pil_tgt = None
if image_tgt is not None:
pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB")
return preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg)
def read_api_image(image_value):
if image_value is None:
return None
if image_value.startswith("data:"):
try:
_, encoded = image_value.split(",", 1)
return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")
except Exception as exc:
raise ValueError("Invalid base64 image payload.") from exc
if image_value.startswith(("http://", "https://")):
try:
with urlopen(image_value) as response:
return Image.open(io.BytesIO(response.read())).convert("RGB")
except Exception as exc:
raise ValueError(f"Failed to fetch image URL: {image_value}") from exc
raise ValueError("Image input must be a data URL or an http(s) URL.")
def serialize_prediction(ans_dict):
def safe_float(val, default=0.0):
try:
return float(val)
except Exception:
return float(default)
return {
"ref_az_pred": safe_float(ans_dict.get("ref_az_pred", 0)),
"ref_el_pred": safe_float(ans_dict.get("ref_el_pred", 0)),
"ref_ro_pred": safe_float(ans_dict.get("ref_ro_pred", 0)),
"ref_alpha_pred": int(safe_float(ans_dict.get("ref_alpha_pred", 1), 1)),
"rel_az_pred": safe_float(ans_dict.get("rel_az_pred", 0)),
"rel_el_pred": safe_float(ans_dict.get("rel_el_pred", 0)),
"rel_ro_pred": safe_float(ans_dict.get("rel_ro_pred", 0)),
}
# ====== 推理函数 ======
@spaces.GPU(duration=20)
@torch.no_grad()
def run_inference(image_ref, image_tgt, do_rm_bkg):
try:
pil_ref, pil_tgt = prepare_numpy_inputs(image_ref, image_tgt, do_rm_bkg)
except ValueError as exc:
raise gr.Error(str(exc)) from exc
try:
ans_dict = inf_single_case(pil_ref, pil_tgt)
except Exception as e:
print("Inference error:", e)
raise gr.Error(f"Inference failed: {str(e)}")
def safe_float(val, default=0.0):
try:
return float(val)
except:
return float(default)
az = safe_float(ans_dict.get('ref_az_pred', 0))
el = safe_float(ans_dict.get('ref_el_pred', 0))
ro = safe_float(ans_dict.get('ref_ro_pred', 0))
alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
# ===== 用临时文件保存渲染结果 =====
tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
tmp_ref.close()
tmp_tgt.close()
try:
# ===== 渲染参考图的坐标轴 =====
axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name)
axis_ref = Image.open(tmp_ref.name).convert("RGBA")
# 叠加坐标轴到参考图
# 确保尺寸一致
if axis_ref.size != pil_ref.size:
pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC)
pil_ref_rgba = pil_ref.convert("RGBA")
overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
# ===== 处理目标图(如果有)=====
if pil_tgt is not None:
rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
# target 默认 alpha=1(根据你的说明)
axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name)
axis_tgt = Image.open(tmp_tgt.name).convert("RGBA")
if axis_tgt.size != pil_tgt.size:
pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC)
pil_tgt_rgba = pil_tgt.convert("RGBA")
overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
else:
overlaid_tgt = None
rel_az = rel_el = rel_ro = 0.0
finally:
# 安全删除临时文件(即使出错也清理)
if os.path.exists(tmp_ref.name):
os.remove(tmp_ref.name)
print('cleaned {}'.format(tmp_ref.name))
if os.path.exists(tmp_tgt.name):
os.remove(tmp_tgt.name)
print('cleaned {}'.format(tmp_tgt.name))
return [
overlaid_ref, # 渲染+叠加后的参考图
overlaid_tgt, # 渲染+叠加后的目标图(可能为 None)
f"{az:.2f}",
f"{el:.2f}",
f"{ro:.2f}",
str(alpha),
f"{rel_az:.2f}",
f"{rel_el:.2f}",
f"{rel_ro:.2f}",
]
def predict(
image_ref: str,
image_tgt: str | None = None,
do_rm_bkg: bool = True,
) -> dict[str, float | int]:
pil_ref = read_api_image(image_ref)
pil_tgt = read_api_image(image_tgt)
pil_ref, pil_tgt = preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg)
try:
return serialize_prediction(inf_single_case(pil_ref, pil_tgt))
except Exception as exc:
print("API inference error:", exc)
raise gr.Error(f"Inference failed: {exc}") from exc
# ====== Gradio Blocks UI ======
with gr.Blocks(title="Orient-Anything-V2 Demo") as demo:
gr.Markdown("# Orient-Anything-V2 Demo")
gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
with gr.Row():
# 左侧:输入图像(参考图 + 目标图,同一行)
with gr.Column():
with gr.Row():
ref_img = gr.Image(
label="Reference Image (required)",
type="numpy",
height=256,
width=256,
value=None,
interactive=True
)
tgt_img = gr.Image(
label="Target Image (optional)",
type="numpy",
height=256,
width=256,
value=None,
interactive=True
)
rm_bkg = gr.Checkbox(label="Remove Background", value=True)
run_btn = gr.Button("Run Inference", variant="primary")
# === 在这里插入示例 ===
with gr.Row():
gr.Examples(
examples=[
["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"],
["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"],
],
inputs=[ref_img, tgt_img],
examples_per_page=2,
label="Example Inputs (click to load)"
)
gr.Examples(
examples=[
["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"],
["assets/examples/bottle.jpg", None],
],
inputs=[ref_img, tgt_img],
examples_per_page=2,
label=""
)
# 右侧:结果图像 + 文本输出
with gr.Column():
# 结果图像:参考结果 + 目标结果(可选)
with gr.Row():
res_ref_img = gr.Image(
label="Rendered Reference",
type="pil",
height=256,
width=256,
interactive=False
)
res_tgt_img = gr.Image(
label="Rendered Target (if provided)",
type="pil",
height=256,
width=256,
interactive=False
)
# 文本输出放在图像下方
with gr.Row():
with gr.Column():
gr.Markdown("### Absolute Pose (Reference)")
az_out = gr.Textbox(label="Azimuth (0~360°)")
el_out = gr.Textbox(label="Polar (-90~90°)")
ro_out = gr.Textbox(label="Rotation (-90~90°)")
alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
with gr.Column():
gr.Markdown("### Relative Pose (Target w.r.t Reference)")
rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
# 绑定点击事件
run_btn.click(
fn=run_inference,
inputs=[ref_img, tgt_img, rm_bkg],
outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out],
preprocess=True,
postprocess=True
)
gr.api(
predict,
api_name="predict",
api_visibility="undocumented",
queue=False,
)
demo.launch()