""" SAM2 Click Refinement RL Environment """ import gymnasium as gym from gymnasium import spaces import numpy as np import torch from PIL import Image from scipy.ndimage import distance_transform_edt def compute_dice(pred_mask, gt_mask): pred = pred_mask.astype(bool) gt = gt_mask.astype(bool) intersection = (pred & gt).sum() total = pred.sum() + gt.sum() if total == 0: return 1.0 if intersection == 0 else 0.0 return float(2 * intersection / (total + 1e-8)) def compute_iou(pred_mask, gt_mask): pred = pred_mask.astype(bool) gt = gt_mask.astype(bool) intersection = (pred & gt).sum() union = (pred | gt).sum() if union == 0: return 1.0 return float(intersection / (union + 1e-8)) def oracle_click(pred_mask, gt_mask, noise_range=3): pred = pred_mask.astype(bool) gt = gt_mask.astype(bool) fn_mask = (~pred) & gt fp_mask = pred & (~gt) fn_dist = distance_transform_edt(fn_mask) if fn_mask.any() else np.zeros_like(pred_mask, dtype=float) fp_dist = distance_transform_edt(fp_mask) if fp_mask.any() else np.zeros_like(pred_mask, dtype=float) if fn_dist.max() >= fp_dist.max() and fn_mask.any(): click_pos = np.unravel_index(fn_dist.argmax(), fn_dist.shape) label = 1 elif fp_mask.any(): click_pos = np.unravel_index(fp_dist.argmax(), fp_dist.shape) label = 0 else: if gt.any(): gt_dist = distance_transform_edt(gt) click_pos = np.unravel_index(gt_dist.argmax(), gt_dist.shape) label = 1 else: click_pos = (pred_mask.shape[0] // 2, pred_mask.shape[1] // 2) label = 0 if noise_range > 0: noise = np.random.randint(-noise_range, noise_range + 1, size=2) click_pos = ( int(np.clip(click_pos[0] + noise[0], 0, pred_mask.shape[0] - 1)), int(np.clip(click_pos[1] + noise[1], 0, pred_mask.shape[1] - 1)), ) return click_pos[0], click_pos[1], label def binary_erosion_safe(mask): from scipy.ndimage import binary_erosion if not mask.any(): return mask result = binary_erosion(mask, iterations=2) if not result.any(): return mask return result class SAM2ClickEnv(gym.Env): metadata = {"render_modes": ["rgb_array"]} def __init__(self, dataset=None, sam_predictor=None, obs_size=128, grid_size=32, max_clicks=5, click_radius=3, boundary_reward_weight=0.3, initial_click_noise=5, use_sam=True, split="train", render_mode=None): super().__init__() self.dataset = dataset self.sam_predictor = sam_predictor self.obs_size = obs_size self.grid_size = grid_size self.max_clicks = max_clicks self.click_radius = click_radius self.boundary_reward_weight = boundary_reward_weight self.initial_click_noise = initial_click_noise self.use_sam = use_sam self.render_mode = render_mode self.observation_space = spaces.Box(low=0, high=255, shape=(obs_size, obs_size, 6), dtype=np.uint8) self.action_space = spaces.Discrete(grid_size * grid_size * 2) self.current_image = None self.current_gt = None self.current_mask = None self.click_coords = [] self.click_labels = [] self.prev_dice = 0.0 self.n_clicks = 0 self.orig_h = 0 self.orig_w = 0 self._dataset_indices = None self._current_idx = 0 def _load_random_sample(self): if self._dataset_indices is None: self._dataset_indices = np.random.permutation(len(self.dataset)) self._current_idx = 0 idx = int(self._dataset_indices[self._current_idx]) self._current_idx = (self._current_idx + 1) % len(self._dataset_indices) if self._current_idx == 0: self._dataset_indices = np.random.permutation(len(self.dataset)) sample = self.dataset[idx] image = sample["image"] mask = sample["mask"] if isinstance(image, Image.Image): image = image.convert("RGB") image = np.array(image) if isinstance(mask, Image.Image): mask = mask.convert("L") mask = np.array(mask) mask = (mask > 127).astype(np.uint8) return image, mask def _resize_for_obs(self, img, is_mask=False): pil_img = Image.fromarray(img) if is_mask: pil_img = pil_img.resize((self.obs_size, self.obs_size), Image.NEAREST) else: pil_img = pil_img.resize((self.obs_size, self.obs_size), Image.BILINEAR) return np.array(pil_img) def _make_click_heatmap(self, clicks_yx, orig_h, orig_w): heatmap = np.zeros((self.obs_size, self.obs_size), dtype=np.uint8) for (y, x) in clicks_yx: obs_y = int(y * self.obs_size / orig_h) obs_x = int(x * self.obs_size / orig_w) obs_y = np.clip(obs_y, 0, self.obs_size - 1) obs_x = np.clip(obs_x, 0, self.obs_size - 1) for dy in range(-self.click_radius, self.click_radius+1): for dx in range(-self.click_radius, self.click_radius+1): if dy**2 + dx**2 <= self.click_radius**2: ny, nx = obs_y + dy, obs_x + dx if 0 <= ny < self.obs_size and 0 <= nx < self.obs_size: heatmap[ny, nx] = 255 return heatmap def _get_obs(self): img_resized = self._resize_for_obs(self.current_image) mask_resized = self._resize_for_obs((self.current_mask * 255).astype(np.uint8), is_mask=True) fg_yx = [(y, x) for (x, y), l in zip(self.click_coords, self.click_labels) if l == 1] bg_yx = [(y, x) for (x, y), l in zip(self.click_coords, self.click_labels) if l == 0] fg_heatmap = self._make_click_heatmap(fg_yx, self.orig_h, self.orig_w) bg_heatmap = self._make_click_heatmap(bg_yx, self.orig_h, self.orig_w) obs = np.stack([ img_resized[:, :, 0], img_resized[:, :, 1], img_resized[:, :, 2], mask_resized, fg_heatmap, bg_heatmap, ], axis=-1).astype(np.uint8) return obs def _run_sam(self): if not self.use_sam or self.sam_predictor is None: return self._simulate_mask() coords = np.array(self.click_coords, dtype=np.float32) labels = np.array(self.click_labels, dtype=np.int32) device = "cuda" if torch.cuda.is_available() else "cpu" ctx = torch.autocast(device, dtype=torch.bfloat16) if device == "cuda" else torch.inference_mode() with torch.inference_mode(), ctx: masks, scores, logits = self.sam_predictor.predict( point_coords=coords, point_labels=labels, multimask_output=(len(self.click_coords) == 1), ) if len(masks.shape) == 3 and masks.shape[0] > 1: best_idx = np.argmax(scores) mask = masks[best_idx] else: mask = masks[0] if len(masks.shape) == 3 else masks return mask.astype(np.uint8) def _simulate_mask(self): if not hasattr(self, '_noise_mask'): noise = np.random.random(self.current_gt.shape) < 0.15 self._noise_mask = self.current_gt.copy() from scipy.ndimage import binary_dilation, binary_erosion if np.random.random() < 0.5: self._noise_mask = binary_dilation(self._noise_mask, np.ones((7,7))).astype(np.uint8) else: self._noise_mask = binary_erosion(self._noise_mask, np.ones((5,5))).astype(np.uint8) self._noise_mask = (self._noise_mask ^ noise.astype(np.uint8)).astype(np.uint8) mask = self._noise_mask.copy() for (x, y), label in zip(self.click_coords, self.click_labels): radius = 20 y0, y1 = max(0, int(y)-radius), min(mask.shape[0], int(y)+radius) x0, x1 = max(0, int(x)-radius), min(mask.shape[1], int(x)+radius) mask[y0:y1, x0:x1] = self.current_gt[y0:y1, x0:x1] return mask def reset(self, seed=None, options=None): super().reset(seed=seed) self.current_image, self.current_gt = self._load_random_sample() self.orig_h, self.orig_w = self.current_image.shape[:2] self.click_coords = [] self.click_labels = [] self.n_clicks = 0 if hasattr(self, '_noise_mask'): del self._noise_mask if self.use_sam and self.sam_predictor is not None: device = "cuda" if torch.cuda.is_available() else "cpu" ctx = torch.autocast(device, dtype=torch.bfloat16) if device == "cuda" else torch.inference_mode() with torch.inference_mode(), ctx: self.sam_predictor.set_image(self.current_image) init_y, init_x, init_label = oracle_click( np.zeros_like(self.current_gt), self.current_gt, noise_range=self.initial_click_noise ) self.click_coords.append((int(init_x), int(init_y))) self.click_labels.append(init_label) self.current_mask = self._run_sam() self.prev_dice = compute_dice(self.current_mask, self.current_gt) return self._get_obs(), {"dice": self.prev_dice, "iou": compute_iou(self.current_mask, self.current_gt)} def step(self, action): label = action % 2 pos = action // 2 grid_y = pos // self.grid_size grid_x = pos % self.grid_size orig_x = int((grid_x + 0.5) * self.orig_w / self.grid_size) orig_y = int((grid_y + 0.5) * self.orig_h / self.grid_size) orig_x = np.clip(orig_x, 0, self.orig_w - 1) orig_y = np.clip(orig_y, 0, self.orig_h - 1) self.click_coords.append((orig_x, orig_y)) self.click_labels.append(1 if label == 0 else 0) self.current_mask = self._run_sam() new_dice = compute_dice(self.current_mask, self.current_gt) delta_dice = new_dice - self.prev_dice boundary_bonus = 0.0 pred = self.current_mask.astype(bool) gt = self.current_gt.astype(bool) error_mask = pred != gt if error_mask.any(): error_dist = distance_transform_edt(~error_mask) click_error_dist = error_dist[orig_y, orig_x] if click_error_dist < 10: boundary_bonus = 0.05 * (1.0 - click_error_dist / 10.0) reward = delta_dice + self.boundary_reward_weight * boundary_bonus self.prev_dice = new_dice self.n_clicks += 1 terminated = self.n_clicks >= self.max_clicks truncated = False if new_dice > 0.95 and not terminated: reward += 0.1 info = { "dice": new_dice, "iou": compute_iou(self.current_mask, self.current_gt), "delta_dice": delta_dice, "n_clicks": self.n_clicks + 1, "click_pos": (orig_x, orig_y), "click_label": self.click_labels[-1], } return self._get_obs(), float(reward), terminated, truncated, info def render(self): if self.render_mode == "rgb_array": return self._get_obs()[:, :, :3] return None