File size: 3,830 Bytes
7decfe1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://rollingdepth.github.io/
# https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------
import cv2
import numpy as np
import tarfile
import torch
from PIL import Image
from io import BytesIO
from typing import Union
def img_hwc2chw(img: Union[np.ndarray, torch.Tensor]):
assert len(img.shape) == 3
if isinstance(img, np.ndarray):
return np.transpose(img, (2, 0, 1))
if isinstance(img, torch.Tensor):
return img.permute(2, 0, 1)
raise TypeError("img should be np.ndarray or torch.Tensor")
def img_chw2hwc(chw):
assert 3 == len(chw.shape)
if isinstance(chw, torch.Tensor):
hwc = torch.permute(chw, (1, 2, 0))
elif isinstance(chw, np.ndarray):
hwc = np.moveaxis(chw, 0, -1)
else:
raise TypeError("img should be np.ndarray or torch.Tensor")
return hwc
def img_int2float(img, dtype=None):
if dtype is not None:
if isinstance(img, np.ndarray):
img = img.astype(dtype)
else:
img = img.to(dtype)
return img / 255.0
def img_float2int(img):
if isinstance(img, np.ndarray):
return (img * 255.0).astype(np.uint8)
else:
return (img * 255.0).to(torch.uint8)
def img_normalize(img):
return img * 2.0 - 1.0
def img_denormalize(img):
return img * 0.5 + 0.5
def img_linear2srgb(img):
return img ** (1 / 2.2)
def img_srgb2linear(img):
return img**2.2
def write_img(img: np.ndarray, path):
img = img_float2int(img)
if len(img.shape) == 3:
img = img[:, :, ::-1] # RGB->BGR
cv2.imwrite(path, img)
def _read_image_from_buffer(buffer: BytesIO, is_hdr: bool) -> np.ndarray:
if is_hdr:
file_bytes = np.frombuffer(buffer.read(), dtype=np.uint8)
img = cv2.imdecode(file_bytes, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.clip(img, 0, 1)
else:
img = Image.open(buffer) # [H, W, rgb]
img = np.asarray(img)
img = img_int2float(img)
return img
def is_hdr(path: str):
return path.endswith(".exr")
def read_img_from_tar(tar_file: tarfile.TarFile, rel_path: str) -> np.ndarray:
tar_obj = tar_file.extractfile(rel_path)
buffer = BytesIO(tar_obj.read())
img = _read_image_from_buffer(buffer, is_hdr(rel_path))
return img
def read_img_from_file(path: str) -> np.ndarray:
with open(path, "rb") as f:
buffer = BytesIO(f.read())
img = _read_image_from_buffer(buffer, is_hdr(path))
return img
|