Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import depth_pro | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from typing import Union | |
| from pathlib import Path | |
| import os | |
| def predict_depth(image: Image.Image, auto_rotate: bool, remove_alpha: bool, grayscale: bool, model, transform): | |
| # Convert the PIL image to a temporary file path if needed | |
| image_path = "temp_image.jpg" | |
| image.save(image_path) | |
| # Load and preprocess the image from the given path | |
| loaded_image, _, f_px = depth_pro.load_rgb(image_path, auto_rotate=auto_rotate, remove_alpha=remove_alpha) | |
| loaded_image = transform(loaded_image) | |
| # Run inference | |
| prediction = model.infer(loaded_image, f_px=f_px) | |
| depth = prediction["depth"].detach().cpu().numpy().squeeze() # Depth in [m] | |
| inverse_depth = 1 / depth | |
| # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization. | |
| max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1) | |
| min_invdepth_vizu = max(1 / 250, inverse_depth.min()) | |
| inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / ( | |
| max_invdepth_vizu - min_invdepth_vizu | |
| ) | |
| focallength = prediction["focallength_px"].cpu().numpy() | |
| if grayscale: | |
| # Normalize the inverse depth map to 0-255 and convert to grayscale | |
| grayscale_depth = (inverse_depth_normalized * 255).astype(np.uint8) | |
| depth_image = Image.fromarray(grayscale_depth, mode="L") | |
| else: | |
| # Normalize and colorize depth map | |
| cmap = plt.get_cmap("turbo_r") | |
| color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8) | |
| depth_image = Image.fromarray(color_depth) | |
| # Clean up temporary image | |
| os.remove(image_path) | |
| return depth_image, focallength # Return depth map and f_px | |
| def main(): | |
| # Load model and preprocessing transform | |
| model, transform = depth_pro.create_model_and_transforms() | |
| model.eval() | |
| # Set up Gradio interface | |
| iface = gr.Interface( | |
| fn=lambda image, auto_rotate, remove_alpha, grayscale: predict_depth(image, auto_rotate, remove_alpha, grayscale, model, transform), | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), # Use image browser for input | |
| gr.Checkbox(label="Auto Rotate", value=True), # Checkbox for auto_rotate | |
| gr.Checkbox(label="Remove Alpha", value=True), # Checkbox for remove_alpha | |
| gr.Checkbox(label="Grayscale Depth", value=False) # Checkbox for grayscale | |
| ], | |
| outputs=[ | |
| gr.Image(label="Depth Map"), # Use PIL image output | |
| gr.Textbox(label="Focal Length in Pixels", placeholder="Focal length") # Output for f_px | |
| ], | |
| title="Depth Pro: Sharp Monocular Metric Depth Estimation", # Set the title to "Depth Pro" | |
| description="Upload an image and adjust options to estimate its depth map using a depth estimation model.", | |
| allow_flagging=False # Disable the flag button | |
| ) | |
| # Launch the interface | |
| iface.launch() | |
| if __name__ == "__main__": | |
| main() |