File size: 7,558 Bytes
7382c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
    InitProcessGroupKwargs, 
    ProjectConfiguration, 
    set_seed
)
import torch
from contextlib import nullcontext
import trimesh
import gradio as gr
from gradio_imageslider import ImageSlider
from da2.utils.base import load_config
from da2.utils.model import load_model
from da2.utils.io import (
    read_cv2_image,
    torch_transform,
    tensorize
)
from da2.utils.vis import colorize_distance
from da2.utils.d2pc import distance2pointcloud
from datetime import (
    timedelta,
    datetime
)
import cv2
import numpy as np

last_glb_path = None

def prepare_to_run_demo():
    config = load_config('configs/infer.json')
    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
    output_dir = f'output/infer'
    if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
    accu_steps = config['accelerator']['accumulation_nsteps']
    accelerator = Accelerator(
        gradient_accumulation_steps=accu_steps,
        mixed_precision=config['accelerator']['mixed_precision'],
        log_with=config['accelerator']['report_to'],
        project_config=ProjectConfiguration(project_dir=output_dir),
        kwargs_handlers=[kwargs]
    )
    logger = get_logger(__name__, log_level='INFO')
    config['env']['logger'] = logger
    set_seed(config['env']['seed'])
    return config, accelerator

def read_mask_demo(mask_path, shape):
    if mask_path is None:
        return np.ones(shape[1:]) > 0
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    mask = mask > 0
    return mask

def load_infer_data_demo(image, mask, model_dtype, device):
    cv2_image = read_cv2_image(image)
    image = torch_transform(cv2_image)
    mask = read_mask_demo(mask, image.shape)
    image = tensorize(image, model_dtype, device)
    return image, cv2_image, mask

def ply2glb(ply_path, glb_path):
    pcd = trimesh.load(ply_path)
    points = np.asarray(pcd.vertices)
    colors = np.asarray(pcd.visual.vertex_colors)
    cloud = trimesh.points.PointCloud(vertices=points, colors=colors)
    cloud.export(glb_path)
    os.remove(ply_path)

def fn(image_path, mask_path):
    global last_glb_path
    config, accelerator = prepare_to_run_demo()
    model = load_model(config, accelerator)
    image, cv2_image, mask = load_infer_data_demo(image_path, mask_path, 
        model_dtype=config['spherevit']['dtype'], device=accelerator.device)
    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(accelerator.device.type)
    with autocast_ctx, torch.no_grad():
        distance = model(image).cpu().numpy()[0]
        if last_glb_path is not None:
            os.remove(last_glb_path)
        distance_vis = colorize_distance(distance, mask)
        save_path = f'cache/tmp_{datetime.now().strftime("%Y%m%d_%H%M%S")}.glb'
        last_glb_path = save_path
        normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False)
        ply2glb(save_path.replace('.glb', '.ply'), save_path)
        return save_path, [distance_vis, normal_image]

inputs = [
    gr.Image(label="Input Image", type="filepath"),
    gr.Image(label="Input Mask", type="filepath"),
]
outputs = [
    gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0],  label="3D Point Cloud"),
    gr.ImageSlider(
        label="Output Depth / Normal (transformed from the depth)",
        type="pil",
        slider_position=75,
    )
]

demo = gr.Interface(
    fn=fn,
    title="DA<sup>2</sup>: <u>D</u>epth <u>A</u>nything in <u>A</u>ny <u>D</u>irection",
    description="""
        <p align="center">
        <a title="Project Page" href="https://depth-any-in-any-dir.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white">
        </a>
        <a title="arXiv" href="http://arxiv.org/abs/2509.26618" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white">
        </a>
        <a title="Github" href="https://github.com/EnVision-Research/DA-2" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://img.shields.io/github/stars/EnVision-Research/DA-2?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
        </a>
        <a title="Social" href="https://x.com/_akhaliq/status/1973283687652606411" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
        </a>
        <a title="Social" href="https://x.com/haodongli00/status/1973287870317338747" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
            <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
        </a>
        <br>
        <strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/DA-2" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful!</strong>
        </p>
        <p><strong>Note: the "Input Mask" is optional, all pixels are assumed to be valid if mask is None.</strong></p>
    """,
    inputs=inputs,
    outputs=outputs,
    examples=[
        [os.path.join(os.path.dirname(__file__), "assets/demos/a1.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a2.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a3.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a4.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b0.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b0.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b1.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b1.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a5.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a6.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a7.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a8.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b2.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b2.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b3.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b3.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a9.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a10.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a11.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/a0.png"), None],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b4.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b4.png")],
        [os.path.join(os.path.dirname(__file__), "assets/demos/b5.png"), 
         os.path.join(os.path.dirname(__file__), "assets/masks/b5.png")],
    ],
    examples_per_page=20
)

demo.launch(
        server_name="0.0.0.0",
        server_port=6381,
)