Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Compose | |
| import tempfile | |
| from gradio_imageslider import ImageSlider | |
| from depth_anything.dpt import DepthAnything | |
| from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet | |
| css = """ | |
| #img-display-container { | |
| max-height: 100vh; | |
| } | |
| #img-display-input { | |
| max-height: 80vh; | |
| } | |
| #img-display-output { | |
| max-height: 80vh; | |
| } | |
| """ | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Model configurations - supports different model variants | |
| MODEL_CONFIGS = { | |
| "vits14": { | |
| "model_name": "LiheYoung/depth_anything_vits14", | |
| "display_name": "Depth Anything ViT-S (Small, Fastest)", | |
| "description": "Smallest and fastest model variant" | |
| }, | |
| "vitb14": { | |
| "model_name": "LiheYoung/depth_anything_vitb14", | |
| "display_name": "Depth Anything ViT-B (Base, Balanced)", | |
| "description": "Balanced model with good speed/quality tradeoff" | |
| }, | |
| "vitl14": { | |
| "model_name": "LiheYoung/depth_anything_vitl14", | |
| "display_name": "Depth Anything ViT-L (Large, Best Quality)", | |
| "description": "Largest model with best quality (default)" | |
| } | |
| } | |
| # Global model cache | |
| current_model = None | |
| current_model_name = None | |
| cached_models = {} # Store all downloaded models | |
| title = "# Depth Anything with Model Selection" | |
| description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data** with multiple model variants. | |
| You can choose between different model sizes for speed vs quality tradeoffs: | |
| - **ViT-S**: Fastest inference, good for real-time applications | |
| - **ViT-B**: Balanced performance and quality | |
| - **ViT-L**: Best quality, slower inference | |
| Please refer to our [paper](https://arxiv.org/abs/2401.10891), [project page](https://depth-anything.github.io), or [github](https://github.com/LiheYoung/Depth-Anything) for more details.""" | |
| transform = Compose([ | |
| Resize( | |
| width=518, | |
| height=518, | |
| resize_target=False, | |
| keep_aspect_ratio=True, | |
| ensure_multiple_of=14, | |
| resize_method='lower_bound', | |
| image_interpolation_method=cv2.INTER_CUBIC, | |
| ), | |
| NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| PrepareForNet(), | |
| ]) | |
| def get_memory_status(): | |
| """Get current memory usage status""" | |
| try: | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 # GB | |
| cached = torch.cuda.memory_reserved() / 1024**3 # GB | |
| total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB | |
| return f"GPU Memory: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {total_memory:.2f}GB total" | |
| else: | |
| return "Running on CPU" | |
| except: | |
| return "Memory status unavailable" | |
| def download_all_models(): | |
| """Download and cache all model variants at startup""" | |
| global cached_models | |
| print("π Downloading all Depth Anything model variants...") | |
| print("This may take a few minutes depending on your internet connection...") | |
| for key, config in MODEL_CONFIGS.items(): | |
| try: | |
| print(f"π₯ Downloading {config['display_name']}...") | |
| model = DepthAnything.from_pretrained(config['model_name']).to(DEVICE).eval() | |
| cached_models[key] = model | |
| print(f"β {config['display_name']} downloaded and cached successfully") | |
| except Exception as e: | |
| print(f"β Failed to download {config['display_name']}: {e}") | |
| cached_models[key] = None | |
| print(f"π Model download complete! {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)} models cached successfully.") | |
| return cached_models | |
| def load_model(model_selection): | |
| """Load the selected model variant from cache""" | |
| global current_model, current_model_name | |
| # Find the model key from the display name | |
| selected_key = None | |
| for key, config in MODEL_CONFIGS.items(): | |
| if config["display_name"] == model_selection: | |
| selected_key = key | |
| break | |
| if selected_key is None: | |
| # Fallback to vitl14 if not found | |
| selected_key = "vitl14" | |
| # Check if we need to switch to a different model | |
| if current_model_name != selected_key: | |
| print(f"π Switching to model: {MODEL_CONFIGS[selected_key]['display_name']}") | |
| # Get model from cache | |
| if selected_key in cached_models and cached_models[selected_key] is not None: | |
| current_model = cached_models[selected_key] | |
| current_model_name = selected_key | |
| print(f"β Model {selected_key} loaded from cache successfully") | |
| else: | |
| # Fallback: download model if not in cache | |
| print(f"β οΈ Model {selected_key} not in cache, downloading...") | |
| try: | |
| current_model = DepthAnything.from_pretrained(MODEL_CONFIGS[selected_key]['model_name']).to(DEVICE).eval() | |
| cached_models[selected_key] = current_model | |
| current_model_name = selected_key | |
| print(f"β Model {selected_key} downloaded and loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load model {selected_key}: {e}") | |
| # Fallback to any available cached model | |
| for fallback_key, fallback_model in cached_models.items(): | |
| if fallback_model is not None: | |
| current_model = fallback_model | |
| current_model_name = fallback_key | |
| print(f"π Using fallback model: {fallback_key}") | |
| break | |
| return current_model | |
| def predict_depth(model, image): | |
| return model(image) | |
| def on_submit(model_selection, image): | |
| if image is None: | |
| return None, None | |
| # Load the selected model | |
| try: | |
| model = load_model(model_selection) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None, None | |
| original_image = image.copy() | |
| h, w = image.shape[:2] | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 | |
| image = transform({'image': image})['image'] | |
| image = torch.from_numpy(image).unsqueeze(0).to(DEVICE) | |
| depth = predict_depth(model, image) | |
| depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0] | |
| raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16')) | |
| tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
| raw_depth.save(tmp.name) | |
| depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
| depth = depth.cpu().numpy().astype(np.uint8) | |
| colored_depth = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1] | |
| return [(original_image, colored_depth), tmp.name] | |
| # Download and cache all models at startup | |
| print("π Initializing Depth Anything with all model variants...") | |
| cached_models = download_all_models() | |
| # Set default model to the first successfully cached model | |
| default_model_key = None | |
| for key in ["vitl14", "vitb14", "vits14"]: # Priority order | |
| if key in cached_models and cached_models[key] is not None: | |
| default_model_key = key | |
| break | |
| if default_model_key: | |
| current_model = cached_models[default_model_key] | |
| current_model_name = default_model_key | |
| print(f"π― Default model set to: {MODEL_CONFIGS[default_model_key]['display_name']}") | |
| else: | |
| print("β No models were successfully cached!") | |
| current_model = None | |
| current_model_name = None | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Model Selection") | |
| model_selector = gr.Dropdown( | |
| choices=[config["display_name"] for config in MODEL_CONFIGS.values()], | |
| value=MODEL_CONFIGS[default_model_key]["display_name"] if default_model_key else MODEL_CONFIGS["vitl14"]["display_name"], | |
| label="Choose Model Variant", | |
| info="Select the model size based on your speed/quality requirements" | |
| ) | |
| # Add model info display | |
| initial_info = f"**Selected Model**: {MODEL_CONFIGS[default_model_key]['description']}" if default_model_key else "**Selected Model**: Unknown" | |
| model_info = gr.Markdown(initial_info) | |
| # Add memory status display | |
| memory_status = gr.Markdown(f"**Memory Status**: {get_memory_status()}") | |
| def update_model_info(selection): | |
| info_text = "**Selected Model**: Unknown" | |
| for key, config in MODEL_CONFIGS.items(): | |
| if config["display_name"] == selection: | |
| cached_status = "β Cached" if key in cached_models and cached_models[key] is not None else "β Not cached" | |
| info_text = f"**Selected Model**: {config['description']} ({cached_status})" | |
| break | |
| memory_text = f"**Memory Status**: {get_memory_status()}" | |
| return info_text, memory_text | |
| model_selector.change(update_model_info, inputs=[model_selector], outputs=[model_info, memory_status]) | |
| gr.Markdown("### Depth Prediction Demo") | |
| gr.Markdown("You can slide the output to compare the depth prediction with input image") | |
| with gr.Row(): | |
| input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') | |
| depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5) | |
| raw_file = gr.File(label="16-bit raw depth (can be considered as disparity)") | |
| submit = gr.Button("Submit", variant="primary") | |
| submit.click(on_submit, inputs=[model_selector, input_image], outputs=[depth_image_slider, raw_file]) | |
| # Examples section | |
| if os.path.exists('assets/examples'): | |
| example_files = os.listdir('assets/examples') | |
| example_files.sort() | |
| example_files = [os.path.join('assets/examples', filename) for filename in example_files] | |
| examples = gr.Examples( | |
| examples=example_files, | |
| inputs=[input_image], | |
| outputs=[depth_image_slider, raw_file], | |
| fn=lambda img: on_submit(model_selector.value, img), | |
| cache_examples=False, | |
| label="Example Images" | |
| ) | |
| # Model comparison section | |
| with gr.Accordion("π Model Comparison & Cache Status", open=False): | |
| # Create cache status dynamically | |
| cache_status_md = "### π¦ Cached Models Status\n" | |
| for key, config in MODEL_CONFIGS.items(): | |
| status = "β Cached" if key in cached_models and cached_models[key] is not None else "β Not cached" | |
| cache_status_md += f"- **{config['display_name']}**: {status}\n" | |
| cache_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n" | |
| cache_status_md += f"**Current Memory**: {get_memory_status()}\n\n" | |
| gr.Markdown(cache_status_md) | |
| gr.Markdown(""" | |
| ### π Model Performance Comparison | |
| | Model | Parameters | Speed | Quality | Use Case | | |
| |-------|------------|-------|---------|----------| | |
| | ViT-S | ~25M | Fastest | Good | Real-time applications | | |
| | ViT-B | ~97M | Medium | Better | Balanced performance | | |
| | ViT-L | ~335M | Slower | Best | High-quality results | | |
| **Note**: All models are pre-downloaded and cached for instant switching! | |
| **Processing times** are approximate and depend on hardware and image resolution. | |
| """) | |
| # Add refresh button for memory status | |
| def refresh_status(): | |
| updated_status_md = "### π¦ Cached Models Status\n" | |
| for key, config in MODEL_CONFIGS.items(): | |
| status = "β Cached" if key in cached_models and cached_models[key] is not None else "β Not cached" | |
| updated_status_md += f"- **{config['display_name']}**: {status}\n" | |
| updated_status_md += f"\n**Total Models Cached**: {len([m for m in cached_models.values() if m is not None])}/{len(MODEL_CONFIGS)}\n" | |
| updated_status_md += f"**Current Memory**: {get_memory_status()}\n\n" | |
| return updated_status_md | |
| refresh_btn = gr.Button("π Refresh Status", size="sm") | |
| status_display = gr.Markdown(cache_status_md) | |
| refresh_btn.click(refresh_status, outputs=[status_display]) | |
| # Citation section | |
| with gr.Accordion("π Citation", open=False): | |
| gr.Markdown(""" | |
| ```bibtex | |
| @article{depthanything, | |
| title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data}, | |
| author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang}, | |
| journal={arXiv:2401.10891}, | |
| year={2024} | |
| } | |
| ``` | |
| """) | |
| if __name__ == '__main__': | |
| demo.queue().launch() | |