import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image import torch.nn.functional as F import torchvision.transforms as transforms import depth_only_parameters as params from models.depth_only_model import PVSDNet from models.depth_only_lite_model import PVSDNet_Lite import helperFunctions as helper import socket from huggingface_hub import hf_hub_download import joblib REPO_ID = "3ZadeSSG/PVSDNet-Depth-Only" print("Downloading/Loading checkpoints from Hugging Face Hub...") params.MODEL_Small_Location = hf_hub_download( repo_id=REPO_ID, filename="depth_only_lite_model.pth" ) params.MODEL_Large_Location = hf_hub_download( repo_id=REPO_ID, filename="depth_only_model.pth" ) print(f"Large Model loaded at: {params.MODEL_Large_Location}") print(f"Lite Model loaded at: {params.MODEL_Small_Location}") def get_valid_resolutions(width, height): """Dynamically determines valid resolutions based on input size. - Caps the highest resolution at 1024px to avoid unnecessary high-res computations. - Uses 6 resolutions for large images to improve multi-scale fusion quality. - Uses 4 resolutions for smaller images (< 512px width or height). """ def make_divisible(n, base=16): return max(base, int(round(n / base) * base)) max_resolution = 1024 high_w, high_h = make_divisible(min(width, max_resolution)), make_divisible(min(height, max_resolution)) # Calculate more intermediate steps for better fusion r80_w, r80_h = make_divisible(int(high_w // 1.25)), make_divisible(int(high_h // 1.25)) r66_w, r66_h = make_divisible(int(high_w // 1.5)), make_divisible(int(high_h // 1.5)) r50_w, r50_h = make_divisible(int(high_w // 2)), make_divisible(int(high_h // 2)) r40_w, r40_h = make_divisible(int(high_w // 2.5)), make_divisible(int(high_h // 2.5)) r33_w, r33_h = make_divisible(max(256, int(high_w // 3))), make_divisible(max(256, int(high_h // 3))) if width < 512 or height < 512: return [(high_w, high_h), (r80_w, r80_h), (r66_w, r66_h), (r50_w, r50_h)] else: return [ (high_w, high_h), (r80_w, r80_h), (r66_w, r66_h), (r50_w, r50_h), (r40_w, r40_h), (r33_w, r33_h) ] def get_transforms(resolutions): return [transforms.Compose([transforms.Resize((h, w)), transforms.ToTensor()]) for w, h in resolutions] def get_prediction(image, transform, model): img_input = image.convert('RGB') img_input = transform(img_input).unsqueeze(0).to(params.DEVICE) depth_out = model(img_input).detach().squeeze(0).to("cpu") return depth_out def predict_single_image(image, model_type): if image is None: return None, None # Select model class and checkpoint if model_type == "Lite": model_class = PVSDNet_Lite checkpoint = params.MODEL_Small_Location else: # Default to "Large" model_class = PVSDNet checkpoint = params.MODEL_Large_Location model = model_class(total_image_input=params.params_number_input) model = helper.load_Checkpoint(checkpoint, model, load_cpu=True) model.to(params.DEVICE) model.eval() original_width, original_height = image.size resolutions = get_valid_resolutions(original_width, original_height) print(f"Resolutions: {resolutions} for Model Type: {model_type}") transforms_list = get_transforms(resolutions) depth_maps = [get_prediction(image, t, model) for t in transforms_list] depth_maps_resized = [ F.interpolate(depth[None], (original_height, original_width), mode='bilinear', align_corners=False)[0, 0] for depth in depth_maps ] depth_final = sum(depth_maps_resized) / len(depth_maps_resized) depth_image = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min()) img_out = depth_image.numpy() img_out_colored = plt.get_cmap('inferno')(img_out / np.max(img_out))[:, :, :3] img_out_colored = (img_out_colored * 255).astype(np.uint8) gray_scale_img_out = (depth_image.numpy() * 255).astype(np.uint8) return Image.fromarray(img_out_colored), Image.fromarray(gray_scale_img_out) with gr.Blocks(title="PVSDNet-Depth-Only Model", theme="default") as demo: gr.Markdown( """ ## PVSDNet-Depth-Only ZeroShot Relative Depth Estimation Model * Upload an image and get its depth estimation with multi-scale fusion. * Images use 2 - 6 resolutions for multi-scale fusion. **Note:** Huggingface demo is running on CPU so inference speeds will be slow. ### Head to our [Project Page](https://realistic3d-miun.github.io/PVSDNet/) for more details about the models. """) with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="RGB Image", height=384) with gr.Accordion("Advanced Settings", open=False): model_type_dropdown = gr.Dropdown(["Large", "Lite"], label="Model Type", value="Large") generate_btn = gr.Button("Estimate Depth", variant="primary") with gr.Column(): output_color = gr.Image(type="pil", label="Depth Map (Color)", height=384) output_gray = gr.Image(type="pil", label="Depth Map (Grayscale)", height=384) generate_btn.click( fn=predict_single_image, inputs=[img_input, model_type_dropdown], outputs=[output_color, output_gray] ) gr.Markdown("### Example Samples") with gr.Column(): with gr.Row(): with gr.Column(scale=2): gr.Markdown("**Example Image (Click to load)**") with gr.Column(scale=1): gr.Markdown("**Resolution**") with gr.Column(scale=2): gr.Markdown("**Fusion Resolutions**") with gr.Row(variant="panel"): with gr.Column(scale=2): diode_preview = gr.Image("./samples/DIODE/00022_00195_outdoor_010_030.png", label="DIODE", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("1024 x 768") with gr.Column(scale=2): gr.Markdown("1024x768, 816x608, 688x512, 512x384, 416x304, 336x256") with gr.Row(variant="panel"): with gr.Column(scale=2): eth3d_preview = gr.Image("./samples/ETH3D/DSC_0243.JPG", label="ETH3D", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("6048 x 4032") with gr.Column(scale=2): gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") with gr.Row(variant="panel"): with gr.Column(scale=2): sintel_preview = gr.Image("./samples/Sintel/frame_0028_temple.png", label="Sintel", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("1024 x 436") with gr.Column(scale=2): gr.Markdown("1024x432, 816x352, 688x288, 512x224") with gr.Row(variant="panel"): with gr.Column(scale=2): kitti_preview = gr.Image("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png", label="KITTI", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("1216 x 532") with gr.Column(scale=2): gr.Markdown("1024x352, 816x288, 688x240, 512x176") with gr.Row(variant="panel"): with gr.Column(scale=2): wild_1_preview = gr.Image("./samples/Wild/toy.jpeg", label="Wild Image 1", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("3019 x 3018") with gr.Column(scale=2): gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") with gr.Row(variant="panel"): with gr.Column(scale=2): wild_2_preview = gr.Image("./samples/Wild/hamburg.jpeg", label="Wild Image 2", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("1536 x 1920") with gr.Column(scale=2): gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") with gr.Row(variant="panel"): with gr.Column(scale=2): wild_3_preview = gr.Image("./samples/Wild/north_hill.jpeg", label="Wild Image 3", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("2320 x 2321") with gr.Column(scale=2): gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") with gr.Row(variant="panel"): with gr.Column(scale=2): wild_4_preview = gr.Image("./samples/Wild/EH.jpeg", label="Wild Image 4", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("1920 x 1080") with gr.Column(scale=2): gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") with gr.Row(variant="panel"): with gr.Column(scale=2): wild_5_preview = gr.Image("./samples/Wild/train_station.jpeg", label="Wild Image 5", height=120, interactive=False, show_label=True) with gr.Column(scale=1): gr.Markdown("1066 x 1060") with gr.Column(scale=2): gr.Markdown("1024x1024, 816x816, 688x688, 512x512, 416x416, 336x336") # Define click events to load images eth3d_preview.select(fn=lambda: Image.open("./samples/ETH3D/DSC_0243.JPG"), outputs=img_input) sintel_preview.select(fn=lambda: Image.open("./samples/Sintel/frame_0028_temple.png"), outputs=img_input) kitti_preview.select(fn=lambda: Image.open("./samples/KITTI/2011_10_03_drive_0047_sync_image_0000000383_image_02.png"), outputs=img_input) diode_preview.select(fn=lambda: Image.open("./samples/DIODE/00022_00195_outdoor_010_030.png"), outputs=img_input) wild_1_preview.select(fn=lambda: Image.open("./samples/Wild/toy.jpeg"), outputs=img_input) wild_2_preview.select(fn=lambda: Image.open("./samples/Wild/hamburg.jpeg"), outputs=img_input) wild_3_preview.select(fn=lambda: Image.open("./samples/Wild/north_hill.jpeg"), outputs=img_input) wild_4_preview.select(fn=lambda: Image.open("./samples/Wild/EH.jpeg"), outputs=img_input) wild_5_preview.select(fn=lambda: Image.open("./samples/Wild/train_station.jpeg"), outputs=img_input) demo.launch()