| import os |
|
|
|
|
| def _normalize_thread_env_var(name): |
| raw_value = os.getenv(name) |
| if raw_value is None: |
| return |
|
|
| value = raw_value.strip() |
| if value.endswith("m"): |
| try: |
| os.environ[name] = str(max(1, int(value[:-1]) // 1000)) |
| except ValueError: |
| os.environ.pop(name, None) |
| return |
|
|
| try: |
| os.environ[name] = str(max(1, int(value))) |
| except ValueError: |
| os.environ.pop(name, None) |
|
|
|
|
| for env_name in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"): |
| _normalize_thread_env_var(env_name) |
|
|
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import torch |
| import tempfile |
| import io |
| import base64 |
| from urllib.request import urlopen |
|
|
| from paths import * |
| from vision_tower import VGGT_OriAny_Ref |
| from inference import * |
| from app_utils import * |
| from axis_renderer import BlendRenderer |
| import spaces |
|
|
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
|
|
| from huggingface_hub import hf_hub_download |
| ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True) |
| print(ckpt_path) |
|
|
| |
| |
| |
| |
| |
| |
| mark_dtype = torch.bfloat16 |
| |
|
|
| model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True) |
| model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) |
| model.eval() |
| print('Model loaded.') |
|
|
| axis_renderer = BlendRenderer(RENDER_FILE) |
|
|
| @spaces.GPU(duration=20) |
| @torch.no_grad() |
| def inf_single_batch(batch): |
| model.to(device='cuda', dtype=mark_dtype) |
| batch_img_inputs = batch.to(device='cuda', dtype=mark_dtype) |
| |
| B, S, C, H, W = batch_img_inputs.shape |
| pose_enc = model(batch_img_inputs) |
| |
| pose_enc = pose_enc.view(B*S, -1) |
| angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1) |
| angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90 |
| angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180 |
| |
| |
| |
| distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy() |
| |
| |
| alpha_pred = val_fit_alpha(distribute = distribute) |
|
|
| |
| if S > 1: |
| ref_az_pred = angle_az_pred.reshape(B,S)[:,0] |
| ref_el_pred = angle_el_pred.reshape(B,S)[:,0] |
| ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0] |
| ref_alpha_pred = alpha_pred.reshape(B,S)[:,0] |
| rel_az_pred = angle_az_pred.reshape(B,S)[:,1] |
| rel_el_pred = angle_el_pred.reshape(B,S)[:,1] |
| rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1] |
| else: |
| ref_az_pred = angle_az_pred[0] |
| ref_el_pred = angle_el_pred[0] |
| ref_ro_pred = angle_ro_pred[0] |
| ref_alpha_pred = alpha_pred[0] |
| rel_az_pred = 0. |
| rel_el_pred = 0. |
| rel_ro_pred = 0. |
|
|
| ans_dict = { |
| 'ref_az_pred': ref_az_pred, |
| 'ref_el_pred': ref_el_pred, |
| 'ref_ro_pred': ref_ro_pred, |
| 'ref_alpha_pred' : ref_alpha_pred, |
| 'rel_az_pred' : rel_az_pred, |
| 'rel_el_pred' : rel_el_pred, |
| 'rel_ro_pred' : rel_ro_pred, |
| } |
| |
| return ans_dict |
|
|
| |
| @torch.no_grad() |
| def inf_single_case(image_ref, image_tgt): |
| if image_tgt is None: |
| image_list = [image_ref] |
| else: |
| image_list = [image_ref, image_tgt] |
| image_tensors = preprocess_images(image_list, mode="pad") |
| ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0)) |
| print(ans_dict) |
| return ans_dict |
|
|
|
|
| |
| def safe_image_input(image): |
| """确保返回合法的 numpy 数组或 None""" |
| if image is None: |
| return None |
| if isinstance(image, np.ndarray): |
| return image |
| try: |
| return np.array(image) |
| except Exception: |
| return None |
|
|
|
|
| def preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg): |
| if do_rm_bkg: |
| pil_ref = background_preprocess(pil_ref, True) |
| if pil_tgt is not None: |
| pil_tgt = background_preprocess(pil_tgt, True) |
| return pil_ref, pil_tgt |
|
|
|
|
| def prepare_numpy_inputs(image_ref, image_tgt, do_rm_bkg): |
| image_ref = safe_image_input(image_ref) |
| image_tgt = safe_image_input(image_tgt) |
|
|
| if image_ref is None: |
| raise ValueError("Please upload a reference image before running inference.") |
|
|
| pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB") |
| pil_tgt = None |
|
|
| if image_tgt is not None: |
| pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB") |
| return preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg) |
|
|
|
|
| def read_api_image(image_value): |
| if image_value is None: |
| return None |
|
|
| if image_value.startswith("data:"): |
| try: |
| _, encoded = image_value.split(",", 1) |
| return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB") |
| except Exception as exc: |
| raise ValueError("Invalid base64 image payload.") from exc |
|
|
| if image_value.startswith(("http://", "https://")): |
| try: |
| with urlopen(image_value) as response: |
| return Image.open(io.BytesIO(response.read())).convert("RGB") |
| except Exception as exc: |
| raise ValueError(f"Failed to fetch image URL: {image_value}") from exc |
|
|
| raise ValueError("Image input must be a data URL or an http(s) URL.") |
|
|
|
|
| def serialize_prediction(ans_dict): |
| def safe_float(val, default=0.0): |
| try: |
| return float(val) |
| except Exception: |
| return float(default) |
|
|
| return { |
| "ref_az_pred": safe_float(ans_dict.get("ref_az_pred", 0)), |
| "ref_el_pred": safe_float(ans_dict.get("ref_el_pred", 0)), |
| "ref_ro_pred": safe_float(ans_dict.get("ref_ro_pred", 0)), |
| "ref_alpha_pred": int(safe_float(ans_dict.get("ref_alpha_pred", 1), 1)), |
| "rel_az_pred": safe_float(ans_dict.get("rel_az_pred", 0)), |
| "rel_el_pred": safe_float(ans_dict.get("rel_el_pred", 0)), |
| "rel_ro_pred": safe_float(ans_dict.get("rel_ro_pred", 0)), |
| } |
|
|
|
|
| |
| @spaces.GPU(duration=20) |
| @torch.no_grad() |
| def run_inference(image_ref, image_tgt, do_rm_bkg): |
| try: |
| pil_ref, pil_tgt = prepare_numpy_inputs(image_ref, image_tgt, do_rm_bkg) |
| except ValueError as exc: |
| raise gr.Error(str(exc)) from exc |
|
|
| try: |
| ans_dict = inf_single_case(pil_ref, pil_tgt) |
| except Exception as e: |
| print("Inference error:", e) |
| raise gr.Error(f"Inference failed: {str(e)}") |
|
|
| def safe_float(val, default=0.0): |
| try: |
| return float(val) |
| except: |
| return float(default) |
|
|
| az = safe_float(ans_dict.get('ref_az_pred', 0)) |
| el = safe_float(ans_dict.get('ref_el_pred', 0)) |
| ro = safe_float(ans_dict.get('ref_ro_pred', 0)) |
| alpha = int(ans_dict.get('ref_alpha_pred', 1)) |
|
|
| |
| tmp_ref = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
| tmp_tgt = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
| tmp_ref.close() |
| tmp_tgt.close() |
|
|
| try: |
| |
| axis_renderer.render_axis(az, el, ro, alpha, save_path=tmp_ref.name) |
| axis_ref = Image.open(tmp_ref.name).convert("RGBA") |
|
|
| |
| |
| if axis_ref.size != pil_ref.size: |
| pil_ref = pil_ref.resize(axis_ref.size, Image.BICUBIC) |
| pil_ref_rgba = pil_ref.convert("RGBA") |
| overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB") |
|
|
| |
| if pil_tgt is not None: |
| rel_az = safe_float(ans_dict.get('rel_az_pred', 0)) |
| rel_el = safe_float(ans_dict.get('rel_el_pred', 0)) |
| rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0)) |
|
|
| tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro) |
| print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot) |
| |
| |
| axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=tmp_tgt.name) |
| axis_tgt = Image.open(tmp_tgt.name).convert("RGBA") |
|
|
| if axis_tgt.size != pil_tgt.size: |
| pil_tgt = pil_tgt.resize(axis_tgt.size, Image.BICUBIC) |
| pil_tgt_rgba = pil_tgt.convert("RGBA") |
| overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB") |
| else: |
| overlaid_tgt = None |
| rel_az = rel_el = rel_ro = 0.0 |
| finally: |
| |
| if os.path.exists(tmp_ref.name): |
| os.remove(tmp_ref.name) |
| print('cleaned {}'.format(tmp_ref.name)) |
| if os.path.exists(tmp_tgt.name): |
| os.remove(tmp_tgt.name) |
| print('cleaned {}'.format(tmp_tgt.name)) |
|
|
| return [ |
| overlaid_ref, |
| overlaid_tgt, |
| f"{az:.2f}", |
| f"{el:.2f}", |
| f"{ro:.2f}", |
| str(alpha), |
| f"{rel_az:.2f}", |
| f"{rel_el:.2f}", |
| f"{rel_ro:.2f}", |
| ] |
|
|
|
|
| def predict( |
| image_ref: str, |
| image_tgt: str | None = None, |
| do_rm_bkg: bool = True, |
| ) -> dict[str, float | int]: |
| pil_ref = read_api_image(image_ref) |
| pil_tgt = read_api_image(image_tgt) |
| pil_ref, pil_tgt = preprocess_pil_inputs(pil_ref, pil_tgt, do_rm_bkg) |
|
|
| try: |
| return serialize_prediction(inf_single_case(pil_ref, pil_tgt)) |
| except Exception as exc: |
| print("API inference error:", exc) |
| raise gr.Error(f"Inference failed: {exc}") from exc |
|
|
|
|
| |
| with gr.Blocks(title="Orient-Anything-V2 Demo") as demo: |
| gr.Markdown("# Orient-Anything-V2 Demo") |
| gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.") |
|
|
| with gr.Row(): |
| |
| with gr.Column(): |
| with gr.Row(): |
| ref_img = gr.Image( |
| label="Reference Image (required)", |
| type="numpy", |
| height=256, |
| width=256, |
| value=None, |
| interactive=True |
| ) |
| tgt_img = gr.Image( |
| label="Target Image (optional)", |
| type="numpy", |
| height=256, |
| width=256, |
| value=None, |
| interactive=True |
| ) |
| rm_bkg = gr.Checkbox(label="Remove Background", value=True) |
| run_btn = gr.Button("Run Inference", variant="primary") |
| |
| with gr.Row(): |
| gr.Examples( |
| examples=[ |
| ["assets/examples/F35-0.jpg", "assets/examples/F35-1.jpg"], |
| ["assets/examples/skateboard-0.jpg", "assets/examples/skateboard-1.jpg"], |
| ], |
| inputs=[ref_img, tgt_img], |
| examples_per_page=2, |
| label="Example Inputs (click to load)" |
| ) |
| gr.Examples( |
| examples=[ |
| ["assets/examples/table-0.jpg", "assets/examples/table-1.jpg"], |
| ["assets/examples/bottle.jpg", None], |
| ], |
| inputs=[ref_img, tgt_img], |
| examples_per_page=2, |
| label="" |
| ) |
|
|
| |
| with gr.Column(): |
| |
| with gr.Row(): |
| res_ref_img = gr.Image( |
| label="Rendered Reference", |
| type="pil", |
| height=256, |
| width=256, |
| interactive=False |
| ) |
| res_tgt_img = gr.Image( |
| label="Rendered Target (if provided)", |
| type="pil", |
| height=256, |
| width=256, |
| interactive=False |
| ) |
| |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Absolute Pose (Reference)") |
| az_out = gr.Textbox(label="Azimuth (0~360°)") |
| el_out = gr.Textbox(label="Polar (-90~90°)") |
| ro_out = gr.Textbox(label="Rotation (-90~90°)") |
| alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)") |
| with gr.Column(): |
| gr.Markdown("### Relative Pose (Target w.r.t Reference)") |
| rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)") |
| rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)") |
| rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)") |
|
|
| |
| run_btn.click( |
| fn=run_inference, |
| inputs=[ref_img, tgt_img, rm_bkg], |
| outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out], |
| preprocess=True, |
| postprocess=True |
| ) |
|
|
| gr.api( |
| predict, |
| api_name="predict", |
| api_visibility="undocumented", |
| queue=False, |
| ) |
|
|
| demo.launch() |
|
|