Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import depth_only_parameters as params | |
| from models.depth_only_model import PVSDNet | |
| from models.depth_only_lite_model import PVSDNet_Lite | |
| import helperFunctions as helper | |
| import socket | |
| from huggingface_hub import hf_hub_download | |
| import joblib | |
| REPO_ID = "3ZadeSSG/PVSDNet-Depth-Only" | |
| print("Downloading/Loading checkpoints from Hugging Face Hub...") | |
| params.MODEL_Small_Location = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="depth_only_lite_model.pth" | |
| ) | |
| params.MODEL_Large_Location = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="depth_only_model.pth" | |
| ) | |
| print(f"Large Model loaded at: {params.MODEL_Large_Location}") | |
| print(f"Lite Model loaded at: {params.MODEL_Small_Location}") | |
| def get_valid_resolutions(width, height): | |
| """Dynamically determines valid resolutions based on input size. | |
| - Caps the highest resolution at 1024px to avoid unnecessary high-res computations. | |
| - Uses 6 resolutions for large images to improve multi-scale fusion quality. | |
| - Uses 4 resolutions for smaller images (< 512px width or height). | |
| """ | |
| def make_divisible(n, base=16): | |
| return max(base, int(round(n / base) * base)) | |
| max_resolution = 1024 | |
| high_w, high_h = make_divisible(min(width, max_resolution)), make_divisible(min(height, max_resolution)) | |
| # Calculate more intermediate steps for better fusion | |
| r80_w, r80_h = make_divisible(int(high_w // 1.25)), make_divisible(int(high_h // 1.25)) | |
| r66_w, r66_h = make_divisible(int(high_w // 1.5)), make_divisible(int(high_h // 1.5)) | |
| r50_w, r50_h = make_divisible(int(high_w // 2)), make_divisible(int(high_h // 2)) | |
| r40_w, r40_h = make_divisible(int(high_w // 2.5)), make_divisible(int(high_h // 2.5)) | |
| r33_w, r33_h = make_divisible(max(256, int(high_w // 3))), make_divisible(max(256, int(high_h // 3))) | |
| if width < 512 or height < 512: | |
| return [(high_w, high_h), (r80_w, r80_h), (r66_w, r66_h), (r50_w, r50_h)] | |
| else: | |
| return [ | |
| (high_w, high_h), | |
| (r80_w, r80_h), | |
| (r66_w, r66_h), | |
| (r50_w, r50_h), | |
| (r40_w, r40_h), | |
| (r33_w, r33_h) | |
| ] | |
| def get_transforms(resolutions): | |
| return [transforms.Compose([transforms.Resize((h, w)), transforms.ToTensor()]) for w, h in resolutions] | |
| def get_prediction(image, transform, model): | |
| img_input = image.convert('RGB') | |
| img_input = transform(img_input).unsqueeze(0).to(params.DEVICE) | |
| depth_out = model(img_input).detach().squeeze(0).to("cpu") | |
| return depth_out | |
| def predict_single_image(image, model_type): | |
| if image is None: | |
| return None, None | |
| # Select model class and checkpoint | |
| if model_type == "Lite": | |
| model_class = PVSDNet_Lite | |
| checkpoint = params.MODEL_Small_Location | |
| else: # Default to "Large" | |
| model_class = PVSDNet | |
| checkpoint = params.MODEL_Large_Location | |
| model = model_class(total_image_input=params.params_number_input) | |
| model = helper.load_Checkpoint(checkpoint, model, load_cpu=True) | |
| model.to(params.DEVICE) | |
| model.eval() | |
| original_width, original_height = image.size | |
| resolutions = get_valid_resolutions(original_width, original_height) | |
| print(f"Resolutions: {resolutions} for Model Type: {model_type}") | |
| transforms_list = get_transforms(resolutions) | |
| depth_maps = [get_prediction(image, t, model) for t in transforms_list] | |
| depth_maps_resized = [ | |
| F.interpolate(depth[None], (original_height, original_width), mode='bilinear', align_corners=False)[0, 0] | |
| for depth in depth_maps | |
| ] | |
| depth_final = sum(depth_maps_resized) / len(depth_maps_resized) | |
| depth_image = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min()) | |
| img_out = depth_image.numpy() | |
| img_out_colored = plt.get_cmap('inferno')(img_out / np.max(img_out))[:, :, :3] | |
| img_out_colored = (img_out_colored * 255).astype(np.uint8) | |
| gray_scale_img_out = (depth_image.numpy() * 255).astype(np.uint8) | |
| return Image.fromarray(img_out_colored), Image.fromarray(gray_scale_img_out) | |
| with gr.Blocks(title="PVSDNet-Depth-Only Model", theme="default") as demo: | |
| gr.Markdown( | |
| """ | |
| ## PVSDNet-Depth-Only ZeroShot Relative Depth Estimation Model | |
| * Upload an image and get its depth estimation with multi-scale fusion. | |
| * Images use 2 - 6 resolutions for multi-scale fusion. | |
| **Note:** Huggingface demo is running on CPU so inference speeds will be slow. | |
| ### Head to our [Project Page](https://realistic3d-miun.github.io/PVSDNet/) for more details about the models. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="RGB Image", height=384) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| model_type_dropdown = gr.Dropdown(["Large", "Lite"], label="Model Type", value="Large") | |
| generate_btn = gr.Button("Estimate Depth", variant="primary") | |
| with gr.Column(): | |
| output_color = gr.Image(type="pil", label="Depth Map (Color)", height=384) | |
| output_gray = gr.Image(type="pil", label="Depth Map (Grayscale)", height=384) | |
| generate_btn.click( | |
| fn=predict_single_image, | |
| inputs=[img_input, model_type_dropdown], | |
| outputs=[output_color, output_gray] | |
| ) | |
| gr.Markdown("### Example Samples") | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(scale=2): gr.Markdown("**Example Image (Click to load)**") | |
| with gr.Column(scale=1): gr.Markdown("**Resolution**") | |
| with gr.Column(scale=2): gr.Markdown("**Fusion Resolutions**") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| diode_preview = gr.Image("./samples/DIODE/00022_00195_outdoor_010_030.png", label="DIODE", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("1024 x 768") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x768, 816x608, 688x512, 512x384, 416x304, 336x256") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| eth3d_preview = gr.Image("./samples/ETH3D/DSC_0243.JPG", label="ETH3D", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("6048 x 4032") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| sintel_preview = gr.Image("./samples/Sintel/frame_0028_temple.png", label="Sintel", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("1024 x 436") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x432, 816x352, 688x288, 512x224") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| kitti_preview = gr.Image("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png", label="KITTI", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("1216 x 532") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x352, 816x288, 688x240, 512x176") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| wild_1_preview = gr.Image("./samples/Wild/toy.jpeg", label="Wild Image 1", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("3019 x 3018") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| wild_2_preview = gr.Image("./samples/Wild/hamburg.jpeg", label="Wild Image 2", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("1536 x 1920") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| wild_3_preview = gr.Image("./samples/Wild/north_hill.jpeg", label="Wild Image 3", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("2320 x 2321") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| wild_4_preview = gr.Image("./samples/Wild/EH.jpeg", label="Wild Image 4", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("1920 x 1080") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=2): | |
| wild_5_preview = gr.Image("./samples/Wild/train_station.jpeg", label="Wild Image 5", height=120, interactive=False, show_label=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("1066 x 1060") | |
| with gr.Column(scale=2): | |
| gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") | |
| # Define click events to load images | |
| eth3d_preview.select(fn=lambda: Image.open("./samples/ETH3D/DSC_0243.JPG"), outputs=img_input) | |
| sintel_preview.select(fn=lambda: Image.open("./samples/Sintel/frame_0028_temple.png"), outputs=img_input) | |
| kitti_preview.select(fn=lambda: Image.open("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png"), outputs=img_input) | |
| diode_preview.select(fn=lambda: Image.open("./samples/DIODE/00022_00195_outdoor_010_030.png"), outputs=img_input) | |
| wild_1_preview.select(fn=lambda: Image.open("./samples/Wild/toy.jpeg"), outputs=img_input) | |
| wild_2_preview.select(fn=lambda: Image.open("./samples/Wild/hamburg.jpeg"), outputs=img_input) | |
| wild_3_preview.select(fn=lambda: Image.open("./samples/Wild/north_hill.jpeg"), outputs=img_input) | |
| wild_4_preview.select(fn=lambda: Image.open("./samples/Wild/EH.jpeg"), outputs=img_input) | |
| wild_5_preview.select(fn=lambda: Image.open("./samples/Wild/train_station.jpeg"), outputs=img_input) | |
| demo.launch() | |