Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import imageio.v3 as imageio | |
| import numpy as np | |
| import warnings | |
| import os | |
| import safetensors | |
| import torchvision.transforms.functional as F | |
| def read_image(filename: str, out: torch.Tensor=None) -> torch.Tensor: | |
| ''' | |
| Read a local image file into a float tensor (pixel values are normalized to [0, 1], CxHxW) | |
| Args: | |
| filename: Image file path. | |
| out: Fill in this tensor rather than return a new tensor if provided. | |
| Returns: | |
| Loaded image tensor. | |
| ''' | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") # ignore PIL's user warning that reads fp16 img as fp32 | |
| img: np.ndarray = imageio.imread(filename) | |
| # Convert the image array to float tensor according to its data type | |
| res = None | |
| if img.dtype == np.uint8: | |
| img = img.astype(np.float32) / 255.0 | |
| elif img.dtype == np.uint16 or img.dtype == np.int32: | |
| img = img.astype(np.float32) / 65535.0 | |
| else: | |
| raise ValueError(f'Unrecognized image pixel value type: {img.dtype}') | |
| if img.ndim == 2: | |
| res = torch.from_numpy(img).unsqueeze(0) # 1xHxW for grayscale images | |
| elif img.ndim == 3: | |
| res = torch.from_numpy(img).movedim(2, 0)[:3] # HxWxC to CxHxW | |
| else: | |
| raise ValueError(f'Unrecognized image dimension: {img.shape}') | |
| if out is None: | |
| return res | |
| out.copy_(res) | |
| def create_img(img: torch.Tensor): | |
| ''' | |
| Convert tensor to PIL image | |
| Args: | |
| path: Image tensor CxHxW. Squeeze if BxCxHxW and B==1 | |
| Returns: | |
| PIL image | |
| ''' | |
| if img.dim() == 4: | |
| assert img.shape[0] == 1 | |
| img = img.squeeze(0) | |
| if img.shape[0] == 4: | |
| out_img = F.to_pil_image(img, mode="CMYK") | |
| out_img = out_img.convert('RGB') | |
| elif img.shape[0] == 3: | |
| out_img = F.to_pil_image(img, mode="RGB") | |
| elif img.shape[0] == 1: | |
| out_img = F.to_pil_image(img, mode="L") | |
| else: | |
| raise ValueError("Unsupported image dimension.") | |
| return out_img | |
| def save_maps(path: str, maps: dict): | |
| ''' | |
| Save SVBRDF maps to a given path. | |
| Args: | |
| path: Output path. | |
| maps: Named maps of tensor images. | |
| ''' | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| for name, image in maps.items(): | |
| out_img = create_img(image) | |
| out_img.save(os.path.join(path, name+".png")) | |
| def load_torch_file(ckpt, device=None): | |
| if device is None: | |
| device = torch.device("cpu") | |
| if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): | |
| with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: | |
| state_dict = {} | |
| for k in f.keys(): | |
| tensor = f.get_tensor(k) | |
| state_dict[k] = tensor | |
| else: | |
| torch_args = {} | |
| ckpt = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) | |
| if "state_dict" in ckpt: | |
| state_dict = ckpt["state_dict"] | |
| else: | |
| state_dict = ckpt | |
| return state_dict |