Spaces:
Runtime error
Runtime error
| from typing import Literal, Optional | |
| import torch | |
| from einops import einsum, repeat | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from .coordinate_conversion import generate_conversions | |
| from .rendering import render_over_image | |
| from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector | |
| def draw_lines( | |
| image: Float[Tensor, "3 height width"], | |
| start: Vector, | |
| end: Vector, | |
| color: Vector, | |
| width: Scalar, | |
| cap: Literal["butt", "round", "square"] = "round", | |
| num_msaa_passes: int = 1, | |
| x_range: Optional[Pair] = None, | |
| y_range: Optional[Pair] = None, | |
| ) -> Float[Tensor, "3 height width"]: | |
| device = image.device | |
| start = sanitize_vector(start, 2, device) | |
| end = sanitize_vector(end, 2, device) | |
| color = sanitize_vector(color, 3, device) | |
| width = sanitize_scalar(width, device) | |
| (num_lines,) = torch.broadcast_shapes( | |
| start.shape[0], | |
| end.shape[0], | |
| color.shape[0], | |
| width.shape, | |
| ) | |
| # Convert world-space points to pixel space. | |
| _, h, w = image.shape | |
| world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) | |
| start = world_to_pixel(start) | |
| end = world_to_pixel(end) | |
| def color_function( | |
| xy: Float[Tensor, "point 2"], | |
| ) -> Float[Tensor, "point 4"]: | |
| # Define a vector between the start and end points. | |
| delta = end - start | |
| delta_norm = delta.norm(dim=-1, keepdim=True) | |
| u_delta = delta / delta_norm | |
| # Define a vector between each sample and the start point. | |
| indicator = xy - start[:, None] | |
| # Determine whether each sample is inside the line in the parallel direction. | |
| extra = 0.5 * width[:, None] if cap == "square" else 0 | |
| parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") | |
| parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) | |
| # Determine whether each sample is inside the line perpendicularly. | |
| perpendicular = indicator - parallel[..., None] * u_delta[:, None] | |
| perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] | |
| inside_line = parallel_inside_line & perpendicular_inside_line | |
| # Compute round caps. | |
| if cap == "round": | |
| near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] | |
| inside_line |= near_start | |
| end_indicator = indicator = xy - end[:, None] | |
| near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] | |
| inside_line |= near_end | |
| # Determine the sample's color. | |
| selectable_color = color.broadcast_to((num_lines, 3)) | |
| arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] | |
| top_color = selectable_color.gather( | |
| dim=0, | |
| index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), | |
| ) | |
| rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) | |
| return rgba | |
| return render_over_image(image, color_function, device, num_passes=num_msaa_passes) | |