shriarul5273's picture
initial commit of depth-anything compare
e0f1d2e
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
import tempfile
from gradio_imageslider import ImageSlider
from depth_anything.dpt import DepthAnything
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Model configurations - supports different model variants
MODEL_CONFIGS = {
"vits14": {
"model_name": "LiheYoung/depth_anything_vits14",
"display_name": "Depth Anything ViT-S (Small, Fastest)",
"description": "Smallest and fastest model variant"
},
"vitb14": {
"model_name": "LiheYoung/depth_anything_vitb14",
"display_name": "Depth Anything ViT-B (Base, Balanced)",
"description": "Balanced model with good speed/quality tradeoff"
},
"vitl14": {
"model_name": "LiheYoung/depth_anything_vitl14",
"display_name": "Depth Anything ViT-L (Large, Best Quality)",
"description": "Largest model with best quality (default)"
}
}
# Global model cache
current_model = None
current_model_name = None
cached_models = {} # Store all downloaded models
title = "# Depth Anything with Model Selection"
description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data** with multiple model variants.
You can choose between different model sizes for speed vs quality tradeoffs:
- **ViT-S**: Fastest inference, good for real-time applications
- **ViT-B**: Balanced performance and quality
- **ViT-L**: Best quality, slower inference
Please refer to our [paper](https://arxiv.org/abs/2401.10891), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details."""
transform = Compose([
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
def get_memory_status():
"""Get current memory usage status"""
try:
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
cached = torch.cuda.memory_reserved() / 1024**3 # GB
total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total_memory:.2f}GB total"
else:
return "Running on CPU"
except:
return "Memory status unavailable"
def download_all_models():
"""Download and cache all model variants at startup"""
global cached_models
print("πŸ”„ Downloading all Depth Anything model variants...")
print("This may take a few minutes depending on your internet connection...")
for key, config in MODEL_CONFIGS.items():
try:
print(f"πŸ“₯ Downloading {config['display_name']}...")
model = DepthAnything.from_pretrained(config['model_name']).to(DEVICE).eval()
cached_models[key] = model
print(f"βœ… {config['display_name']} downloaded and cached successfully")
except Exception as e:
print(f"❌ Failed to download {config['display_name']}: {e}")
cached_models[key] = None
print(f"πŸŽ‰ Model download complete! {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)} models cached successfully.")
return cached_models
def load_model(model_selection):
"""Load the selected model variant from cache"""
global current_model, current_model_name
# Find the model key from the display name
selected_key = None
for key, config in MODEL_CONFIGS.items():
if config["display_name"] == model_selection:
selected_key = key
break
if selected_key is None:
# Fallback to vitl14 if not found
selected_key = "vitl14"
# Check if we need to switch to a different model
if current_model_name != selected_key:
print(f"πŸ”„ Switching to model: {MODEL_CONFIGS[selected_key]['display_name']}")
# Get model from cache
if selected_key in cached_models and cached_models[selected_key] is not None:
current_model = cached_models[selected_key]
current_model_name = selected_key
print(f"βœ… Model {selected_key} loaded from cache successfully")
else:
# Fallback: download model if not in cache
print(f"⚠️ Model {selected_key} not in cache, downloading...")
try:
current_model = DepthAnything.from_pretrained(MODEL_CONFIGS[selected_key]['model_name']).to(DEVICE).eval()
cached_models[selected_key] = current_model
current_model_name = selected_key
print(f"βœ… Model {selected_key} downloaded and loaded successfully")
except Exception as e:
print(f"❌ Failed to load model {selected_key}: {e}")
# Fallback to any available cached model
for fallback_key, fallback_model in cached_models.items():
if fallback_model is not None:
current_model = fallback_model
current_model_name = fallback_key
print(f"πŸ”„ Using fallback model: {fallback_key}")
break
return current_model
@torch.no_grad()
def predict_depth(model, image):
return model(image)
def on_submit(model_selection, image):
if image is None:
return None, None
# Load the selected model
try:
model = load_model(model_selection)
except Exception as e:
print(f"Error loading model: {e}")
return None, None
original_image = image.copy()
h, w = image.shape[:2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
depth = predict_depth(model, image)
depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp.name)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
return [(original_image, colored_depth), tmp.name]
# Download and cache all models at startup
print("πŸš€ Initializing Depth Anything with all model variants...")
cached_models = download_all_models()
# Set default model to the first successfully cached model
default_model_key = None
for key in ["vitl14", "vitb14", "vits14"]: # Priority order
if key in cached_models and cached_models[key] is not None:
default_model_key = key
break
if default_model_key:
current_model = cached_models[default_model_key]
current_model_name = default_model_key
print(f"🎯 Default model set to: {MODEL_CONFIGS[default_model_key]['display_name']}")
else:
print("❌ No models were successfully cached!")
current_model = None
current_model_name = None
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown("### Model Selection")
model_selector = gr.Dropdown(
choices=[config["display_name"] for config in MODEL_CONFIGS.values()],
value=MODEL_CONFIGS[default_model_key]["display_name"] if default_model_key else MODEL_CONFIGS["vitl14"]["display_name"],
label="Choose Model Variant",
info="Select the model size based on your speed/quality requirements"
)
# Add model info display
initial_info = f"**Selected Model**: {MODEL_CONFIGS[default_model_key]['description']}" if default_model_key else "**Selected Model**: Unknown"
model_info = gr.Markdown(initial_info)
# Add memory status display
memory_status = gr.Markdown(f"**Memory Status**: {get_memory_status()}")
def update_model_info(selection):
info_text = "**Selected Model**: Unknown"
for key, config in MODEL_CONFIGS.items():
if config["display_name"] == selection:
cached_status = "βœ… Cached" if key in cached_models and cached_models[key] is not None else "❌ Not cached"
info_text = f"**Selected Model**: {config['description']} ({cached_status})"
break
memory_text = f"**Memory Status**: {get_memory_status()}"
return info_text, memory_text
model_selector.change(update_model_info, inputs=[model_selector], outputs=[model_info, memory_status])
gr.Markdown("### Depth Prediction Demo")
gr.Markdown("You can slide the output to compare the depth prediction with input image")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)")
submit = gr.Button("Submit", variant="primary")
submit.click(on_submit, inputs=[model_selector, input_image], outputs=[depth_image_slider, raw_file])
# Examples section
if os.path.exists('assets/examples'):
example_files = os.listdir('assets/examples')
example_files.sort()
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
examples = gr.Examples(
examples=example_files,
inputs=[input_image],
outputs=[depth_image_slider, raw_file],
fn=lambda img: on_submit(model_selector.value, img),
cache_examples=False,
label="Example Images"
)
# Model comparison section
with gr.Accordion("πŸ“Š Model Comparison & Cache Status", open=False):
# Create cache status dynamically
cache_status_md = "### πŸ“¦ Cached Models Status\n"
for key, config in MODEL_CONFIGS.items():
status = "βœ… Cached" if key in cached_models and cached_models[key] is not None else "❌ Not cached"
cache_status_md += f"- **{config['display_name']}**: {status}\n"
cache_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n"
cache_status_md += f"**Current Memory**: {get_memory_status()}\n\n"
gr.Markdown(cache_status_md)
gr.Markdown("""
### πŸ“ˆ Model Performance Comparison
| Model | Parameters | Speed | Quality | Use Case |
|-------|------------|-------|---------|----------|
| ViT-S | ~25M | Fastest | Good | Real-time applications |
| ViT-B | ~97M | Medium | Better | Balanced performance |
| ViT-L | ~335M | Slower | Best | High-quality results |
**Note**: All models are pre-downloaded and cached for instant switching!
**Processing times** are approximate and depend on hardware and image resolution.
""")
# Add refresh button for memory status
def refresh_status():
updated_status_md = "### πŸ“¦ Cached Models Status\n"
for key, config in MODEL_CONFIGS.items():
status = "βœ… Cached" if key in cached_models and cached_models[key] is not None else "❌ Not cached"
updated_status_md += f"- **{config['display_name']}**: {status}\n"
updated_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n"
updated_status_md += f"**Current Memory**: {get_memory_status()}\n\n"
return updated_status_md
refresh_btn = gr.Button("πŸ”„ Refresh Status", size="sm")
status_display = gr.Markdown(cache_status_md)
refresh_btn.click(refresh_status, outputs=[status_display])
# Citation section
with gr.Accordion("πŸ“– Citation", open=False):
gr.Markdown("""
```bibtex
@article{depthanything,
title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
journal={arXiv:2401.10891},
year={2024}
}
```
""")
if __name__ == '__main__':
demo.queue().launch()