Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,454 Bytes
e0f1d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 |
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import tempfile
from gradio_imageslider import ImageSlider
from depth_anything.dpt import DepthAnything
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Model configurations - supports different model variants
MODEL_CONFIGS = {
"vits14": {
"model_name": "LiheYoung/depth_anything_vits14",
"display_name": "Depth Anything ViT-S (Small, Fastest)",
"description": "Smallest and fastest model variant"
},
"vitb14": {
"model_name": "LiheYoung/depth_anything_vitb14",
"display_name": "Depth Anything ViT-B (Base, Balanced)",
"description": "Balanced model with good speed/quality tradeoff"
},
"vitl14": {
"model_name": "LiheYoung/depth_anything_vitl14",
"display_name": "Depth Anything ViT-L (Large, Best Quality)",
"description": "Largest model with best quality (default)"
}
}
# Global model cache
current_model = None
current_model_name = None
cached_models = {} # Store all downloaded models
title = "# Depth Anything with Model Selection"
description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data** with multiple model variants.
You can choose between different model sizes for speed vs quality tradeoffs:
- **ViT-S**: Fastest inference, good for real-time applications
- **ViT-B**: Balanced performance and quality
- **ViT-L**: Best quality, slower inference
Please refer to our [paper](https://arxiv.org/abs/2401.10891), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details."""
transform = Compose([
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
def get_memory_status():
"""Get current memory usage status"""
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
cached = torch.cuda.memory_reserved() / 1024**3 # GB
total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total_memory:.2f}GB total"
else:
return "Running on CPU"
except:
return "Memory status unavailable"
def download_all_models():
"""Download and cache all model variants at startup"""
global cached_models
print("π Downloading all Depth Anything model variants...")
print("This may take a few minutes depending on your internet connection...")
for key, config in MODEL_CONFIGS.items():
try:
print(f"π₯ Downloading {config['display_name']}...")
model = DepthAnything.from_pretrained(config['model_name']).to(DEVICE).eval()
cached_models[key] = model
print(f"β
{config['display_name']} downloaded and cached successfully")
except Exception as e:
print(f"β Failed to download {config['display_name']}: {e}")
cached_models[key] = None
print(f"π Model download complete! {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)} models cached successfully.")
return cached_models
def load_model(model_selection):
"""Load the selected model variant from cache"""
global current_model, current_model_name
# Find the model key from the display name
selected_key = None
for key, config in MODEL_CONFIGS.items():
if config["display_name"] == model_selection:
selected_key = key
break
if selected_key is None:
# Fallback to vitl14 if not found
selected_key = "vitl14"
# Check if we need to switch to a different model
if current_model_name != selected_key:
print(f"π Switching to model: {MODEL_CONFIGS[selected_key]['display_name']}")
# Get model from cache
if selected_key in cached_models and cached_models[selected_key] is not None:
current_model = cached_models[selected_key]
current_model_name = selected_key
print(f"β
Model {selected_key} loaded from cache successfully")
else:
# Fallback: download model if not in cache
print(f"β οΈ Model {selected_key} not in cache, downloading...")
try:
current_model = DepthAnything.from_pretrained(MODEL_CONFIGS[selected_key]['model_name']).to(DEVICE).eval()
cached_models[selected_key] = current_model
current_model_name = selected_key
print(f"β
Model {selected_key} downloaded and loaded successfully")
except Exception as e:
print(f"β Failed to load model {selected_key}: {e}")
# Fallback to any available cached model
for fallback_key, fallback_model in cached_models.items():
if fallback_model is not None:
current_model = fallback_model
current_model_name = fallback_key
print(f"π Using fallback model: {fallback_key}")
break
return current_model
@torch.no_grad()
def predict_depth(model, image):
return model(image)
def on_submit(model_selection, image):
if image is None:
return None, None
# Load the selected model
try:
model = load_model(model_selection)
except Exception as e:
print(f"Error loading model: {e}")
return None, None
original_image = image.copy()
h, w = image.shape[:2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
depth = predict_depth(model, image)
depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp.name)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
return [(original_image, colored_depth), tmp.name]
# Download and cache all models at startup
print("π Initializing Depth Anything with all model variants...")
cached_models = download_all_models()
# Set default model to the first successfully cached model
default_model_key = None
for key in ["vitl14", "vitb14", "vits14"]: # Priority order
if key in cached_models and cached_models[key] is not None:
default_model_key = key
break
if default_model_key:
current_model = cached_models[default_model_key]
current_model_name = default_model_key
print(f"π― Default model set to: {MODEL_CONFIGS[default_model_key]['display_name']}")
else:
print("β No models were successfully cached!")
current_model = None
current_model_name = None
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown("### Model Selection")
model_selector = gr.Dropdown(
choices=[config["display_name"] for config in MODEL_CONFIGS.values()],
value=MODEL_CONFIGS[default_model_key]["display_name"] if default_model_key else MODEL_CONFIGS["vitl14"]["display_name"],
label="Choose Model Variant",
info="Select the model size based on your speed/quality requirements"
)
# Add model info display
initial_info = f"**Selected Model**: {MODEL_CONFIGS[default_model_key]['description']}" if default_model_key else "**Selected Model**: Unknown"
model_info = gr.Markdown(initial_info)
# Add memory status display
memory_status = gr.Markdown(f"**Memory Status**: {get_memory_status()}")
def update_model_info(selection):
info_text = "**Selected Model**: Unknown"
for key, config in MODEL_CONFIGS.items():
if config["display_name"] == selection:
cached_status = "β
Cached" if key in cached_models and cached_models[key] is not None else "β Not cached"
info_text = f"**Selected Model**: {config['description']} ({cached_status})"
break
memory_text = f"**Memory Status**: {get_memory_status()}"
return info_text, memory_text
model_selector.change(update_model_info, inputs=[model_selector], outputs=[model_info, memory_status])
gr.Markdown("### Depth Prediction Demo")
gr.Markdown("You can slide the output to compare the depth prediction with input image")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)")
submit = gr.Button("Submit", variant="primary")
submit.click(on_submit, inputs=[model_selector, input_image], outputs=[depth_image_slider, raw_file])
# Examples section
if os.path.exists('assets/examples'):
example_files = os.listdir('assets/examples')
example_files.sort()
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
examples = gr.Examples(
examples=example_files,
inputs=[input_image],
outputs=[depth_image_slider, raw_file],
fn=lambda img: on_submit(model_selector.value, img),
cache_examples=False,
label="Example Images"
)
# Model comparison section
with gr.Accordion("π Model Comparison & Cache Status", open=False):
# Create cache status dynamically
cache_status_md = "### π¦ Cached Models Status\n"
for key, config in MODEL_CONFIGS.items():
status = "β
Cached" if key in cached_models and cached_models[key] is not None else "β Not cached"
cache_status_md += f"- **{config['display_name']}**: {status}\n"
cache_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n"
cache_status_md += f"**Current Memory**: {get_memory_status()}\n\n"
gr.Markdown(cache_status_md)
gr.Markdown("""
### π Model Performance Comparison
| Model | Parameters | Speed | Quality | Use Case |
|-------|------------|-------|---------|----------|
| ViT-S | ~25M | Fastest | Good | Real-time applications |
| ViT-B | ~97M | Medium | Better | Balanced performance |
| ViT-L | ~335M | Slower | Best | High-quality results |
**Note**: All models are pre-downloaded and cached for instant switching!
**Processing times** are approximate and depend on hardware and image resolution.
""")
# Add refresh button for memory status
def refresh_status():
updated_status_md = "### π¦ Cached Models Status\n"
for key, config in MODEL_CONFIGS.items():
status = "β
Cached" if key in cached_models and cached_models[key] is not None else "β Not cached"
updated_status_md += f"- **{config['display_name']}**: {status}\n"
updated_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n"
updated_status_md += f"**Current Memory**: {get_memory_status()}\n\n"
return updated_status_md
refresh_btn = gr.Button("π Refresh Status", size="sm")
status_display = gr.Markdown(cache_status_md)
refresh_btn.click(refresh_status, outputs=[status_display])
# Citation section
with gr.Accordion("π Citation", open=False):
gr.Markdown("""
```bibtex
@article{depthanything,
title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
journal={arXiv:2401.10891},
year={2024}
}
```
""")
if __name__ == '__main__':
demo.queue().launch()
|