File size: 13,454 Bytes
e0f1d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import tempfile
from gradio_imageslider import ImageSlider

from depth_anything.dpt import DepthAnything
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet

css = """
#img-display-container {
    max-height: 100vh;
    }
#img-display-input {
    max-height: 80vh;
    }
#img-display-output {
    max-height: 80vh;
    }
"""

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model configurations - supports different model variants
MODEL_CONFIGS = {
    "vits14": {
        "model_name": "LiheYoung/depth_anything_vits14",
        "display_name": "Depth Anything ViT-S (Small, Fastest)",
        "description": "Smallest and fastest model variant"
    },
    "vitb14": {
        "model_name": "LiheYoung/depth_anything_vitb14", 
        "display_name": "Depth Anything ViT-B (Base, Balanced)",
        "description": "Balanced model with good speed/quality tradeoff"
    },
    "vitl14": {
        "model_name": "LiheYoung/depth_anything_vitl14",
        "display_name": "Depth Anything ViT-L (Large, Best Quality)",
        "description": "Largest model with best quality (default)"
    }
}

# Global model cache
current_model = None
current_model_name = None
cached_models = {}  # Store all downloaded models

title = "# Depth Anything with Model Selection"
description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data** with multiple model variants.

You can choose between different model sizes for speed vs quality tradeoffs:
- **ViT-S**: Fastest inference, good for real-time applications
- **ViT-B**: Balanced performance and quality 
- **ViT-L**: Best quality, slower inference

Please refer to our [paper](https://arxiv.org/abs/2401.10891), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details."""

transform = Compose([
        Resize(
            width=518,
            height=518,
            resize_target=False,
            keep_aspect_ratio=True,
            ensure_multiple_of=14,
            resize_method='lower_bound',
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet(),
])

def get_memory_status():
    """Get current memory usage status"""
    try:
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            cached = torch.cuda.memory_reserved() / 1024**3  # GB
            total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3  # GB
            return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total_memory:.2f}GB total"
        else:
            return "Running on CPU"
    except:
        return "Memory status unavailable"

def download_all_models():
    """Download and cache all model variants at startup"""
    global cached_models
    
    print("πŸ”„ Downloading all Depth Anything model variants...")
    print("This may take a few minutes depending on your internet connection...")
    
    for key, config in MODEL_CONFIGS.items():
        try:
            print(f"πŸ“₯ Downloading {config['display_name']}...")
            model = DepthAnything.from_pretrained(config['model_name']).to(DEVICE).eval()
            cached_models[key] = model
            print(f"βœ… {config['display_name']} downloaded and cached successfully")
        except Exception as e:
            print(f"❌ Failed to download {config['display_name']}: {e}")
            cached_models[key] = None
    
    print(f"πŸŽ‰ Model download complete! {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)} models cached successfully.")
    return cached_models

def load_model(model_selection):
    """Load the selected model variant from cache"""
    global current_model, current_model_name
    
    # Find the model key from the display name
    selected_key = None
    for key, config in MODEL_CONFIGS.items():
        if config["display_name"] == model_selection:
            selected_key = key
            break
    
    if selected_key is None:
        # Fallback to vitl14 if not found
        selected_key = "vitl14"
    
    # Check if we need to switch to a different model
    if current_model_name != selected_key:
        print(f"πŸ”„ Switching to model: {MODEL_CONFIGS[selected_key]['display_name']}")
        
        # Get model from cache
        if selected_key in cached_models and cached_models[selected_key] is not None:
            current_model = cached_models[selected_key]
            current_model_name = selected_key
            print(f"βœ… Model {selected_key} loaded from cache successfully")
        else:
            # Fallback: download model if not in cache
            print(f"⚠️ Model {selected_key} not in cache, downloading...")
            try:
                current_model = DepthAnything.from_pretrained(MODEL_CONFIGS[selected_key]['model_name']).to(DEVICE).eval()
                cached_models[selected_key] = current_model
                current_model_name = selected_key
                print(f"βœ… Model {selected_key} downloaded and loaded successfully")
            except Exception as e:
                print(f"❌ Failed to load model {selected_key}: {e}")
                # Fallback to any available cached model
                for fallback_key, fallback_model in cached_models.items():
                    if fallback_model is not None:
                        current_model = fallback_model
                        current_model_name = fallback_key
                        print(f"πŸ”„ Using fallback model: {fallback_key}")
                        break
    
    return current_model

@torch.no_grad()
def predict_depth(model, image):
    return model(image)

def on_submit(model_selection, image):
    if image is None:
        return None, None
    
    # Load the selected model
    try:
        model = load_model(model_selection)
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None
    
    original_image = image.copy()

    h, w = image.shape[:2]

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
    image = transform({'image': image})['image']
    image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

    depth = predict_depth(model, image)
    depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]

    raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
    tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
    raw_depth.save(tmp.name)

    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.cpu().numpy().astype(np.uint8)
    colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]

    return [(original_image, colored_depth), tmp.name]

# Download and cache all models at startup
print("πŸš€ Initializing Depth Anything with all model variants...")
cached_models = download_all_models()

# Set default model to the first successfully cached model
default_model_key = None
for key in ["vitl14", "vitb14", "vits14"]:  # Priority order
    if key in cached_models and cached_models[key] is not None:
        default_model_key = key
        break

if default_model_key:
    current_model = cached_models[default_model_key]
    current_model_name = default_model_key
    print(f"🎯 Default model set to: {MODEL_CONFIGS[default_model_key]['display_name']}")
else:
    print("❌ No models were successfully cached!")
    current_model = None
    current_model_name = None

with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Model Selection")
            model_selector = gr.Dropdown(
                choices=[config["display_name"] for config in MODEL_CONFIGS.values()],
                value=MODEL_CONFIGS[default_model_key]["display_name"] if default_model_key else MODEL_CONFIGS["vitl14"]["display_name"],
                label="Choose Model Variant",
                info="Select the model size based on your speed/quality requirements"
            )
            
            # Add model info display
            initial_info = f"**Selected Model**: {MODEL_CONFIGS[default_model_key]['description']}" if default_model_key else "**Selected Model**: Unknown"
            model_info = gr.Markdown(initial_info)
            
            # Add memory status display
            memory_status = gr.Markdown(f"**Memory Status**: {get_memory_status()}")
            
            def update_model_info(selection):
                info_text = "**Selected Model**: Unknown"
                for key, config in MODEL_CONFIGS.items():
                    if config["display_name"] == selection:
                        cached_status = "βœ… Cached" if key in cached_models and cached_models[key] is not None else "❌ Not cached"
                        info_text = f"**Selected Model**: {config['description']} ({cached_status})"
                        break
                
                memory_text = f"**Memory Status**: {get_memory_status()}"
                return info_text, memory_text
            
            model_selector.change(update_model_info, inputs=[model_selector], outputs=[model_info, memory_status])

    gr.Markdown("### Depth Prediction Demo")
    gr.Markdown("You can slide the output to compare the depth prediction with input image")

    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
    
    raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)")
    
    submit = gr.Button("Submit", variant="primary")

    submit.click(on_submit, inputs=[model_selector, input_image], outputs=[depth_image_slider, raw_file])
    
    # Examples section
    if os.path.exists('assets/examples'):
        example_files = os.listdir('assets/examples')
        example_files.sort()
        example_files = [os.path.join('assets/examples', filename) for filename in example_files]
        
        examples = gr.Examples(
            examples=example_files, 
            inputs=[input_image], 
            outputs=[depth_image_slider, raw_file], 
            fn=lambda img: on_submit(model_selector.value, img), 
            cache_examples=False,
            label="Example Images"
        )
    
    # Model comparison section
    with gr.Accordion("πŸ“Š Model Comparison & Cache Status", open=False):
        # Create cache status dynamically
        cache_status_md = "### πŸ“¦ Cached Models Status\n"
        for key, config in MODEL_CONFIGS.items():
            status = "βœ… Cached" if key in cached_models and cached_models[key] is not None else "❌ Not cached"
            cache_status_md += f"- **{config['display_name']}**: {status}\n"
        
        cache_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n"
        cache_status_md += f"**Current Memory**: {get_memory_status()}\n\n"
        
        gr.Markdown(cache_status_md)
        
        gr.Markdown("""
        ### πŸ“ˆ Model Performance Comparison
        | Model | Parameters | Speed | Quality | Use Case |
        |-------|------------|-------|---------|----------|
        | ViT-S | ~25M | Fastest | Good | Real-time applications |
        | ViT-B | ~97M | Medium | Better | Balanced performance |
        | ViT-L | ~335M | Slower | Best | High-quality results |
        
        **Note**: All models are pre-downloaded and cached for instant switching!  
        **Processing times** are approximate and depend on hardware and image resolution.
        """)
        
        # Add refresh button for memory status
        def refresh_status():
            updated_status_md = "### πŸ“¦ Cached Models Status\n"
            for key, config in MODEL_CONFIGS.items():
                status = "βœ… Cached" if key in cached_models and cached_models[key] is not None else "❌ Not cached"
                updated_status_md += f"- **{config['display_name']}**: {status}\n"
            
            updated_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n"
            updated_status_md += f"**Current Memory**: {get_memory_status()}\n\n"
            return updated_status_md
        
        refresh_btn = gr.Button("πŸ”„ Refresh Status", size="sm")
        status_display = gr.Markdown(cache_status_md)
        refresh_btn.click(refresh_status, outputs=[status_display])
    
    # Citation section
    with gr.Accordion("πŸ“– Citation", open=False):
        gr.Markdown("""
        ```bibtex
        @article{depthanything,
            title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data}, 
            author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
            journal={arXiv:2401.10891},
            year={2024}
        }
        ```
        """)

if __name__ == '__main__':
    demo.queue().launch()