| # SAM2 Click Agent — RL-Based Interactive Segmentation Refinement |
|
|
| ## Overview |
| A PPO-trained RL agent that automates corrective click placement for SAM2-based |
| interactive medical image segmentation. After an initial user click, the agent |
| automatically places additional refinement clicks to improve the segmentation mask. |
|
|
| ## Architecture |
| - **Environment**: SAM2.1-hiera-base-plus (frozen) as segmentation backbone |
| - **Agent**: CNN-based PPO policy (Stable-Baselines3) |
| - **Observation**: 6-channel image (RGB + mask + fg/bg click heatmaps), 128x128 |
| - **Action**: Discrete(2048) = 32x32 grid positions × 2 (fg/bg) |
| - **Reward**: Delta Dice + boundary-aware bonus (BS-IRIS inspired) |
|
|
| ## Training Details |
| - **Dataset**: [Kvasir-SEG Augmented](https://huggingface.co/datasets/andreribeiro87/kvasir-seg-augmented) (4800 train, polyp segmentation) |
| - **Total timesteps**: 500,000 |
| - **Training time**: 185.7 minutes |
| - **Parameters**: 1,886,625 |
| - **PPO Config**: lr=0.00025, clip=0.1, ent=0.02, batch=128 |
|
|
| ## Results (on test set, 100 samples) |
|
|
| ### Oracle Baseline (deterministic heuristic — center of largest error region) |
| - step_0: Dice = 0.0000 ± 0.0000 |
| - step_1: Dice = 0.7482 ± 0.3160 |
| - step_2: Dice = 0.8480 ± 0.2142 |
| - step_3: Dice = 0.8545 ± 0.2173 |
| - step_4: Dice = 0.8942 ± 0.1692 |
| - step_5: Dice = 0.9170 ± 0.1281 |
|
|
| ### RL Click Agent (trained PPO policy) |
| - mean_episode_reward: -0.0348 |
| - step_0: Dice = 0.7482 ± 0.3160 |
| - step_1: Dice = 0.6528 ± 0.3375 |
| - step_2: Dice = 0.6242 ± 0.3509 |
| - step_3: Dice = 0.6667 ± 0.3246 |
| - step_4: Dice = 0.6445 ± 0.3087 |
| - step_5: Dice = 0.6141 ± 0.3233 |
|
|
| ## Based On |
| - [BS-IRIS](https://arxiv.org/abs/2303.10692) — Boundary-aware reward design (IEEE TMI 2023) |
| - [IteR-MRL](https://arxiv.org/abs/1911.10334) — Multi-agent RL for interactive segmentation (CVPR 2020) |
| - [RITM](https://arxiv.org/abs/2102.06583) — Oracle click simulation strategy |
|
|
| ## Usage |
| ```python |
| from stable_baselines3 import PPO |
| from sam2_click_env import SAM2ClickEnv, compute_dice |
| |
| # Load agent |
| model = PPO.load("click_agent_ppo") |
| |
| # Create environment with your SAM2 predictor |
| env = SAM2ClickEnv( |
| dataset=your_dataset, |
| sam_predictor=your_sam_predictor, |
| obs_size=128, |
| grid_size=32, |
| max_clicks=5, |
| use_sam=True, |
| ) |
| |
| # Run inference |
| obs, info = env.reset() |
| for step in range(5): |
| action, _ = model.predict(obs, deterministic=True) |
| obs, reward, done, truncated, info = env.step(action) |
| print(f"Step {step+1}: Dice={info['dice']:.4f}") |
| if done: |
| break |
| ``` |
|
|