File size: 5,629 Bytes
b15831b
4845d25
b15831b
 
d3d9a93
b15831b
d3d9a93
4845d25
b15831b
 
1ecbfb8
b15831b
 
 
 
d3d9a93
 
b15831b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3d9a93
b15831b
 
d3d9a93
b15831b
d3d9a93
 
b15831b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3d9a93
b15831b
 
d3d9a93
b15831b
 
 
 
 
d3d9a93
 
 
1ecbfb8
b15831b
1ecbfb8
b15831b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# app.py (safe CPU startup for HF Spaces)
import os
import io
import numpy as np
import torch
from PIL import Image
import gradio as gr

# Import the CPU-patched class you added earlier
from depth_anything_3.api import DepthAnything3

# ---------------------------
# Configuration
# ---------------------------
# Keep the same model path you used earlier (default is the one in your logs)
MODEL_DIR = os.environ.get("DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE")

# Lower processing resolution to make CPU inference feasible.
# Increase if you want better quality but expect it to be much slower.
PROCESS_RES = int(os.environ.get("DA3_PROCESS_RES", "384"))

# ---------------------------
# Model loading (CPU)
# ---------------------------
print(f"🔄 Loading DepthAnything3 from '{MODEL_DIR}' on CPU (this may take a moment)...")
# Uses the PyTorchModelHubMixin.from_pretrained you have in the class
model = DepthAnything3.from_pretrained(MODEL_DIR)
model.to(torch.device("cpu"))
model.eval()
print("✅ Model ready on CPU")

# ---------------------------
# Inference helper
# ---------------------------
def _normalize_depth_to_uint8(depth: np.ndarray) -> np.ndarray:
    """Normalize a depth map (H,W) to uint8 grayscale for display."""
    if depth is None:
        return None
    # convert to float
    d = depth.astype(np.float32)
    # clip NaNs / infs
    d = np.nan_to_num(d, nan=0.0, posinf=0.0, neginf=0.0)
    # Normalize robustly: use 1st and 99th percentiles to avoid outliers
    vmin = np.percentile(d, 1.0)
    vmax = np.percentile(d, 99.0)
    if vmax - vmin < 1e-6:
        vmax = vmin + 1.0
    d = (d - vmin) / (vmax - vmin)
    d = np.clip(d, 0.0, 1.0)
    img = (d * 255.0).astype(np.uint8)
    return img

def run_depth(single_img: Image.Image, process_res: int = PROCESS_RES):
    """
    Run single-image depth inference with the patched DepthAnything3 API.
    Returns a grayscale PIL image visualizing depth.
    """
    if single_img is None:
        return None

    # Convert PIL to numpy (DepthAnything3 accepts PIL images)
    try:
        # Use the API's inference function; we pass a list with single image.
        # Keep other args minimal to avoid heavy processing.
        pred = model.inference(
            [single_img],
            process_res=process_res,
            process_res_method="upper_bound_resize",
            export_format="mini_npz",  # minimal export
        )
    except Exception as e:
        # If inference raises, return a helpful message image
        msg = f"Inference error: {e}"
        print(msg)
        # Make a small image with the error text
        err_img = Image.new("RGB", (640, 120), color=(255, 255, 255))
        return err_img

    # Extract depth from Prediction object - handle a few possible shapes / attrs
    depth_map = None
    # First try attribute .depth (common pattern in your code)
    if hasattr(pred, "depth"):
        depth_map = pred.depth
    elif isinstance(pred, dict) and "depth" in pred:
        depth_map = pred["depth"]
    elif hasattr(pred, "predictions") and len(pred.predictions) > 0:
        # fallback: some wrappers store lists
        depth_map = pred.predictions[0].depth if hasattr(pred.predictions[0], "depth") else None

    # depth_map might be (N,H,W) or (H,W)
    if depth_map is None:
        # fallback: try processed_images if available (visual sanity)
        try:
            if hasattr(pred, "processed_images"):
                imgs = pred.processed_images
                if isinstance(imgs, np.ndarray) and imgs.shape[0] > 0:
                    # return first processed image
                    return Image.fromarray((imgs[0] * 255).astype(np.uint8))
        except Exception:
            pass
        # nothing usable
        print("No depth found in prediction; returning empty image.")
        return Image.new("RGB", (640, 480), color=(255, 255, 255))

    # If depth_map is batched, take first
    if isinstance(depth_map, (list, tuple)):
        depth_map = depth_map[0]
    if isinstance(depth_map, np.ndarray) and depth_map.ndim == 3 and depth_map.shape[0] in (1,):
        # shape (1,H,W)
        depth_map = depth_map[0]
    if isinstance(depth_map, torch.Tensor):
        depth_map = depth_map.cpu().numpy()
    # Now depth_map should be (H,W)
    if depth_map.ndim == 3 and depth_map.shape[0] == 3:
        # if somehow 3-channel, convert to single channel by averaging
        depth_map = depth_map.mean(axis=0)

    depth_uint8 = _normalize_depth_to_uint8(depth_map)
    if depth_uint8 is None:
        return Image.new("RGB", (640, 480), color=(255, 255, 255))

    # Return grayscale PIL image
    depth_img = Image.fromarray(depth_uint8, mode="L")
    return depth_img

# ---------------------------
# Gradio interface
# ---------------------------
title = "Depth Anything 3 — CPU (single-image)"
description = (
    "CPU-only minimal interface. Upload a single image and get a quick depth visualization.\n"
    "This Space is intentionally lightweight to allow CPU startup. For better quality/multiview features you need GPU or the full app."
)

# Make the Gradio Interface the top-level `app` variable so HF Spaces detects it
app = gr.Interface(
    fn=run_depth,
    inputs=[
        gr.Image(type="pil", label="Upload image"),
        gr.Slider(minimum=128, maximum=1024, step=64, value=PROCESS_RES, label="Process resolution (smaller = faster)")
    ],
    outputs=gr.Image(label="Predicted depth (grayscale)"),
    title=title,
    description=description,
)

# For local running
if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)