| import dataclasses |
| import importlib |
| import math |
| from dataclasses import dataclass |
| from typing import Any, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import PIL |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from jaxtyping import Bool, Float, Int, Num |
| from omegaconf import DictConfig, OmegaConf |
| from torch import Tensor |
|
|
|
|
| class BaseModule(nn.Module): |
| @dataclass |
| class Config: |
| pass |
|
|
| cfg: Config |
|
|
| def __init__( |
| self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs |
| ) -> None: |
| super().__init__() |
| self.cfg = parse_structured(self.Config, cfg) |
| self.configure(*args, **kwargs) |
|
|
| def configure(self, *args, **kwargs) -> None: |
| raise NotImplementedError |
|
|
|
|
| def find_class(cls_string): |
| module_string = ".".join(cls_string.split(".")[:-1]) |
| cls_name = cls_string.split(".")[-1] |
| module = importlib.import_module(module_string, package=None) |
| cls = getattr(module, cls_name) |
| return cls |
|
|
|
|
| def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: |
| |
| cfg_ = cfg.copy() |
| keys = list(cfg_.keys()) |
|
|
| field_names = {f.name for f in dataclasses.fields(fields)} |
| for key in keys: |
| |
| if key not in field_names: |
| print(f"Ignoring {key} as it's not supported by {fields}") |
| cfg_.pop(key) |
| scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_) |
| return scfg |
|
|
|
|
| EPS_DTYPE = { |
| torch.float16: 1e-4, |
| torch.bfloat16: 1e-4, |
| torch.float32: 1e-7, |
| torch.float64: 1e-8, |
| } |
|
|
|
|
| def dot(x, y, dim=-1): |
| return torch.sum(x * y, dim, keepdim=True) |
|
|
|
|
| def reflect(x, n): |
| return x - 2 * dot(x, n) * n |
|
|
|
|
| def normalize(x, dim=-1, eps=None): |
| if eps is None: |
| eps = EPS_DTYPE[x.dtype] |
| return F.normalize(x, dim=dim, p=2, eps=eps) |
|
|
|
|
| def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]: |
| |
| tri_sq = F.pad(tri, (0, 1), "constant", 1.0) |
| det_tri = torch.det(tri_sq) |
| tri_rev = torch.cat( |
| (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2 |
| ) |
| tri_sq[det_tri < 0] = tri_rev[det_tri < 0] |
| return tri_sq |
|
|
|
|
| def triangle_intersection_2d( |
| t1: Float[Tensor, "*B 3 2"], |
| t2: Float[Tensor, "*B 3 2"], |
| eps=1e-12, |
| ) -> Float[Tensor, "*B"]: |
| """Returns True if triangles collide, False otherwise""" |
|
|
| def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: |
| logdetx = torch.logdet(x.double()) |
| if eps is None: |
| return ~torch.isfinite(logdetx) |
| return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps))) |
|
|
| t1s = tri_winding(t1) |
| t2s = tri_winding(t2) |
|
|
| |
| ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device) |
| for i in range(3): |
| edge = torch.roll(t1s, i, dims=1)[:, :2, :] |
| |
| |
| upd = ( |
| chk_edge(torch.cat((edge, t2s[:, 0:1]), 1)) |
| & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1)) |
| & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1)) |
| ) |
| |
| ret = ret | upd |
|
|
| for i in range(3): |
| edge = torch.roll(t2s, i, dims=1)[:, :2, :] |
|
|
| upd = ( |
| chk_edge(torch.cat((edge, t1s[:, 0:1]), 1)) |
| & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1)) |
| & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1)) |
| ) |
| |
| ret = ret | upd |
|
|
| return ~ret |
|
|
|
|
| ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] |
|
|
|
|
| def scale_tensor( |
| dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale |
| ): |
| if inp_scale is None: |
| inp_scale = (0, 1) |
| if tgt_scale is None: |
| tgt_scale = (0, 1) |
| if isinstance(tgt_scale, Tensor): |
| assert dat.shape[-1] == tgt_scale.shape[-1] |
| dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) |
| dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] |
| return dat |
|
|
|
|
| def dilate_fill(img, mask, iterations=10): |
| oldMask = mask.float() |
| oldImg = img |
|
|
| mask_kernel = torch.ones( |
| (1, 1, 3, 3), |
| dtype=oldMask.dtype, |
| device=oldMask.device, |
| ) |
|
|
| for i in range(iterations): |
| newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1) |
|
|
| |
| img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1) |
| mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1) |
| new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1) |
|
|
| |
| mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze( |
| 2 |
| ) |
| |
| fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1) |
|
|
| mask_conv = F.conv2d( |
| newMask, mask_kernel, padding=1 |
| ) |
| newImg = F.fold( |
| fill_color, (img.shape[-2], img.shape[-1]), (3, 3) |
| ) / mask_conv.clamp(1) |
|
|
| diffMask = newMask - oldMask |
|
|
| oldMask = newMask |
| oldImg = torch.lerp(oldImg, newImg, diffMask) |
|
|
| return oldImg |
|
|
|
|
| def float32_to_uint8_np( |
| x: Float[np.ndarray, "*B H W C"], |
| dither: bool = True, |
| dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None, |
| dither_strength: float = 1.0, |
| ) -> Int[np.ndarray, "*B H W C"]: |
| if dither: |
| dither = ( |
| dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5 |
| ) |
| if dither_mask is not None: |
| dither = dither * dither_mask |
| return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8) |
| return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8) |
|
|
|
|
| def convert_data(data): |
| if data is None: |
| return None |
| elif isinstance(data, np.ndarray): |
| return data |
| elif isinstance(data, torch.Tensor): |
| if data.dtype in [torch.float16, torch.bfloat16]: |
| data = data.float() |
| return data.detach().cpu().numpy() |
| elif isinstance(data, list): |
| return [convert_data(d) for d in data] |
| elif isinstance(data, dict): |
| return {k: convert_data(v) for k, v in data.items()} |
| else: |
| raise TypeError( |
| "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", |
| type(data), |
| ) |
|
|
|
|
| class ImageProcessor: |
| def convert_and_resize( |
| self, |
| image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], |
| size: int, |
| ): |
| if isinstance(image, PIL.Image.Image): |
| image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0) |
| elif isinstance(image, np.ndarray): |
| if image.dtype == np.uint8: |
| image = torch.from_numpy(image.astype(np.float32) / 255.0) |
| else: |
| image = torch.from_numpy(image) |
| elif isinstance(image, torch.Tensor): |
| pass |
|
|
| batched = image.ndim == 4 |
|
|
| if not batched: |
| image = image[None, ...] |
| image = F.interpolate( |
| image.permute(0, 3, 1, 2), |
| (size, size), |
| mode="bilinear", |
| align_corners=False, |
| antialias=True, |
| ).permute(0, 2, 3, 1) |
| if not batched: |
| image = image[0] |
| return image |
|
|
| def __call__( |
| self, |
| image: Union[ |
| PIL.Image.Image, |
| np.ndarray, |
| torch.FloatTensor, |
| List[PIL.Image.Image], |
| List[np.ndarray], |
| List[torch.FloatTensor], |
| ], |
| size: int, |
| ) -> Any: |
| if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4: |
| image = self.convert_and_resize(image, size) |
| else: |
| if not isinstance(image, list): |
| image = [image] |
| image = [self.convert_and_resize(im, size) for im in image] |
| image = torch.stack(image, dim=0) |
| return image |
|
|
|
|
| def get_intrinsic_from_fov(fov, H, W, bs=-1): |
| focal_length = 0.5 * H / np.tan(0.5 * fov) |
| intrinsic = np.identity(3, dtype=np.float32) |
| intrinsic[0, 0] = focal_length |
| intrinsic[1, 1] = focal_length |
| intrinsic[0, 2] = W / 2.0 |
| intrinsic[1, 2] = H / 2.0 |
|
|
| if bs > 0: |
| intrinsic = intrinsic[None].repeat(bs, axis=0) |
|
|
| return torch.from_numpy(intrinsic) |
|
|