| from typing import Protocol, runtime_checkable |
|
|
| import torch |
| from einops import rearrange, reduce |
| from jaxtyping import Bool, Float |
| from torch import Tensor |
|
|
|
|
| @runtime_checkable |
| class ColorFunction(Protocol): |
| def __call__( |
| self, |
| xy: Float[Tensor, "point 2"], |
| ) -> Float[Tensor, "point 4"]: |
| pass |
|
|
|
|
| def generate_sample_grid( |
| shape: tuple[int, int], |
| device: torch.device, |
| ) -> Float[Tensor, "height width 2"]: |
| h, w = shape |
| x = torch.arange(w, device=device) + 0.5 |
| y = torch.arange(h, device=device) + 0.5 |
| x, y = torch.meshgrid(x, y, indexing="xy") |
| return torch.stack([x, y], dim=-1) |
|
|
|
|
| def detect_msaa_pixels( |
| image: Float[Tensor, "batch 4 height width"], |
| ) -> Bool[Tensor, "batch height width"]: |
| b, _, h, w = image.shape |
|
|
| mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) |
|
|
| |
| horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) |
| mask[:, :, 1:] |= horizontal |
| mask[:, :, :-1] |= horizontal |
|
|
| |
| vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) |
| mask[:, 1:, :] |= vertical |
| mask[:, :-1, :] |= vertical |
|
|
| |
| tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) |
| mask[:, 1:, 1:] |= tlbr |
| mask[:, :-1, :-1] |= tlbr |
|
|
| |
| trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) |
| mask[:, :-1, 1:] |= trbl |
| mask[:, 1:, :-1] |= trbl |
|
|
| return mask |
|
|
|
|
| def reduce_straight_alpha( |
| rgba: Float[Tensor, "batch 4 height width"], |
| ) -> Float[Tensor, "batch 4"]: |
| color, alpha = rgba.split((3, 1), dim=1) |
|
|
| |
| weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") |
| alpha_sum = reduce(alpha, "b c h w -> b c", "sum") |
| color = weighted_color / (alpha_sum + 1e-10) |
|
|
| |
| alpha = reduce(alpha, "b c h w -> b c", "mean") |
|
|
| return torch.cat((color, alpha), dim=-1) |
|
|
|
|
| @torch.no_grad() |
| def run_msaa_pass( |
| xy: Float[Tensor, "batch height width 2"], |
| color_function: ColorFunction, |
| scale: float, |
| subdivision: int, |
| remaining_passes: int, |
| device: torch.device, |
| batch_size: int = int(2**16), |
| ) -> Float[Tensor, "batch 4 height width"]: |
| |
| b, h, w, _ = xy.shape |
| color = [ |
| color_function(batch) |
| for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) |
| ] |
| color = torch.cat(color, dim=0) |
| color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) |
|
|
| |
| if remaining_passes > 0: |
| mask = detect_msaa_pixels(color) |
| batch_index, row_index, col_index = torch.where(mask) |
| xy = xy[batch_index, row_index, col_index] |
|
|
| offsets = generate_sample_grid((subdivision, subdivision), device) |
| offsets = (offsets / subdivision - 0.5) * scale |
|
|
| color_fine = run_msaa_pass( |
| xy[:, None, None] + offsets, |
| color_function, |
| scale / subdivision, |
| subdivision, |
| remaining_passes - 1, |
| device, |
| batch_size=batch_size, |
| ) |
| color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) |
|
|
| return color |
|
|
|
|
| @torch.no_grad() |
| def render( |
| shape: tuple[int, int], |
| color_function: ColorFunction, |
| device: torch.device, |
| subdivision: int = 8, |
| num_passes: int = 2, |
| ) -> Float[Tensor, "4 height width"]: |
| xy = generate_sample_grid(shape, device) |
| return run_msaa_pass( |
| xy[None], |
| color_function, |
| 1.0, |
| subdivision, |
| num_passes, |
| device, |
| )[0] |
|
|
|
|
| def render_over_image( |
| image: Float[Tensor, "3 height width"], |
| color_function: ColorFunction, |
| device: torch.device, |
| subdivision: int = 8, |
| num_passes: int = 1, |
| ) -> Float[Tensor, "3 height width"]: |
| _, h, w = image.shape |
| overlay = render( |
| (h, w), |
| color_function, |
| device, |
| subdivision=subdivision, |
| num_passes=num_passes, |
| ) |
| color, alpha = overlay.split((3, 1), dim=0) |
| return image * (1 - alpha) + color * alpha |
|
|