File size: 3,830 Bytes
7decfe1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
#   https://marigoldmonodepth.github.io
#   https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
#   https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
#   https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
#   https://huggingface.co/prs-eth
# Related projects:
#   https://rollingdepth.github.io/
#   https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
#   https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------

import cv2
import numpy as np
import tarfile
import torch
from PIL import Image
from io import BytesIO
from typing import Union


def img_hwc2chw(img: Union[np.ndarray, torch.Tensor]):
    assert len(img.shape) == 3
    if isinstance(img, np.ndarray):
        return np.transpose(img, (2, 0, 1))
    if isinstance(img, torch.Tensor):
        return img.permute(2, 0, 1)
    raise TypeError("img should be np.ndarray or torch.Tensor")


def img_chw2hwc(chw):
    assert 3 == len(chw.shape)
    if isinstance(chw, torch.Tensor):
        hwc = torch.permute(chw, (1, 2, 0))
    elif isinstance(chw, np.ndarray):
        hwc = np.moveaxis(chw, 0, -1)
    else:
        raise TypeError("img should be np.ndarray or torch.Tensor")
    return hwc


def img_int2float(img, dtype=None):
    if dtype is not None:
        if isinstance(img, np.ndarray):
            img = img.astype(dtype)
        else:
            img = img.to(dtype)
    return img / 255.0


def img_float2int(img):
    if isinstance(img, np.ndarray):
        return (img * 255.0).astype(np.uint8)
    else:
        return (img * 255.0).to(torch.uint8)


def img_normalize(img):
    return img * 2.0 - 1.0


def img_denormalize(img):
    return img * 0.5 + 0.5


def img_linear2srgb(img):
    return img ** (1 / 2.2)


def img_srgb2linear(img):
    return img**2.2


def write_img(img: np.ndarray, path):
    img = img_float2int(img)
    if len(img.shape) == 3:
        img = img[:, :, ::-1]  # RGB->BGR
    cv2.imwrite(path, img)


def _read_image_from_buffer(buffer: BytesIO, is_hdr: bool) -> np.ndarray:
    if is_hdr:
        file_bytes = np.frombuffer(buffer.read(), dtype=np.uint8)
        img = cv2.imdecode(file_bytes, cv2.IMREAD_UNCHANGED)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.clip(img, 0, 1)
    else:
        img = Image.open(buffer)  # [H, W, rgb]
        img = np.asarray(img)
        img = img_int2float(img)

    return img


def is_hdr(path: str):
    return path.endswith(".exr")


def read_img_from_tar(tar_file: tarfile.TarFile, rel_path: str) -> np.ndarray:
    tar_obj = tar_file.extractfile(rel_path)
    buffer = BytesIO(tar_obj.read())
    img = _read_image_from_buffer(buffer, is_hdr(rel_path))
    return img


def read_img_from_file(path: str) -> np.ndarray:
    with open(path, "rb") as f:
        buffer = BytesIO(f.read())
        img = _read_image_from_buffer(buffer, is_hdr(path))
    return img