Spaces:
Runtime error
Runtime error
| from typing import Protocol, runtime_checkable | |
| import torch | |
| from einops import rearrange, reduce | |
| from jaxtyping import Bool, Float | |
| from torch import Tensor | |
| class ColorFunction(Protocol): | |
| def __call__( | |
| self, | |
| xy: Float[Tensor, "point 2"], | |
| ) -> Float[Tensor, "point 4"]: # RGBA color | |
| pass | |
| def generate_sample_grid( | |
| shape: tuple[int, int], | |
| device: torch.device, | |
| ) -> Float[Tensor, "height width 2"]: | |
| h, w = shape | |
| x = torch.arange(w, device=device) + 0.5 | |
| y = torch.arange(h, device=device) + 0.5 | |
| x, y = torch.meshgrid(x, y, indexing="xy") | |
| return torch.stack([x, y], dim=-1) | |
| def detect_msaa_pixels( | |
| image: Float[Tensor, "batch 4 height width"], | |
| ) -> Bool[Tensor, "batch height width"]: | |
| b, _, h, w = image.shape | |
| mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) | |
| # Detect horizontal differences. | |
| horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) | |
| mask[:, :, 1:] |= horizontal | |
| mask[:, :, :-1] |= horizontal | |
| # Detect vertical differences. | |
| vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) | |
| mask[:, 1:, :] |= vertical | |
| mask[:, :-1, :] |= vertical | |
| # Detect diagonal (top left to bottom right) differences. | |
| tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) | |
| mask[:, 1:, 1:] |= tlbr | |
| mask[:, :-1, :-1] |= tlbr | |
| # Detect diagonal (top right to bottom left) differences. | |
| trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) | |
| mask[:, :-1, 1:] |= trbl | |
| mask[:, 1:, :-1] |= trbl | |
| return mask | |
| def reduce_straight_alpha( | |
| rgba: Float[Tensor, "batch 4 height width"], | |
| ) -> Float[Tensor, "batch 4"]: | |
| color, alpha = rgba.split((3, 1), dim=1) | |
| # Color becomes a weighted average of color (weighted by alpha). | |
| weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") | |
| alpha_sum = reduce(alpha, "b c h w -> b c", "sum") | |
| color = weighted_color / (alpha_sum + 1e-10) | |
| # Alpha becomes mean alpha. | |
| alpha = reduce(alpha, "b c h w -> b c", "mean") | |
| return torch.cat((color, alpha), dim=-1) | |
| def run_msaa_pass( | |
| xy: Float[Tensor, "batch height width 2"], | |
| color_function: ColorFunction, | |
| scale: float, | |
| subdivision: int, | |
| remaining_passes: int, | |
| device: torch.device, | |
| batch_size: int = int(2**16), | |
| ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) | |
| # Sample the color function. | |
| b, h, w, _ = xy.shape | |
| color = [ | |
| color_function(batch) | |
| for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) | |
| ] | |
| color = torch.cat(color, dim=0) | |
| color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) | |
| # If any MSAA passes remain, subdivide. | |
| if remaining_passes > 0: | |
| mask = detect_msaa_pixels(color) | |
| batch_index, row_index, col_index = torch.where(mask) | |
| xy = xy[batch_index, row_index, col_index] | |
| offsets = generate_sample_grid((subdivision, subdivision), device) | |
| offsets = (offsets / subdivision - 0.5) * scale | |
| color_fine = run_msaa_pass( | |
| xy[:, None, None] + offsets, | |
| color_function, | |
| scale / subdivision, | |
| subdivision, | |
| remaining_passes - 1, | |
| device, | |
| batch_size=batch_size, | |
| ) | |
| color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) | |
| return color | |
| def render( | |
| shape: tuple[int, int], | |
| color_function: ColorFunction, | |
| device: torch.device, | |
| subdivision: int = 8, | |
| num_passes: int = 2, | |
| ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) | |
| xy = generate_sample_grid(shape, device) | |
| return run_msaa_pass( | |
| xy[None], | |
| color_function, | |
| 1.0, | |
| subdivision, | |
| num_passes, | |
| device, | |
| )[0] | |
| def render_over_image( | |
| image: Float[Tensor, "3 height width"], | |
| color_function: ColorFunction, | |
| device: torch.device, | |
| subdivision: int = 8, | |
| num_passes: int = 1, | |
| ) -> Float[Tensor, "3 height width"]: | |
| _, h, w = image.shape | |
| overlay = render( | |
| (h, w), | |
| color_function, | |
| device, | |
| subdivision=subdivision, | |
| num_passes=num_passes, | |
| ) | |
| color, alpha = overlay.split((3, 1), dim=0) | |
| return image * (1 - alpha) + color * alpha | |