File size: 995 Bytes
7349148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import cv2
import numpy as np
import torch
from jaxtyping import Float


def read_img(path):
    img = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
    if img.ndim == 3:
        img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB)
    elif img.ndim == 2:
        img = img[..., np.newaxis]
    dinfo = np.iinfo(img.dtype)
    return (img.astype(np.float32) / dinfo.max) * 2 - 1


def write_img(path: str, data: np.ndarray):
    data = np.clip(data * 0.5 + 0.5, 0, 1)
    if data.ndim == 3 and data.shape[-1] == 3:
        data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR)
    elif data.ndim == 2:
        data = data[..., np.newaxis]

    data = (data * 255).astype(np.uint8)
    cv2.imwrite(path, data)


def to_torch(img: Float[np.ndarray, "H W C"]) -> Float[torch.Tensor, "C H W"]:
    return torch.from_numpy(img).permute(2, 0, 1)


def from_torch(img: Float[torch.Tensor, "C H W"]) -> Float[np.ndarray, "H W C"]:
    return img.permute(1, 2, 0).detach().cpu().float().numpy()