Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| from jaxtyping import Float | |
| def read_img(path): | |
| img = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) | |
| if img.ndim == 3: | |
| img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB) | |
| elif img.ndim == 2: | |
| img = img[..., np.newaxis] | |
| dinfo = np.iinfo(img.dtype) | |
| return (img.astype(np.float32) / dinfo.max) * 2 - 1 | |
| def write_img(path: str, data: np.ndarray): | |
| data = np.clip(data * 0.5 + 0.5, 0, 1) | |
| if data.ndim == 3 and data.shape[-1] == 3: | |
| data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR) | |
| elif data.ndim == 2: | |
| data = data[..., np.newaxis] | |
| data = (data * 255).astype(np.uint8) | |
| cv2.imwrite(path, data) | |
| def to_torch(img: Float[np.ndarray, "H W C"]) -> Float[torch.Tensor, "C H W"]: | |
| return torch.from_numpy(img).permute(2, 0, 1) | |
| def from_torch(img: Float[torch.Tensor, "C H W"]) -> Float[np.ndarray, "H W C"]: | |
| return img.permute(1, 2, 0).detach().cpu().float().numpy() | |