File size: 5,503 Bytes
6b12a63
e237bf4
 
 
 
6b12a63
e237bf4
59f5df0
6b12a63
 
e9ff5bc
9965e76
9701a8b
05c1031
 
59f5df0
 
05c1031
59f5df0
05c1031
9701a8b
05c1031
 
 
e9ff5bc
 
9701a8b
e9ff5bc
 
 
 
 
05c1031
 
9701a8b
05c1031
 
 
 
 
59f5df0
05c1031
 
 
 
 
 
 
e9ff5bc
59f5df0
05c1031
59f5df0
e9ff5bc
6b12a63
59f5df0
 
 
 
6b12a63
e9ff5bc
59f5df0
 
 
 
9701a8b
 
 
 
 
e237bf4
6b12a63
59f5df0
 
 
 
05c1031
 
e9ff5bc
05c1031
e9ff5bc
9701a8b
59f5df0
05c1031
 
 
 
9701a8b
 
 
 
 
59f5df0
05c1031
9701a8b
e9ff5bc
 
9701a8b
05c1031
59f5df0
e9ff5bc
9701a8b
59f5df0
 
e237bf4
9701a8b
e9ff5bc
05c1031
59f5df0
 
 
 
9701a8b
59f5df0
 
9701a8b
59f5df0
9701a8b
 
 
 
 
 
 
59f5df0
 
 
 
e9ff5bc
 
 
59f5df0
 
9701a8b
e9ff5bc
 
59f5df0
 
e9ff5bc
59f5df0
e9ff5bc
59f5df0
 
 
e9ff5bc
05c1031
 
e9ff5bc
 
6b12a63
 
e9ff5bc
 
 
59f5df0
e9ff5bc
 
59f5df0
e9ff5bc
 
59f5df0
9701a8b
59f5df0
6b12a63
 
e9ff5bc
05c1031
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import cv2
import torch
import numpy as np
import gradio as gr
from contextlib import nullcontext

# DA² utilities
from da2.utils.base import load_config
from da2.utils.model import load_model
from da2.utils.io import read_cv2_image, torch_transform, tensorize

# 引入真正的 Accelerator
from accelerate import Accelerator
from accelerate.logging import get_logger

# ============================================================
# Global Initialization (Run once at startup)
# ============================================================
def initialize_app(config_path="configs/infer.json"):
    # 1. 初始化 Accelerator (解決 logging 報錯)
    accelerator = Accelerator()
    
    # 2. 載入 Config
    config = load_config(config_path)

    # 設定 Logger
    logger = get_logger(__name__, log_level="INFO")
    config.setdefault("env", {})
    config["env"]["logger"] = logger
    config["env"].setdefault("seed", 42)

    accelerator.print(f"Running on device: {accelerator.device}")
    
    # 3. 載入模型 (Global Load,避免每次推論重載)
    model = load_model(config, accelerator)
    model = model.to(accelerator.device)
    model.eval()

    return config, accelerator, model

# 初始化全局變數
try:
    CONFIG, ACCELERATOR, MODEL = initialize_app()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    CONFIG, ACCELERATOR, MODEL = None, None, None

# ============================================================
# Mask loader
# ============================================================
def read_mask_demo(mask_path, img_shape):
    if mask_path is None:
        if len(img_shape) == 3:
            return np.ones((img_shape[1], img_shape[2]), dtype=bool)
        return np.ones(img_shape[:2], dtype=bool)

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        if len(img_shape) == 3:
            return np.ones((img_shape[1], img_shape[2]), dtype=bool)
        return np.ones(img_shape[:2], dtype=bool)

    # 確保 Mask 尺寸也跟隨 Tensor (如果需要的話),但通常 Mask 是跟原圖
    # 這裡簡單處理,如果尺寸不對稍微 resize 一下避免崩潰
    if mask.shape[:2] != img_shape[-2:]:
         mask = cv2.resize(mask, (img_shape[-1], img_shape[-2]), interpolation=cv2.INTER_NEAREST)

    return mask > 0

# ============================================================
# Core inference function
# ============================================================
def run_inference_and_save_depth(image_path, mask_path=None):
    if MODEL is None:
        return None, "Error: Model not loaded."

    device = ACCELERATOR.device

    # 1. 讀取原始圖片
    cv2_img = read_cv2_image(image_path)
    if cv2_img is None:
        print(f"Error reading image: {image_path}")
        return None, None

    # 【關鍵修復】獲取原始圖片的寬高 (Height, Width)
    # cv2 shape 是 (H, W, C)
    original_h, original_w = cv2_img.shape[:2]

    # 2. 轉換為 Tensor (這步通常會 resize 成 518x518 或其他模型輸入尺寸)
    img_tensor = torch_transform(cv2_img)  # CxHxW tensor
    
    # 處理 Mask
    mask = read_mask_demo(mask_path, img_tensor.shape)

    # 準備輸入
    model_dtype = CONFIG.get("spherevit", {}).get("dtype", "float32")
    input_tensor = tensorize(img_tensor, model_dtype, device)

    # Autocast
    use_autocast = (device.type == "cuda")
    autocast_ctx = torch.autocast(device_type="cuda") if use_autocast else nullcontext()

    # 3. 推論
    with autocast_ctx, torch.no_grad():
        pred = MODEL(input_tensor)

        if isinstance(pred, (tuple, list)):
            pred = pred[0]

        # 轉回 Numpy float32
        depth = pred.float().cpu().numpy()

    # 4. 後處理
    depth = np.squeeze(depth)

    # 【關鍵修復】將深度圖 Resize 回原始尺寸
    # cv2.resize 接受的參數是 (Width, Height)
    if (depth.shape[0] != original_h) or (depth.shape[1] != original_w):
        depth = cv2.resize(depth, (original_w, original_h), interpolation=cv2.INTER_CUBIC)

    # 5. 正規化 (Normalization) -> 8-bit
    dmin, dmax = float(np.nanmin(depth)), float(np.nanmax(depth))

    if dmax - dmin > 1e-6:
        depth_norm = (depth - dmin) / (dmax - dmin)
    else:
        depth_norm = np.zeros_like(depth, dtype=np.float32)

    depth_8bit = (depth_norm * 255).astype(np.uint8)

    # 6. 儲存
    os.makedirs("outputs", exist_ok=True)
    base = os.path.splitext(os.path.basename(image_path))[0]
    out_path = f"outputs/{base}_depth.png"
    cv2.imwrite(out_path, depth_8bit)

    return depth_8bit, out_path

# ============================================================
# Gradio UI
# ============================================================
def gradio_fn(image, mask):
    if image is None:
        return None, None
    depth_img, out_path = run_inference_and_save_depth(image, mask)
    return depth_img, out_path

demo = gr.Interface(
    fn=gradio_fn,
    inputs=[
        gr.Image(label="Input Image", type="filepath"),
        gr.Image(label="Optional Mask", type="filepath"),
    ],
    outputs=[
        gr.Image(label="Depth (8-bit Grayscale)", type="numpy"),
        gr.File(label="Download Depth PNG"),
    ],
    title="DA² — Minimal Depth Demo",
    description="Upload an image (and optional mask) -> outputs an 8-bit grayscale depth PNG (Resized to Original).",
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.launch()