Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from contextlib import nullcontext | |
| # DA² utilities | |
| from da2.utils.base import load_config | |
| from da2.utils.model import load_model | |
| from da2.utils.io import read_cv2_image, torch_transform, tensorize | |
| # 引入真正的 Accelerator | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| # ============================================================ | |
| # Global Initialization (Run once at startup) | |
| # ============================================================ | |
| def initialize_app(config_path="configs/infer.json"): | |
| # 1. 初始化 Accelerator (解決 logging 報錯) | |
| accelerator = Accelerator() | |
| # 2. 載入 Config | |
| config = load_config(config_path) | |
| # 設定 Logger | |
| logger = get_logger(__name__, log_level="INFO") | |
| config.setdefault("env", {}) | |
| config["env"]["logger"] = logger | |
| config["env"].setdefault("seed", 42) | |
| accelerator.print(f"Running on device: {accelerator.device}") | |
| # 3. 載入模型 (Global Load,避免每次推論重載) | |
| model = load_model(config, accelerator) | |
| model = model.to(accelerator.device) | |
| model.eval() | |
| return config, accelerator, model | |
| # 初始化全局變數 | |
| try: | |
| CONFIG, ACCELERATOR, MODEL = initialize_app() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| CONFIG, ACCELERATOR, MODEL = None, None, None | |
| # ============================================================ | |
| # Mask loader | |
| # ============================================================ | |
| def read_mask_demo(mask_path, img_shape): | |
| if mask_path is None: | |
| if len(img_shape) == 3: | |
| return np.ones((img_shape[1], img_shape[2]), dtype=bool) | |
| return np.ones(img_shape[:2], dtype=bool) | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| if len(img_shape) == 3: | |
| return np.ones((img_shape[1], img_shape[2]), dtype=bool) | |
| return np.ones(img_shape[:2], dtype=bool) | |
| # 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖 | |
| # 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰 | |
| if mask.shape[:2] != img_shape[-2:]: | |
| mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST) | |
| return mask > 0 | |
| # ============================================================ | |
| # Core inference function | |
| # ============================================================ | |
| def run_inference_and_save_depth(image_path, mask_path=None): | |
| if MODEL is None: | |
| return None, "Error: Model not loaded." | |
| device = ACCELERATOR.device | |
| # 1. 讀取原始圖片 | |
| cv2_img = read_cv2_image(image_path) | |
| if cv2_img is None: | |
| print(f"Error reading image: {image_path}") | |
| return None, None | |
| # 【關鍵修復】獲取原始圖片的寬高 (Height, Width) | |
| # cv2 shape 是 (H, W, C) | |
| original_h, original_w = cv2_img.shape[:2] | |
| # 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸) | |
| img_tensor = torch_transform(cv2_img) # CxHxW tensor | |
| # 處理 Mask | |
| mask = read_mask_demo(mask_path, img_tensor.shape) | |
| # 準備輸入 | |
| model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32") | |
| input_tensor = tensorize(img_tensor, model_dtype, device) | |
| # Autocast | |
| use_autocast = (device.type == "cuda") | |
| autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext() | |
| # 3. 推論 | |
| with autocast_ctx, torch.no_grad(): | |
| pred = MODEL(input_tensor) | |
| if isinstance(pred, (tuple, list)): | |
| pred = pred[0] | |
| # 轉回 Numpy float32 | |
| depth = pred.float().cpu().numpy() | |
| # 4. 後處理 | |
| depth = np.squeeze(depth) | |
| # 【關鍵修復】將深度圖 Resize 回原始尺寸 | |
| # cv2.resize 接受的參數是 (Width, Height) | |
| if (depth.shape[0] != original_h) or (depth.shape[1] != original_w): | |
| depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC) | |
| # 5. 正規化 (Normalization) -> 8-bit | |
| dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth)) | |
| if dmax - dmin > 1e-6: | |
| depth_norm = (depth - dmin) / (dmax - dmin) | |
| else: | |
| depth_norm = np.zeros_like(depth, dtype=np.float32) | |
| depth_8bit = (depth_norm * 255).astype(np.uint8) | |
| # 6. 儲存 | |
| os.makedirs("outputs", exist_ok=True) | |
| base = os.path.splitext(os.path.basename(image_path))[0] | |
| out_path = f"outputs/{base}_depth.png" | |
| cv2.imwrite(out_path, depth_8bit) | |
| return depth_8bit, out_path | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| def gradio_fn(image, mask): | |
| if image is None: | |
| return None, None | |
| depth_img, out_path = run_inference_and_save_depth(image, mask) | |
| return depth_img, out_path | |
| demo = gr.Interface( | |
| fn=gradio_fn, | |
| inputs=[ | |
| gr.Image(label="Input Image", type="filepath"), | |
| gr.Image(label="Optional Mask", type="filepath"), | |
| ], | |
| outputs=[ | |
| gr.Image(label="Depth (8-bit Grayscale)", type="numpy"), | |
| gr.File(label="Download Depth PNG"), | |
| ], | |
| title="DA² — Minimal Depth Demo", | |
| description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).", | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |