DA-2 / app.py
fcu52005505's picture
Update app.py
9701a8b verified
import os
import cv2
import torch
import numpy as np
import gradio as gr
from contextlib import nullcontext
# DA² utilities
from da2.utils.base import load_config
from da2.utils.model import load_model
from da2.utils.io import read_cv2_image, torch_transform, tensorize
# 引入真正的 Accelerator
from accelerate import Accelerator
from accelerate.logging import get_logger
# ============================================================
# Global Initialization (Run once at startup)
# ============================================================
def initialize_app(config_path="configs/infer.json"):
# 1. 初始化 Accelerator (解決 logging 報錯)
accelerator = Accelerator()
# 2. 載入 Config
config = load_config(config_path)
# 設定 Logger
logger = get_logger(__name__, log_level="INFO")
config.setdefault("env", {})
config["env"]["logger"] = logger
config["env"].setdefault("seed", 42)
accelerator.print(f"Running on device: {accelerator.device}")
# 3. 載入模型 (Global Load,避免每次推論重載)
model = load_model(config, accelerator)
model = model.to(accelerator.device)
model.eval()
return config, accelerator, model
# 初始化全局變數
try:
CONFIG, ACCELERATOR, MODEL = initialize_app()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
CONFIG, ACCELERATOR, MODEL = None, None, None
# ============================================================
# Mask loader
# ============================================================
def read_mask_demo(mask_path, img_shape):
if mask_path is None:
if len(img_shape) == 3:
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
return np.ones(img_shape[:2], dtype=bool)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
if len(img_shape) == 3:
return np.ones((img_shape[1], img_shape[2]), dtype=bool)
return np.ones(img_shape[:2], dtype=bool)
# 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖
# 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰
if mask.shape[:2] != img_shape[-2:]:
mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST)
return mask > 0
# ============================================================
# Core inference function
# ============================================================
def run_inference_and_save_depth(image_path, mask_path=None):
if MODEL is None:
return None, "Error: Model not loaded."
device = ACCELERATOR.device
# 1. 讀取原始圖片
cv2_img = read_cv2_image(image_path)
if cv2_img is None:
print(f"Error reading image: {image_path}")
return None, None
# 【關鍵修復】獲取原始圖片的寬高 (Height, Width)
# cv2 shape 是 (H, W, C)
original_h, original_w = cv2_img.shape[:2]
# 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸)
img_tensor = torch_transform(cv2_img) # CxHxW tensor
# 處理 Mask
mask = read_mask_demo(mask_path, img_tensor.shape)
# 準備輸入
model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
input_tensor = tensorize(img_tensor, model_dtype, device)
# Autocast
use_autocast = (device.type == "cuda")
autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()
# 3. 推論
with autocast_ctx, torch.no_grad():
pred = MODEL(input_tensor)
if isinstance(pred, (tuple, list)):
pred = pred[0]
# 轉回 Numpy float32
depth = pred.float().cpu().numpy()
# 4. 後處理
depth = np.squeeze(depth)
# 【關鍵修復】將深度圖 Resize 回原始尺寸
# cv2.resize 接受的參數是 (Width, Height)
if (depth.shape[0] != original_h) or (depth.shape[1] != original_w):
depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC)
# 5. 正規化 (Normalization) -> 8-bit
dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))
if dmax - dmin > 1e-6:
depth_norm = (depth - dmin) / (dmax - dmin)
else:
depth_norm = np.zeros_like(depth, dtype=np.float32)
depth_8bit = (depth_norm * 255).astype(np.uint8)
# 6. 儲存
os.makedirs("outputs", exist_ok=True)
base = os.path.splitext(os.path.basename(image_path))[0]
out_path = f"outputs/{base}_depth.png"
cv2.imwrite(out_path, depth_8bit)
return depth_8bit, out_path
# ============================================================
# Gradio UI
# ============================================================
def gradio_fn(image, mask):
if image is None:
return None, None
depth_img, out_path = run_inference_and_save_depth(image, mask)
return depth_img, out_path
demo = gr.Interface(
fn=gradio_fn,
inputs=[
gr.Image(label="Input Image", type="filepath"),
gr.Image(label="Optional Mask", type="filepath"),
],
outputs=[
gr.Image(label="Depth (8-bit Grayscale)", type="numpy"),
gr.File(label="Download Depth PNG"),
],
title="DA² — Minimal Depth Demo",
description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).",
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()