Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import PIL | |
| import math | |
| def tensor_to_PIL(img: torch.Tensor) -> PIL.Image.Image: | |
| """ | |
| Converts a tensor image to a PIL Image. | |
| Args: | |
| img (torch.Tensor): The tensor image of shape [batch_size, num_channels, height, width]. | |
| Returns: | |
| A PIL Image object. | |
| """ | |
| img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
| return PIL.Image.fromarray(img[0].cpu().numpy(), "RGB") | |
| def get_ellipse_coords( | |
| point: Tuple[int, int], radius: int = 5 | |
| ) -> Tuple[int, int, int, int]: | |
| """ | |
| Returns the coordinates of an ellipse centered at the given point. | |
| Args: | |
| point (Tuple[int, int]): The center point of the ellipse. | |
| radius (int): The radius of the ellipse. | |
| Returns: | |
| A tuple containing the coordinates of the ellipse in the format (x_min, y_min, x_max, y_max). | |
| """ | |
| center = point | |
| return ( | |
| center[0] - radius, | |
| center[1] - radius, | |
| center[0] + radius, | |
| center[1] + radius, | |
| ) | |
| def draw_handle_target_points( | |
| img: PIL.Image.Image, | |
| handle_points: List[Tuple[int, int]], | |
| target_points: List[Tuple[int, int]], | |
| radius: int = 5): | |
| """ | |
| Draws handle and target points with arrow pointing towards the target point. | |
| Args: | |
| img (PIL.Image.Image): The image to draw on. | |
| handle_points (List[Tuple[int, int]]): A list of handle [x,y] points. | |
| target_points (List[Tuple[int, int]]): A list of target [x,y] points. | |
| radius (int): The radius of the handle and target points. | |
| """ | |
| if len(handle_points) == len(target_points) + 1: | |
| target_points.append(None) | |
| draw = PIL.ImageDraw.Draw(img) | |
| for handle_point, target_point in zip(handle_points, target_points): | |
| # Draw the handle point | |
| handle_coords = get_ellipse_coords(handle_point, radius) | |
| draw.ellipse(handle_coords, fill="red") | |
| if target_point: | |
| # Draw the target point | |
| target_coords = get_ellipse_coords(target_point, radius) | |
| draw.ellipse(target_coords, fill="blue") | |
| # Draw arrow head | |
| arrow_head_length = 10.0 | |
| # Compute the direction vector of the line | |
| dx = target_point[0] - handle_point[0] | |
| dy = target_point[1] - handle_point[1] | |
| angle = math.atan2(dy, dx) | |
| # Shorten the target point by the length of the arrowhead | |
| shortened_target_point = ( | |
| target_point[0] - arrow_head_length * math.cos(angle), | |
| target_point[1] - arrow_head_length * math.sin(angle), | |
| ) | |
| # Draw the arrow (main line) | |
| draw.line([tuple(handle_point), shortened_target_point], fill='green', width=2) | |
| # Compute the points for the arrowhead | |
| arrow_point1 = ( | |
| target_point[0] - arrow_head_length * math.cos(angle - math.pi / 6), | |
| target_point[1] - arrow_head_length * math.sin(angle - math.pi / 6), | |
| ) | |
| arrow_point2 = ( | |
| target_point[0] - arrow_head_length * math.cos(angle + math.pi / 6), | |
| target_point[1] - arrow_head_length * math.sin(angle + math.pi / 6), | |
| ) | |
| # Draw the arrowhead | |
| draw.polygon([tuple(target_point), arrow_point1, arrow_point2], fill='green') | |
| # # Draw shifted coordinates handle + d_i | |
| # for points in all_shifted_coordinates: | |
| # if not torch.isnan(points).any(): | |
| # coords = utils.get_ellipse_coords(points.mean(0).flip(-1).cpu().long().numpy().tolist(), 7) | |
| # draw.ellipse(coords, fill="orange") | |
| def create_circular_mask( | |
| h: int, | |
| w: int, | |
| center: Optional[Tuple[int, int]] = None, | |
| radius: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Create a circular mask tensor. | |
| Args: | |
| h (int): The height of the mask tensor. | |
| w (int): The width of the mask tensor. | |
| center (Optional[Tuple[int, int]]): The center of the circle as a tuple (y, x). If None, the middle of the image is used. | |
| radius (Optional[int]): The radius of the circle. If None, the smallest distance between the center and image walls is used. | |
| Returns: | |
| A boolean tensor of shape [h, w] representing the circular mask. | |
| """ | |
| if center is None: # use the middle of the image | |
| center = (int(h / 2), int(w / 2)) | |
| if radius is None: # use the smallest distance between the center and image walls | |
| radius = min(center[0], center[1], h - center[0], w - center[1]) | |
| Y, X = np.ogrid[:h, :w] | |
| dist_from_center = np.sqrt((Y - center[0]) ** 2 + (X - center[1]) ** 2) | |
| mask = dist_from_center <= radius | |
| mask = torch.from_numpy(mask).bool() | |
| return mask | |
| def create_square_mask( | |
| height: int, width: int, center: list, radius: int | |
| ) -> torch.Tensor: | |
| """Create a square mask tensor. | |
| Args: | |
| height (int): The height of the mask. | |
| width (int): The width of the mask. | |
| center (list): The center of the square mask as a list of two integers. Order [y,x] | |
| radius (int): The radius of the square mask. | |
| Returns: | |
| torch.Tensor: The square mask tensor of shape (1, 1, height, width). | |
| Raises: | |
| ValueError: If the center or radius is invalid. | |
| """ | |
| if not isinstance(center, list) or len(center) != 2: | |
| raise ValueError("center must be a list of two integers") | |
| if not isinstance(radius, int) or radius <= 0: | |
| raise ValueError("radius must be a positive integer") | |
| if ( | |
| center[0] < radius | |
| or center[0] >= height - radius | |
| or center[1] < radius | |
| or center[1] >= width - radius | |
| ): | |
| raise ValueError("center and radius must be within the bounds of the mask") | |
| mask = torch.zeros((height, width), dtype=torch.float32) | |
| x1 = int(center[1]) - radius | |
| x2 = int(center[1]) + radius | |
| y1 = int(center[0]) - radius | |
| y2 = int(center[0]) + radius | |
| mask[y1 : y2 + 1, x1 : x2 + 1] = 1.0 | |
| return mask.bool() | |