sam2-click-agent / README.md
dat-rohit's picture
Upload folder using huggingface_hub
0ab9351 verified
# SAM2 Click Agent — RL-Based Interactive Segmentation Refinement
## Overview
A PPO-trained RL agent that automates corrective click placement for SAM2-based
interactive medical image segmentation. After an initial user click, the agent
automatically places additional refinement clicks to improve the segmentation mask.
## Architecture
- **Environment**: SAM2.1-hiera-base-plus (frozen) as segmentation backbone
- **Agent**: CNN-based PPO policy (Stable-Baselines3)
- **Observation**: 6-channel image (RGB + mask + fg/bg click heatmaps), 128x128
- **Action**: Discrete(2048) = 32x32 grid positions × 2 (fg/bg)
- **Reward**: Delta Dice + boundary-aware bonus (BS-IRIS inspired)
## Training Details
- **Dataset**: [Kvasir-SEG Augmented](https://huggingface.co/datasets/andreribeiro87/kvasir-seg-augmented) (4800 train, polyp segmentation)
- **Total timesteps**: 500,000
- **Training time**: 185.7 minutes
- **Parameters**: 1,886,625
- **PPO Config**: lr=0.00025, clip=0.1, ent=0.02, batch=128
## Results (on test set, 100 samples)
### Oracle Baseline (deterministic heuristic — center of largest error region)
- step_0: Dice = 0.0000 ± 0.0000
- step_1: Dice = 0.7482 ± 0.3160
- step_2: Dice = 0.8480 ± 0.2142
- step_3: Dice = 0.8545 ± 0.2173
- step_4: Dice = 0.8942 ± 0.1692
- step_5: Dice = 0.9170 ± 0.1281
### RL Click Agent (trained PPO policy)
- mean_episode_reward: -0.0348
- step_0: Dice = 0.7482 ± 0.3160
- step_1: Dice = 0.6528 ± 0.3375
- step_2: Dice = 0.6242 ± 0.3509
- step_3: Dice = 0.6667 ± 0.3246
- step_4: Dice = 0.6445 ± 0.3087
- step_5: Dice = 0.6141 ± 0.3233
## Based On
- [BS-IRIS](https://arxiv.org/abs/2303.10692) — Boundary-aware reward design (IEEE TMI 2023)
- [IteR-MRL](https://arxiv.org/abs/1911.10334) — Multi-agent RL for interactive segmentation (CVPR 2020)
- [RITM](https://arxiv.org/abs/2102.06583) — Oracle click simulation strategy
## Usage
```python
from stable_baselines3 import PPO
from sam2_click_env import SAM2ClickEnv, compute_dice
# Load agent
model = PPO.load("click_agent_ppo")
# Create environment with your SAM2 predictor
env = SAM2ClickEnv(
dataset=your_dataset,
sam_predictor=your_sam_predictor,
obs_size=128,
grid_size=32,
max_clicks=5,
use_sam=True,
)
# Run inference
obs, info = env.reset()
for step in range(5):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
print(f"Step {step+1}: Dice={info['dice']:.4f}")
if done:
break
```