| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import sys |
| | import os |
| |
|
| | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| |
|
| | import argparse |
| | import logging |
| | import numpy as np |
| | import os |
| | import torch |
| | from PIL import Image |
| | from glob import glob |
| | from tqdm.auto import tqdm |
| |
|
| | from olbedo import OlbedoIIDPipeline, OlbedoIIDOutput |
| | from olbedo.util.image_util import chw2hwc |
| |
|
| | import rasterio |
| |
|
| | EXTENSION_LIST = [".jpg", ".jpeg", ".png", ".tif", ".tiff"] |
| |
|
| |
|
| | if "__main__" == __name__: |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | |
| | parser = argparse.ArgumentParser( |
| | description="Olbedo: Intrinsic Image Decomposition for Large-Scale Outdoor Environments" |
| | ) |
| | parser.add_argument( |
| | "--checkpoint", |
| | type=str, |
| | default="checkpoint", |
| | help="Checkpoint path or hub name.", |
| | ) |
| | parser.add_argument( |
| | "--input_rgb_dir", |
| | type=str, |
| | required=True, |
| | help="Path to the input image folder.", |
| | ) |
| | parser.add_argument( |
| | "--output_dir", type=str, required=True, help="Output directory." |
| | ) |
| | parser.add_argument( |
| | "--denoise_steps", |
| | type=int, |
| | default=4, |
| | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed. If set to " |
| | "`None`, default value will be read from checkpoint.", |
| | ) |
| | parser.add_argument( |
| | "--processing_res", |
| | type=int, |
| | default=2000, |
| | help="Resolution to which the input is resized before performing estimation. `0` uses the original input " |
| | "resolution; `None` resolves the best default from the model checkpoint. Default: `None`", |
| | ) |
| | parser.add_argument( |
| | "--ensemble_size", |
| | type=int, |
| | default=1, |
| | help="Number of predictions to be ensembled. Default: `1`.", |
| | ) |
| | parser.add_argument( |
| | "--half_precision", |
| | "--fp16", |
| | action="store_true", |
| | help="Run with half-precision (16-bit float), might lead to suboptimal result.", |
| | ) |
| | parser.add_argument( |
| | "--output_processing_res", |
| | action="store_true", |
| | help="Setting this flag will output the result at the effective value of `processing_res`, otherwise the " |
| | "output will be resized to the input resolution.", |
| | ) |
| | parser.add_argument( |
| | "--resample_method", |
| | choices=["bilinear", "bicubic", "nearest"], |
| | default="bilinear", |
| | help="Resampling method used to resize images and predictions. This can be one of `bilinear`, `bicubic` or " |
| | "`nearest`. Default: `bilinear`", |
| | ) |
| | parser.add_argument( |
| | "--seed", |
| | type=int, |
| | default=None, |
| | help="Reproducibility seed. Set to `None` for randomized inference. Default: `None`", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=0, |
| | help="Inference batch size. Default: 0 (will be set automatically).", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--model", |
| | type=str, |
| | default="rgbx", |
| | choices=["rgbx", "others"], |
| | help="Choose model", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | checkpoint_path = args.checkpoint |
| | input_rgb_dir = args.input_rgb_dir |
| | output_dir = args.output_dir |
| |
|
| | denoise_steps = args.denoise_steps |
| | ensemble_size = args.ensemble_size |
| | if ensemble_size > 15: |
| | logging.warning("Running with large ensemble size will be slow.") |
| | half_precision = args.half_precision |
| |
|
| | processing_res = args.processing_res |
| | match_input_res = not args.output_processing_res |
| | if 0 == processing_res and match_input_res is False: |
| | logging.warning( |
| | "Processing at native resolution without resizing output might NOT lead to exactly the same resolution, " |
| | "due to the padding and pooling properties of conv layers." |
| | ) |
| | resample_method = args.resample_method |
| |
|
| | seed = args.seed |
| | batch_size = args.batch_size |
| | model = args.model |
| |
|
| | |
| | |
| |
|
| | output_dir_vis = os.path.join(output_dir, "albedo") |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | os.makedirs(output_dir_vis, exist_ok=True) |
| | logging.info(f"output dir = {output_dir}") |
| |
|
| | |
| |
|
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | else: |
| | device = torch.device("cpu") |
| | logging.warning("CUDA is not available. Running on CPU will be slow.") |
| | logging.info(f"device = {device}") |
| |
|
| | |
| | rgb_filename_list = glob(os.path.join(input_rgb_dir, "*")) |
| | rgb_filename_list = [ |
| | f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST |
| | ] |
| | rgb_filename_list = sorted(rgb_filename_list) |
| | n_images = len(rgb_filename_list) |
| | if n_images > 0: |
| | logging.info(f"Found {n_images} images") |
| | else: |
| | logging.error(f"No image found in '{input_rgb_dir}'") |
| | exit(1) |
| |
|
| | |
| | if half_precision: |
| | dtype = torch.float16 |
| | variant = "fp16" |
| | logging.info( |
| | f"Running with half precision ({dtype}), might lead to suboptimal result." |
| | ) |
| | else: |
| | dtype = torch.float32 |
| | variant = None |
| |
|
| | pipe: OlbedoIIDPipeline = OlbedoIIDPipeline.from_pretrained( |
| | checkpoint_path, variant=variant, torch_dtype=dtype |
| | ) |
| |
|
| | pipe.mode = model |
| |
|
| | try: |
| | pipe.enable_xformers_memory_efficient_attention() |
| | except ImportError: |
| | pass |
| |
|
| | pipe = pipe.to(device) |
| | logging.info("Loaded IID pipeline") |
| |
|
| | |
| | logging.info( |
| | f"Inference settings: checkpoint = `{checkpoint_path}`, " |
| | f"with predicted target names: {pipe.target_names}, " |
| | f"denoise_steps = {denoise_steps or pipe.default_denoising_steps}, " |
| | f"ensemble_size = {ensemble_size}, " |
| | f"processing resolution = {processing_res or pipe.default_processing_resolution}, " |
| | f"seed = {seed}; " |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | for rgb_path in tqdm(rgb_filename_list, desc="IID Inference", leave=True): |
| | |
| | input_image = Image.open(rgb_path) |
| |
|
| | |
| | if seed is None: |
| | generator = None |
| | else: |
| | generator = torch.Generator(device=device) |
| | generator.manual_seed(seed) |
| |
|
| | |
| | pipe_out: OlbedoIIDOutput = pipe( |
| | input_image, |
| | denoising_steps=denoise_steps, |
| | ensemble_size=ensemble_size, |
| | processing_res=processing_res, |
| | match_input_res=match_input_res, |
| | batch_size=batch_size, |
| | show_progress_bar=False, |
| | resample_method=resample_method, |
| | generator=generator, |
| | ) |
| |
|
| | rgb_name_base, rgb_ext = os.path.splitext(os.path.basename(rgb_path)) |
| | for target_name in pipe.target_names: |
| | if target_name!='albedo': |
| | continue |
| | target_entry = pipe_out[target_name] |
| |
|
| | if rgb_ext.lower() in [".tif", ".tiff"]: |
| | img = np.array(target_entry.image).transpose((2, 0, 1)) |
| |
|
| | with rasterio.open(rgb_path) as src: |
| | profile = src.profile.copy() |
| |
|
| | with rasterio.open(os.path.join(output_dir_vis, f"{rgb_name_base}.tif"), 'w', **profile) as dst: |
| | dst.write(img.astype(profile['dtype'])) |
| | else: |
| | target_entry.image.save( |
| | os.path.join(output_dir_vis, f"{rgb_name_base}{rgb_ext}") |
| | ) |