import gradio as gr import numpy as np from PIL import Image, ImageDraw class ImageMask(gr.components.Image): """ Image component that behaves like the old canvas+sketch image with mask. We manage our own `tool`/`source` flags instead of relying on Gradio kwargs. """ is_template = True def __init__(self, **kwargs): # Gradio >= 4/5/6: `source` and `tool` are not valid kwargs to Image.__init__. # Make the image non-interactive by default; manage behaviour ourselves. super().__init__( interactive=False, **kwargs, ) # Emulate the old attributes used elsewhere in this repo. self.tool = "sketch" if not hasattr(self, "source"): self.source = "upload" def preprocess(self, x): if x is None: return x current_source = getattr(self, "source", "upload") # For uploads / webcam frames, wrap into {image, mask} dict as expected. if ( self.tool == "sketch" and current_source in ["upload", "webcam"] and type(x) != dict ): decode_image = gr.processing_utils.decode_base64_to_image(x) width, height = decode_image.size mask = np.ones((height, width, 4), dtype=np.uint8) mask[..., -1] = 255 mask = self.postprocess(mask) x = {"image": x, "mask": mask} return super().preprocess(x) def get_valid_mask(mask: np.ndarray): """Convert mask from gr.Image(0 to 255, RGBA) to binary mask.""" if mask.ndim == 3: mask_pil = Image.fromarray(mask).convert("L") mask = np.array(mask_pil) if mask.max() == 255: mask = mask / 255 return mask def draw_points_on_image( image, points, curr_point=None, highlight_all=True, radius_scale=0.01, ): overlay_rgba = Image.new("RGBA", image.size, 0) overlay_draw = ImageDraw.Draw(overlay_rgba) for point_key, point in points.items(): if (curr_point is not None and curr_point == point_key) or highlight_all: p_color = (255, 0, 0) t_color = (0, 0, 255) else: p_color = (255, 0, 0, 35) t_color = (0, 0, 255, 35) rad_draw = int(image.size[0] * radius_scale) p_start = point.get("start_temp", point["start"]) p_target = point["target"] if p_start is not None and p_target is not None: p_draw = int(p_start[0]), int(p_start[1]) t_draw = int(p_target[0]), int(p_target[1]) overlay_draw.line( (p_draw[0], p_draw[1], t_draw[0], t_draw[1]), fill=(255, 255, 0), width=2, ) if p_start is not None: p_draw = int(p_start[0]), int(p_start[1]) overlay_draw.ellipse( ( p_draw[0] - rad_draw, p_draw[1] - rad_draw, p_draw[0] + rad_draw, p_draw[1] + rad_draw, ), fill=p_color, ) if curr_point is not None and curr_point == point_key: overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0)) if p_target is not None: t_draw = int(p_target[0]), int(p_target[1]) overlay_draw.ellipse( ( t_draw[0] - rad_draw, t_draw[1] - rad_draw, t_draw[0] + rad_draw, t_draw[1] + rad_draw, ), fill=t_color, ) if curr_point is not None and curr_point == point_key: overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0)) return Image.alpha_composite(image.convert("RGBA"), overlay_rgba).convert("RGB") def draw_mask_on_image(image, mask): im_mask = np.uint8(mask * 255) im_mask_rgba = np.concatenate( ( np.tile(im_mask[..., None], [1, 1, 3]), 45 * np.ones( (im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8, ), ), axis=-1, ) im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA") return Image.alpha_composite(image.convert("RGBA"), im_mask_rgba).convert("RGB") def on_change_single_global_state( keys, value, global_state, map_transform=None, ): if map_transform is not None: value = map_transform(value) curr_state = global_state if isinstance(keys, str): last_key = keys else: for k in keys[:-1]: curr_state = curr_state[k] last_key = keys[-1] curr_state[last_key] = value return global_state def get_latest_points_pair(points_dict): if not points_dict: return None point_idx = list(points_dict.keys()) latest_point_idx = max(point_idx) return latest_point_idx