File size: 2,364 Bytes
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import imageio.v3 as imageio
import numpy as np
import warnings
import os

import torchvision.transforms.functional as F

def read_image(filename: str, out: torch.Tensor=None) -> torch.Tensor:
    '''
    Read a local image file into a float tensor (pixel values are normalized to [0, 1], CxHxW)

    Args:
        filename: Image file path.
        out: Fill in this tensor rather than return a new tensor if provided.

    Returns:
        Loaded image tensor.
    '''
    with warnings.catch_warnings():
        warnings.simplefilter("ignore") # ignore PIL's user warning that reads fp16 img as fp32
        img: np.ndarray = imageio.imread(filename)

    # Convert the image array to float tensor according to its data type
    res = None
    if img.dtype == np.uint8:
        img = img.astype(np.float32) / 255.0
    elif img.dtype == np.uint16 or img.dtype == np.int32:
        img = img.astype(np.float32) / 65535.0
    else:
        raise ValueError(f'Unrecognized image pixel value type: {img.dtype}')
    if img.ndim == 2:
        res = torch.from_numpy(img).unsqueeze(0)  # 1xHxW for grayscale images
    elif img.ndim == 3:
        res = torch.from_numpy(img).movedim(2, 0)[:3] # HxWxC to CxHxW
    else:
        raise ValueError(f'Unrecognized image dimension: {img.shape}')
    
    if out is None:
        return res
    out.copy_(res)

def create_img(img: torch.Tensor):
    '''
    Convert tensor to PIL image

    Args:
        path: Image tensor CxHxW. Squeeze if BxCxHxW and B==1

    Returns:
        PIL image
    '''
    if img.dim() == 4:
        assert img.shape[0] == 1
        img = img.squeeze(0)

    if img.shape[0] == 4:
        out_img = F.to_pil_image(img, mode="CMYK")
        out_img = out_img.convert('RGB')
    elif img.shape[0] == 3:
        out_img = F.to_pil_image(img, mode="RGB")
    elif img.shape[0] == 1:
        out_img = F.to_pil_image(img, mode="L")
    else:
        raise ValueError("Unsupported image dimension.")
    return out_img
    
def save_maps(path: str, maps: dict):
    '''
    Save SVBRDF maps to a given path.

    Args:
        path: Output path.   
        maps: Named maps of tensor images. 
    '''
    if not os.path.exists(path):
        os.makedirs(path)
    for name, image in maps.items():
        out_img = create_img(image)
        out_img.save(os.path.join(path, name+".png"))