Spaces:
Build error
Build error
File size: 5,503 Bytes
6b12a63 e237bf4 6b12a63 e237bf4 59f5df0 6b12a63 e9ff5bc 9965e76 9701a8b 05c1031 59f5df0 05c1031 59f5df0 05c1031 9701a8b 05c1031 e9ff5bc 9701a8b e9ff5bc 05c1031 9701a8b 05c1031 59f5df0 05c1031 e9ff5bc 59f5df0 05c1031 59f5df0 e9ff5bc 6b12a63 59f5df0 6b12a63 e9ff5bc 59f5df0 9701a8b e237bf4 6b12a63 59f5df0 05c1031 e9ff5bc 05c1031 e9ff5bc 9701a8b 59f5df0 05c1031 9701a8b 59f5df0 05c1031 9701a8b e9ff5bc 9701a8b 05c1031 59f5df0 e9ff5bc 9701a8b 59f5df0 e237bf4 9701a8b e9ff5bc 05c1031 59f5df0 9701a8b 59f5df0 9701a8b 59f5df0 9701a8b 59f5df0 e9ff5bc 59f5df0 9701a8b e9ff5bc 59f5df0 e9ff5bc 59f5df0 e9ff5bc 59f5df0 e9ff5bc 05c1031 e9ff5bc 6b12a63 e9ff5bc 59f5df0 e9ff5bc 59f5df0 e9ff5bc 59f5df0 9701a8b 59f5df0 6b12a63 e9ff5bc 05c1031 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import os
import cv2
import torch
import numpy as np
import gradio as gr
from contextlib import nullcontext
# DA² utilities
from da2.utils.base import load_config
from da2.utils.model import load_model
from da2.utils.io import read_cv2_image, torch_transform, tensorize
# 引入真正的 Accelerator
from accelerate import Accelerator
from accelerate.logging import get_logger
# ============================================================
# Global Initialization (Run once at startup)
# ============================================================
def initialize_app(config_path="configs/infer.json"):
# 1. 初始化 Accelerator (解決 logging 報錯)
accelerator = Accelerator()
# 2. 載入 Config
config = load_config(config_path)
# 設定 Logger
logger = get_logger(__name__, log_level="INFO")
config.setdefault("env", {})
config["env"]["logger"] = logger
config["env"].setdefault("seed", 42)
accelerator.print(f"Running on device: {accelerator.device}")
# 3. 載入模型 (Global Load,避免每次推論重載)
model = load_model(config, accelerator)
model = model.to(accelerator.device)
model.eval()
return config, accelerator, model
# 初始化全局變數
try:
CONFIG, ACCELERATOR, MODEL = initialize_app()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
CONFIG, ACCELERATOR, MODEL = None, None, None
# ============================================================
# Mask loader
# ============================================================
def read_mask_demo(mask_path, img_shape):
if mask_path is None:
if len(img_shape) == 3:
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
return np.ones(img_shape[:2], dtype=bool)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
if len(img_shape) == 3:
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
return np.ones(img_shape[:2], dtype=bool)
# 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖
# 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰
if mask.shape[:2] != img_shape[-2:]:
mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST)
return mask > 0
# ============================================================
# Core inference function
# ============================================================
def run_inference_and_save_depth(image_path, mask_path=None):
if MODEL is None:
return None, "Error: Model not loaded."
device = ACCELERATOR.device
# 1. 讀取原始圖片
cv2_img = read_cv2_image(image_path)
if cv2_img is None:
print(f"Error reading image: {image_path}")
return None, None
# 【關鍵修復】獲取原始圖片的寬高 (Height, Width)
# cv2 shape 是 (H, W, C)
original_h, original_w = cv2_img.shape[:2]
# 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸)
img_tensor = torch_transform(cv2_img) # CxHxW tensor
# 處理 Mask
mask = read_mask_demo(mask_path, img_tensor.shape)
# 準備輸入
model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
input_tensor = tensorize(img_tensor, model_dtype, device)
# Autocast
use_autocast = (device.type == "cuda")
autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()
# 3. 推論
with autocast_ctx, torch.no_grad():
pred = MODEL(input_tensor)
if isinstance(pred, (tuple, list)):
pred = pred[0]
# 轉回 Numpy float32
depth = pred.float().cpu().numpy()
# 4. 後處理
depth = np.squeeze(depth)
# 【關鍵修復】將深度圖 Resize 回原始尺寸
# cv2.resize 接受的參數是 (Width, Height)
if (depth.shape[0] != original_h) or (depth.shape[1] != original_w):
depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC)
# 5. 正規化 (Normalization) -> 8-bit
dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))
if dmax - dmin > 1e-6:
depth_norm = (depth - dmin) / (dmax - dmin)
else:
depth_norm = np.zeros_like(depth, dtype=np.float32)
depth_8bit = (depth_norm * 255).astype(np.uint8)
# 6. 儲存
os.makedirs("outputs", exist_ok=True)
base = os.path.splitext(os.path.basename(image_path))[0]
out_path = f"outputs/{base}_depth.png"
cv2.imwrite(out_path, depth_8bit)
return depth_8bit, out_path
# ============================================================
# Gradio UI
# ============================================================
def gradio_fn(image, mask):
if image is None:
return None, None
depth_img, out_path = run_inference_and_save_depth(image, mask)
return depth_img, out_path
demo = gr.Interface(
fn=gradio_fn,
inputs=[
gr.Image(label="Input Image", type="filepath"),
gr.Image(label="Optional Mask", type="filepath"),
],
outputs=[
gr.Image(label="Depth (8-bit Grayscale)", type="numpy"),
gr.File(label="Download Depth PNG"),
],
title="DA² — Minimal Depth Demo",
description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).",
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch() |