|
|
| """ |
| SAM2 Click Refinement RL Environment |
| """ |
|
|
| import gymnasium as gym |
| from gymnasium import spaces |
| import numpy as np |
| import torch |
| from PIL import Image |
| from scipy.ndimage import distance_transform_edt |
|
|
|
|
| def compute_dice(pred_mask, gt_mask): |
| pred = pred_mask.astype(bool) |
| gt = gt_mask.astype(bool) |
| intersection = (pred & gt).sum() |
| total = pred.sum() + gt.sum() |
| if total == 0: |
| return 1.0 if intersection == 0 else 0.0 |
| return float(2 * intersection / (total + 1e-8)) |
|
|
|
|
| def compute_iou(pred_mask, gt_mask): |
| pred = pred_mask.astype(bool) |
| gt = gt_mask.astype(bool) |
| intersection = (pred & gt).sum() |
| union = (pred | gt).sum() |
| if union == 0: |
| return 1.0 |
| return float(intersection / (union + 1e-8)) |
|
|
|
|
| def oracle_click(pred_mask, gt_mask, noise_range=3): |
| pred = pred_mask.astype(bool) |
| gt = gt_mask.astype(bool) |
| |
| fn_mask = (~pred) & gt |
| fp_mask = pred & (~gt) |
| |
| fn_dist = distance_transform_edt(fn_mask) if fn_mask.any() else np.zeros_like(pred_mask, dtype=float) |
| fp_dist = distance_transform_edt(fp_mask) if fp_mask.any() else np.zeros_like(pred_mask, dtype=float) |
| |
| if fn_dist.max() >= fp_dist.max() and fn_mask.any(): |
| click_pos = np.unravel_index(fn_dist.argmax(), fn_dist.shape) |
| label = 1 |
| elif fp_mask.any(): |
| click_pos = np.unravel_index(fp_dist.argmax(), fp_dist.shape) |
| label = 0 |
| else: |
| if gt.any(): |
| gt_dist = distance_transform_edt(gt) |
| click_pos = np.unravel_index(gt_dist.argmax(), gt_dist.shape) |
| label = 1 |
| else: |
| click_pos = (pred_mask.shape[0] // 2, pred_mask.shape[1] // 2) |
| label = 0 |
| |
| if noise_range > 0: |
| noise = np.random.randint(-noise_range, noise_range + 1, size=2) |
| click_pos = ( |
| int(np.clip(click_pos[0] + noise[0], 0, pred_mask.shape[0] - 1)), |
| int(np.clip(click_pos[1] + noise[1], 0, pred_mask.shape[1] - 1)), |
| ) |
| |
| return click_pos[0], click_pos[1], label |
|
|
|
|
| def binary_erosion_safe(mask): |
| from scipy.ndimage import binary_erosion |
| if not mask.any(): |
| return mask |
| result = binary_erosion(mask, iterations=2) |
| if not result.any(): |
| return mask |
| return result |
|
|
|
|
| class SAM2ClickEnv(gym.Env): |
| metadata = {"render_modes": ["rgb_array"]} |
| |
| def __init__(self, dataset=None, sam_predictor=None, obs_size=128, grid_size=32, |
| max_clicks=5, click_radius=3, boundary_reward_weight=0.3, |
| initial_click_noise=5, use_sam=True, split="train", render_mode=None): |
| super().__init__() |
| self.dataset = dataset |
| self.sam_predictor = sam_predictor |
| self.obs_size = obs_size |
| self.grid_size = grid_size |
| self.max_clicks = max_clicks |
| self.click_radius = click_radius |
| self.boundary_reward_weight = boundary_reward_weight |
| self.initial_click_noise = initial_click_noise |
| self.use_sam = use_sam |
| self.render_mode = render_mode |
| |
| self.observation_space = spaces.Box(low=0, high=255, shape=(obs_size, obs_size, 6), dtype=np.uint8) |
| self.action_space = spaces.Discrete(grid_size * grid_size * 2) |
| |
| self.current_image = None |
| self.current_gt = None |
| self.current_mask = None |
| self.click_coords = [] |
| self.click_labels = [] |
| self.prev_dice = 0.0 |
| self.n_clicks = 0 |
| self.orig_h = 0 |
| self.orig_w = 0 |
| self._dataset_indices = None |
| self._current_idx = 0 |
| |
| def _load_random_sample(self): |
| if self._dataset_indices is None: |
| self._dataset_indices = np.random.permutation(len(self.dataset)) |
| self._current_idx = 0 |
| |
| idx = int(self._dataset_indices[self._current_idx]) |
| self._current_idx = (self._current_idx + 1) % len(self._dataset_indices) |
| if self._current_idx == 0: |
| self._dataset_indices = np.random.permutation(len(self.dataset)) |
| |
| sample = self.dataset[idx] |
| image = sample["image"] |
| mask = sample["mask"] |
| |
| if isinstance(image, Image.Image): |
| image = image.convert("RGB") |
| image = np.array(image) |
| if isinstance(mask, Image.Image): |
| mask = mask.convert("L") |
| mask = np.array(mask) |
| |
| mask = (mask > 127).astype(np.uint8) |
| return image, mask |
| |
| def _resize_for_obs(self, img, is_mask=False): |
| pil_img = Image.fromarray(img) |
| if is_mask: |
| pil_img = pil_img.resize((self.obs_size, self.obs_size), Image.NEAREST) |
| else: |
| pil_img = pil_img.resize((self.obs_size, self.obs_size), Image.BILINEAR) |
| return np.array(pil_img) |
| |
| def _make_click_heatmap(self, clicks_yx, orig_h, orig_w): |
| heatmap = np.zeros((self.obs_size, self.obs_size), dtype=np.uint8) |
| for (y, x) in clicks_yx: |
| obs_y = int(y * self.obs_size / orig_h) |
| obs_x = int(x * self.obs_size / orig_w) |
| obs_y = np.clip(obs_y, 0, self.obs_size - 1) |
| obs_x = np.clip(obs_x, 0, self.obs_size - 1) |
| for dy in range(-self.click_radius, self.click_radius+1): |
| for dx in range(-self.click_radius, self.click_radius+1): |
| if dy**2 + dx**2 <= self.click_radius**2: |
| ny, nx = obs_y + dy, obs_x + dx |
| if 0 <= ny < self.obs_size and 0 <= nx < self.obs_size: |
| heatmap[ny, nx] = 255 |
| return heatmap |
| |
| def _get_obs(self): |
| img_resized = self._resize_for_obs(self.current_image) |
| mask_resized = self._resize_for_obs((self.current_mask * 255).astype(np.uint8), is_mask=True) |
| |
| fg_yx = [(y, x) for (x, y), l in zip(self.click_coords, self.click_labels) if l == 1] |
| bg_yx = [(y, x) for (x, y), l in zip(self.click_coords, self.click_labels) if l == 0] |
| |
| fg_heatmap = self._make_click_heatmap(fg_yx, self.orig_h, self.orig_w) |
| bg_heatmap = self._make_click_heatmap(bg_yx, self.orig_h, self.orig_w) |
| |
| obs = np.stack([ |
| img_resized[:, :, 0], img_resized[:, :, 1], img_resized[:, :, 2], |
| mask_resized, fg_heatmap, bg_heatmap, |
| ], axis=-1).astype(np.uint8) |
| return obs |
| |
| def _run_sam(self): |
| if not self.use_sam or self.sam_predictor is None: |
| return self._simulate_mask() |
| |
| coords = np.array(self.click_coords, dtype=np.float32) |
| labels = np.array(self.click_labels, dtype=np.int32) |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| ctx = torch.autocast(device, dtype=torch.bfloat16) if device == "cuda" else torch.inference_mode() |
| with torch.inference_mode(), ctx: |
| masks, scores, logits = self.sam_predictor.predict( |
| point_coords=coords, point_labels=labels, |
| multimask_output=(len(self.click_coords) == 1), |
| ) |
| |
| if len(masks.shape) == 3 and masks.shape[0] > 1: |
| best_idx = np.argmax(scores) |
| mask = masks[best_idx] |
| else: |
| mask = masks[0] if len(masks.shape) == 3 else masks |
| |
| return mask.astype(np.uint8) |
| |
| def _simulate_mask(self): |
| if not hasattr(self, '_noise_mask'): |
| noise = np.random.random(self.current_gt.shape) < 0.15 |
| self._noise_mask = self.current_gt.copy() |
| from scipy.ndimage import binary_dilation, binary_erosion |
| if np.random.random() < 0.5: |
| self._noise_mask = binary_dilation(self._noise_mask, np.ones((7,7))).astype(np.uint8) |
| else: |
| self._noise_mask = binary_erosion(self._noise_mask, np.ones((5,5))).astype(np.uint8) |
| self._noise_mask = (self._noise_mask ^ noise.astype(np.uint8)).astype(np.uint8) |
| |
| mask = self._noise_mask.copy() |
| for (x, y), label in zip(self.click_coords, self.click_labels): |
| radius = 20 |
| y0, y1 = max(0, int(y)-radius), min(mask.shape[0], int(y)+radius) |
| x0, x1 = max(0, int(x)-radius), min(mask.shape[1], int(x)+radius) |
| mask[y0:y1, x0:x1] = self.current_gt[y0:y1, x0:x1] |
| return mask |
| |
| def reset(self, seed=None, options=None): |
| super().reset(seed=seed) |
| self.current_image, self.current_gt = self._load_random_sample() |
| self.orig_h, self.orig_w = self.current_image.shape[:2] |
| self.click_coords = [] |
| self.click_labels = [] |
| self.n_clicks = 0 |
| if hasattr(self, '_noise_mask'): |
| del self._noise_mask |
| |
| if self.use_sam and self.sam_predictor is not None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| ctx = torch.autocast(device, dtype=torch.bfloat16) if device == "cuda" else torch.inference_mode() |
| with torch.inference_mode(), ctx: |
| self.sam_predictor.set_image(self.current_image) |
| |
| init_y, init_x, init_label = oracle_click( |
| np.zeros_like(self.current_gt), self.current_gt, |
| noise_range=self.initial_click_noise |
| ) |
| self.click_coords.append((int(init_x), int(init_y))) |
| self.click_labels.append(init_label) |
| |
| self.current_mask = self._run_sam() |
| self.prev_dice = compute_dice(self.current_mask, self.current_gt) |
| |
| return self._get_obs(), {"dice": self.prev_dice, "iou": compute_iou(self.current_mask, self.current_gt)} |
| |
| def step(self, action): |
| label = action % 2 |
| pos = action // 2 |
| grid_y = pos // self.grid_size |
| grid_x = pos % self.grid_size |
| |
| orig_x = int((grid_x + 0.5) * self.orig_w / self.grid_size) |
| orig_y = int((grid_y + 0.5) * self.orig_h / self.grid_size) |
| orig_x = np.clip(orig_x, 0, self.orig_w - 1) |
| orig_y = np.clip(orig_y, 0, self.orig_h - 1) |
| |
| self.click_coords.append((orig_x, orig_y)) |
| self.click_labels.append(1 if label == 0 else 0) |
| |
| self.current_mask = self._run_sam() |
| |
| new_dice = compute_dice(self.current_mask, self.current_gt) |
| delta_dice = new_dice - self.prev_dice |
| |
| boundary_bonus = 0.0 |
| pred = self.current_mask.astype(bool) |
| gt = self.current_gt.astype(bool) |
| error_mask = pred != gt |
| if error_mask.any(): |
| error_dist = distance_transform_edt(~error_mask) |
| click_error_dist = error_dist[orig_y, orig_x] |
| if click_error_dist < 10: |
| boundary_bonus = 0.05 * (1.0 - click_error_dist / 10.0) |
| |
| reward = delta_dice + self.boundary_reward_weight * boundary_bonus |
| |
| self.prev_dice = new_dice |
| self.n_clicks += 1 |
| terminated = self.n_clicks >= self.max_clicks |
| truncated = False |
| |
| if new_dice > 0.95 and not terminated: |
| reward += 0.1 |
| |
| info = { |
| "dice": new_dice, "iou": compute_iou(self.current_mask, self.current_gt), |
| "delta_dice": delta_dice, "n_clicks": self.n_clicks + 1, |
| "click_pos": (orig_x, orig_y), "click_label": self.click_labels[-1], |
| } |
| return self._get_obs(), float(reward), terminated, truncated, info |
| |
| def render(self): |
| if self.render_mode == "rgb_array": |
| return self._get_obs()[:, :, :3] |
| return None |
|
|