ksangk's picture
switch to safetensors
771c988
import torch
import imageio.v3 as imageio
import numpy as np
import warnings
import os
import safetensors
import torchvision.transforms.functional as F
def read_image(filename: str, out: torch.Tensor=None) -> torch.Tensor:
'''
Read a local image file into a float tensor (pixel values are normalized to [0, 1], CxHxW)
Args:
filename: Image file path.
out: Fill in this tensor rather than return a new tensor if provided.
Returns:
Loaded image tensor.
'''
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore PIL's user warning that reads fp16 img as fp32
img: np.ndarray = imageio.imread(filename)
# Convert the image array to float tensor according to its data type
res = None
if img.dtype == np.uint8:
img = img.astype(np.float32) / 255.0
elif img.dtype == np.uint16 or img.dtype == np.int32:
img = img.astype(np.float32) / 65535.0
else:
raise ValueError(f'Unrecognized image pixel value type: {img.dtype}')
if img.ndim == 2:
res = torch.from_numpy(img).unsqueeze(0) # 1xHxW for grayscale images
elif img.ndim == 3:
res = torch.from_numpy(img).movedim(2, 0)[:3] # HxWxC to CxHxW
else:
raise ValueError(f'Unrecognized image dimension: {img.shape}')
if out is None:
return res
out.copy_(res)
def create_img(img: torch.Tensor):
'''
Convert tensor to PIL image
Args:
path: Image tensor CxHxW. Squeeze if BxCxHxW and B==1
Returns:
PIL image
'''
if img.dim() == 4:
assert img.shape[0] == 1
img = img.squeeze(0)
if img.shape[0] == 4:
out_img = F.to_pil_image(img, mode="CMYK")
out_img = out_img.convert('RGB')
elif img.shape[0] == 3:
out_img = F.to_pil_image(img, mode="RGB")
elif img.shape[0] == 1:
out_img = F.to_pil_image(img, mode="L")
else:
raise ValueError("Unsupported image dimension.")
return out_img
def save_maps(path: str, maps: dict):
'''
Save SVBRDF maps to a given path.
Args:
path: Output path.
maps: Named maps of tensor images.
'''
if not os.path.exists(path):
os.makedirs(path)
for name, image in maps.items():
out_img = create_img(image)
out_img.save(os.path.join(path, name+".png"))
def load_torch_file(ckpt, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
state_dict = {}
for k in f.keys():
tensor = f.get_tensor(k)
state_dict[k] = tensor
else:
torch_args = {}
ckpt = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
if "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
else:
state_dict = ckpt
return state_dict