| | from typing import Protocol, runtime_checkable |
| |
|
| | import torch |
| | from einops import rearrange, reduce |
| | from jaxtyping import Bool, Float |
| | from torch import Tensor |
| |
|
| |
|
| | @runtime_checkable |
| | class ColorFunction(Protocol): |
| | def __call__( |
| | self, |
| | xy: Float[Tensor, "point 2"], |
| | ) -> Float[Tensor, "point 4"]: |
| | pass |
| |
|
| |
|
| | def generate_sample_grid( |
| | shape: tuple[int, int], |
| | device: torch.device, |
| | ) -> Float[Tensor, "height width 2"]: |
| | h, w = shape |
| | x = torch.arange(w, device=device) + 0.5 |
| | y = torch.arange(h, device=device) + 0.5 |
| | x, y = torch.meshgrid(x, y, indexing="xy") |
| | return torch.stack([x, y], dim=-1) |
| |
|
| |
|
| | def detect_msaa_pixels( |
| | image: Float[Tensor, "batch 4 height width"], |
| | ) -> Bool[Tensor, "batch height width"]: |
| | b, _, h, w = image.shape |
| |
|
| | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) |
| |
|
| | |
| | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) |
| | mask[:, :, 1:] |= horizontal |
| | mask[:, :, :-1] |= horizontal |
| |
|
| | |
| | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) |
| | mask[:, 1:, :] |= vertical |
| | mask[:, :-1, :] |= vertical |
| |
|
| | |
| | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) |
| | mask[:, 1:, 1:] |= tlbr |
| | mask[:, :-1, :-1] |= tlbr |
| |
|
| | |
| | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) |
| | mask[:, :-1, 1:] |= trbl |
| | mask[:, 1:, :-1] |= trbl |
| |
|
| | return mask |
| |
|
| |
|
| | def reduce_straight_alpha( |
| | rgba: Float[Tensor, "batch 4 height width"], |
| | ) -> Float[Tensor, "batch 4"]: |
| | color, alpha = rgba.split((3, 1), dim=1) |
| |
|
| | |
| | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") |
| | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") |
| | color = weighted_color / (alpha_sum + 1e-10) |
| |
|
| | |
| | alpha = reduce(alpha, "b c h w -> b c", "mean") |
| |
|
| | return torch.cat((color, alpha), dim=-1) |
| |
|
| |
|
| | @torch.no_grad() |
| | def run_msaa_pass( |
| | xy: Float[Tensor, "batch height width 2"], |
| | color_function: ColorFunction, |
| | scale: float, |
| | subdivision: int, |
| | remaining_passes: int, |
| | device: torch.device, |
| | batch_size: int = int(2**16), |
| | ) -> Float[Tensor, "batch 4 height width"]: |
| | |
| | b, h, w, _ = xy.shape |
| | color = [ |
| | color_function(batch) |
| | for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) |
| | ] |
| | color = torch.cat(color, dim=0) |
| | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) |
| |
|
| | |
| | if remaining_passes > 0: |
| | mask = detect_msaa_pixels(color) |
| | batch_index, row_index, col_index = torch.where(mask) |
| | xy = xy[batch_index, row_index, col_index] |
| |
|
| | offsets = generate_sample_grid((subdivision, subdivision), device) |
| | offsets = (offsets / subdivision - 0.5) * scale |
| |
|
| | color_fine = run_msaa_pass( |
| | xy[:, None, None] + offsets, |
| | color_function, |
| | scale / subdivision, |
| | subdivision, |
| | remaining_passes - 1, |
| | device, |
| | batch_size=batch_size, |
| | ) |
| | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) |
| |
|
| | return color |
| |
|
| |
|
| | @torch.no_grad() |
| | def render( |
| | shape: tuple[int, int], |
| | color_function: ColorFunction, |
| | device: torch.device, |
| | subdivision: int = 8, |
| | num_passes: int = 2, |
| | ) -> Float[Tensor, "4 height width"]: |
| | xy = generate_sample_grid(shape, device) |
| | return run_msaa_pass( |
| | xy[None], |
| | color_function, |
| | 1.0, |
| | subdivision, |
| | num_passes, |
| | device, |
| | )[0] |
| |
|
| |
|
| | def render_over_image( |
| | image: Float[Tensor, "3 height width"], |
| | color_function: ColorFunction, |
| | device: torch.device, |
| | subdivision: int = 8, |
| | num_passes: int = 1, |
| | ) -> Float[Tensor, "3 height width"]: |
| | _, h, w = image.shape |
| | overlay = render( |
| | (h, w), |
| | color_function, |
| | device, |
| | subdivision=subdivision, |
| | num_passes=num_passes, |
| | ) |
| | color, alpha = overlay.split((3, 1), dim=0) |
| | return image * (1 - alpha) + color * alpha |
| |
|