| |
| """Sample script to run DepthPro. |
| |
| Copyright (C) 2024 Apple Inc. All Rights Reserved. |
| """ |
|
|
|
|
| import argparse |
| import logging |
| from pathlib import Path |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
| from matplotlib import pyplot as plt |
| from tqdm import tqdm |
|
|
| from depth_pro import create_model_and_transforms, load_rgb |
|
|
| LOGGER = logging.getLogger(__name__) |
|
|
|
|
| def get_torch_device() -> torch.device: |
| """Get the Torch device.""" |
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda:0") |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| return device |
|
|
|
|
| def run(args): |
| """Run Depth Pro on a sample image.""" |
| if args.verbose: |
| logging.basicConfig(level=logging.INFO) |
|
|
| |
| model, transform = create_model_and_transforms( |
| device=get_torch_device(), |
| precision=torch.half, |
| ) |
| model.eval() |
|
|
| image_paths = [args.image_path] |
| if args.image_path.is_dir(): |
| image_paths = args.image_path.glob("**/*") |
| relative_path = args.image_path |
| else: |
| relative_path = args.image_path.parent |
|
|
| if not args.skip_display: |
| plt.ion() |
| fig = plt.figure() |
| ax_rgb = fig.add_subplot(121) |
| ax_disp = fig.add_subplot(122) |
|
|
| for image_path in tqdm(image_paths): |
| |
| try: |
| LOGGER.info(f"Loading image {image_path} ...") |
| image, _, f_px = load_rgb(image_path) |
| except Exception as e: |
| LOGGER.error(str(e)) |
| continue |
| |
| |
| prediction = model.infer(transform(image), f_px=f_px) |
|
|
| |
| depth = prediction["depth"].detach().cpu().numpy().squeeze() |
| if f_px is not None: |
| LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}") |
| elif prediction["focallength_px"] is not None: |
| focallength_px = prediction["focallength_px"].detach().cpu().item() |
| LOGGER.info(f"Estimated focal length: {focallength_px}") |
|
|
| |
| if args.output_path is not None: |
| output_file = ( |
| args.output_path |
| / image_path.relative_to(relative_path).parent |
| / image_path.stem |
| ) |
| LOGGER.info(f"Saving depth map to: {str(output_file)}") |
| output_file.parent.mkdir(parents=True, exist_ok=True) |
| np.savez_compressed(output_file, depth=depth) |
|
|
| |
| cmap = plt.get_cmap("turbo_r") |
| normalized_depth = (depth - depth.min()) / ( |
| depth.max() - depth.min() |
| ) |
| color_depth = (cmap(normalized_depth)[..., :3] * 255).astype( |
| np.uint8 |
| ) |
| color_map_output_file = str(output_file) + ".jpg" |
| LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}") |
| PIL.Image.fromarray(color_depth).save( |
| color_map_output_file, format="JPEG", quality=90 |
| ) |
|
|
| |
| if not args.skip_display: |
| ax_rgb.imshow(image) |
| ax_disp.imshow(depth, cmap="turbo_r") |
| fig.canvas.draw() |
| fig.canvas.flush_events() |
|
|
| LOGGER.info("Done predicting depth!") |
| if not args.skip_display: |
| plt.show(block=True) |
|
|
|
|
| def main(): |
| """Run DepthPro inference example.""" |
| parser = argparse.ArgumentParser( |
| description="Inference scripts of DepthPro with PyTorch models." |
| ) |
| parser.add_argument( |
| "-i", |
| "--image-path", |
| type=Path, |
| default="./data/example.jpg", |
| help="Path to input image.", |
| ) |
| parser.add_argument( |
| "-o", |
| "--output-path", |
| type=Path, |
| help="Path to store output files.", |
| ) |
| parser.add_argument( |
| "--skip-display", |
| action="store_true", |
| help="Skip matplotlib display.", |
| ) |
| parser.add_argument( |
| "-v", |
| "--verbose", |
| action="store_true", |
| help="Show verbose output." |
| ) |
| |
| run(parser.parse_args()) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|