import os def _normalize_thread_env_var(name): raw_value = os.getenv(name) if raw_value is None: return value = raw_value.strip() if value.endswith("m"): try: os.environ[name] = str(max(1, int(value[:-1]) // 1000)) except ValueError: os.environ.pop(name, None) return try: os.environ[name] = str(max(1, int(value))) except ValueError: os.environ.pop(name, None) for env_name in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"): _normalize_thread_env_var(env_name) import gradio as gr import numpy as np from PIL import Image import torch import tempfile import io import base64 from urllib.request import urlopen from paths import * from vision_tower import VGGT_OriAny_Ref from inference import * from app_utils import * from axis_renderer import BlendRenderer import spaces torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True) print(ckpt_path) # if torch.cuda.is_available(): # mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 # device = torch.device('cuda') # else: # mark_dtype = torch.float16 # device = torch.device('cpu') mark_dtype = torch.bfloat16 # mark_dtype = torch.float16 model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True) model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) model.eval() print('Model loaded.') axis_renderer = BlendRenderer(RENDER_FILE) @spaces.GPU(duration=20) @torch.no_grad() def inf_single_batch(batch): model.to(device='cuda', dtype=mark_dtype) batch_img_inputs = batch.to(device='cuda', dtype=mark_dtype) # (B, S, 3, H, W) # print(batch_img_inputs.shape) B, S, C, H, W = batch_img_inputs.shape pose_enc = model(batch_img_inputs) # (B, S, D) S = 1 pose_enc = pose_enc.view(B*S, -1) angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1) angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90 angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180 # ori_val # trained with BCE loss distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy() # trained with CE loss # distribute = pose_enc[:, 0:360].cpu().float().numpy() alpha_pred = val_fit_alpha(distribute = distribute) # ref_val if S > 1: ref_az_pred = angle_az_pred.reshape(B,S)[:,0] ref_el_pred = angle_el_pred.reshape(B,S)[:,0] ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0] ref_alpha_pred = alpha_pred.reshape(B,S)[:,0] rel_az_pred = angle_az_pred.reshape(B,S)[:,1] rel_el_pred = angle_el_pred.reshape(B,S)[:,1] rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1] else: ref_az_pred = angle_az_pred[0] ref_el_pred = angle_el_pred[0] ref_ro_pred = angle_ro_pred[0] ref_alpha_pred = alpha_pred[0] rel_az_pred = 0. rel_el_pred = 0. rel_ro_pred = 0. ans_dict = { 'ref_az_pred': ref_az_pred, 'ref_el_pred': ref_el_pred, 'ref_ro_pred': ref_ro_pred, 'ref_alpha_pred' : ref_alpha_pred, 'rel_az_pred' : rel_az_pred, 'rel_el_pred' : rel_el_pred, 'rel_ro_pred' : rel_ro_pred, } return ans_dict # input PIL Image @torch.no_grad() def inf_single_case(image_ref, image_tgt): if image_tgt is None: image_list = [image_ref] else: image_list = [image_ref, image_tgt] image_tensors = preprocess_images(image_list, mode="pad") ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0)) print(ans_dict) return ans_dict # ====== 工具函数:安全图像处理 ====== def safe_image_input(image): """确保返回合法的 numpy 数组或 None""" if image is None: return None if isinstance(image, np.ndarray): return image try: return np.array(image) except Exception: return None def preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg): if do_rm_bkg: pil_ref = background_preprocess(pil_ref, True) if pil_tgt is not None: pil_tgt = background_preprocess(pil_tgt, True) return pil_ref, pil_tgt def prepare_numpy_inputs(image_ref, image_tgt, do_rm_bkg): image_ref = safe_image_input(image_ref) image_tgt = safe_image_input(image_tgt) if image_ref is None: raise ValueError("Please upload a reference image before running inference.") pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB") pil_tgt = None if image_tgt is not None: pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB") return preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg) def read_api_image(image_value): if image_value is None: return None if image_value.startswith("data:"): try: _, encoded = image_value.split(",", 1) return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB") except Exception as exc: raise ValueError("Invalid base64 image payload.") from exc if image_value.startswith(("http://", "https://")): try: with urlopen(image_value) as response: return Image.open(io.BytesIO(response.read())).convert("RGB") except Exception as exc: raise ValueError(f"Failed to fetch image URL: {image_value}") from exc raise ValueError("Image input must be a data URL or an http(s) URL.") def serialize_prediction(ans_dict): def safe_float(val, default=0.0): try: return float(val) except Exception: return float(default) return { "ref_az_pred": safe_float(ans_dict.get("ref_az_pred", 0)), "ref_el_pred": safe_float(ans_dict.get("ref_el_pred", 0)), "ref_ro_pred": safe_float(ans_dict.get("ref_ro_pred", 0)), "ref_alpha_pred": int(safe_float(ans_dict.get("ref_alpha_pred", 1), 1)), "rel_az_pred": safe_float(ans_dict.get("rel_az_pred", 0)), "rel_el_pred": safe_float(ans_dict.get("rel_el_pred", 0)), "rel_ro_pred": safe_float(ans_dict.get("rel_ro_pred", 0)), } # ====== 推理函数 ====== @spaces.GPU(duration=20) @torch.no_grad() def run_inference(image_ref, image_tgt, do_rm_bkg): try: pil_ref, pil_tgt = prepare_numpy_inputs(image_ref, image_tgt, do_rm_bkg) except ValueError as exc: raise gr.Error(str(exc)) from exc try: ans_dict = inf_single_case(pil_ref, pil_tgt) except Exception as e: print("Inference error:", e) raise gr.Error(f"Inference failed: {str(e)}") def safe_float(val, default=0.0): try: return float(val) except: return float(default) az = safe_float(ans_dict.get('ref_az_pred', 0)) el = safe_float(ans_dict.get('ref_el_pred', 0)) ro = safe_float(ans_dict.get('ref_ro_pred', 0)) alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是 # ===== 用临时文件保存渲染结果 ===== tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False) tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False) tmp_ref.close() tmp_tgt.close() try: # ===== 渲染参考图的坐标轴 ===== axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name) axis_ref = Image.open(tmp_ref.name).convert("RGBA") # 叠加坐标轴到参考图 # 确保尺寸一致 if axis_ref.size != pil_ref.size: pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC) pil_ref_rgba = pil_ref.convert("RGBA") overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB") # ===== 处理目标图(如果有)===== if pil_tgt is not None: rel_az = safe_float(ans_dict.get('rel_az_pred', 0)) rel_el = safe_float(ans_dict.get('rel_el_pred', 0)) rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0)) tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro) print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot) # target 默认 alpha=1(根据你的说明) axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name) axis_tgt = Image.open(tmp_tgt.name).convert("RGBA") if axis_tgt.size != pil_tgt.size: pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC) pil_tgt_rgba = pil_tgt.convert("RGBA") overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB") else: overlaid_tgt = None rel_az = rel_el = rel_ro = 0.0 finally: # 安全删除临时文件(即使出错也清理) if os.path.exists(tmp_ref.name): os.remove(tmp_ref.name) print('cleaned {}'.format(tmp_ref.name)) if os.path.exists(tmp_tgt.name): os.remove(tmp_tgt.name) print('cleaned {}'.format(tmp_tgt.name)) return [ overlaid_ref, # 渲染+叠加后的参考图 overlaid_tgt, # 渲染+叠加后的目标图(可能为 None) f"{az:.2f}", f"{el:.2f}", f"{ro:.2f}", str(alpha), f"{rel_az:.2f}", f"{rel_el:.2f}", f"{rel_ro:.2f}", ] def predict( image_ref: str, image_tgt: str | None = None, do_rm_bkg: bool = True, ) -> dict[str, float | int]: pil_ref = read_api_image(image_ref) pil_tgt = read_api_image(image_tgt) pil_ref, pil_tgt = preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg) try: return serialize_prediction(inf_single_case(pil_ref, pil_tgt)) except Exception as exc: print("API inference error:", exc) raise gr.Error(f"Inference failed: {exc}") from exc # ====== Gradio Blocks UI ====== with gr.Blocks(title="Orient-Anything-V2 Demo") as demo: gr.Markdown("# Orient-Anything-V2 Demo") gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.") with gr.Row(): # 左侧:输入图像(参考图 + 目标图,同一行) with gr.Column(): with gr.Row(): ref_img = gr.Image( label="Reference Image (required)", type="numpy", height=256, width=256, value=None, interactive=True ) tgt_img = gr.Image( label="Target Image (optional)", type="numpy", height=256, width=256, value=None, interactive=True ) rm_bkg = gr.Checkbox(label="Remove Background", value=True) run_btn = gr.Button("Run Inference", variant="primary") # === 在这里插入示例 === with gr.Row(): gr.Examples( examples=[ ["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"], ["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"], ], inputs=[ref_img, tgt_img], examples_per_page=2, label="Example Inputs (click to load)" ) gr.Examples( examples=[ ["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"], ["assets/examples/bottle.jpg", None], ], inputs=[ref_img, tgt_img], examples_per_page=2, label="" ) # 右侧:结果图像 + 文本输出 with gr.Column(): # 结果图像:参考结果 + 目标结果(可选) with gr.Row(): res_ref_img = gr.Image( label="Rendered Reference", type="pil", height=256, width=256, interactive=False ) res_tgt_img = gr.Image( label="Rendered Target (if provided)", type="pil", height=256, width=256, interactive=False ) # 文本输出放在图像下方 with gr.Row(): with gr.Column(): gr.Markdown("### Absolute Pose (Reference)") az_out = gr.Textbox(label="Azimuth (0~360°)") el_out = gr.Textbox(label="Polar (-90~90°)") ro_out = gr.Textbox(label="Rotation (-90~90°)") alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)") with gr.Column(): gr.Markdown("### Relative Pose (Target w.r.t Reference)") rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)") rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)") rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)") # 绑定点击事件 run_btn.click( fn=run_inference, inputs=[ref_img, tgt_img, rm_bkg], outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out], preprocess=True, postprocess=True ) gr.api( predict, api_name="predict", api_visibility="undocumented", queue=False, ) demo.launch()