Spaces:
Running
on
Zero
Running
on
Zero
yjwnb6
commited on
Commit
·
7b25808
0
Parent(s):
Initial HF Space upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- app.py +428 -0
- requirements.txt +13 -0
- sam2/CODE_OF_CONDUCT.md +80 -0
- sam2/CONTRIBUTING.md +31 -0
- sam2/INSTALL.md +189 -0
- sam2/LICENSE +201 -0
- sam2/LICENSE_cctorch +29 -0
- sam2/README.md +224 -0
- sam2/checkpoints/download_ckpts.sh +59 -0
- sam2/checkpoints/unsamv2_plus_ckpt.pt +3 -0
- sam2/notebooks/cascadepsp.py +61 -0
- sam2/notebooks/interactive_image_segmentation.ipynb +0 -0
- sam2/notebooks/video_segmentation.ipynb +0 -0
- sam2/notebooks/whole_image_segmentation.ipynb +0 -0
- sam2/pyproject.toml +6 -0
- sam2/sam2/__init__.py +11 -0
- sam2/sam2/__pycache__/__init__.cpython-310.pyc +0 -0
- sam2/sam2/__pycache__/build_sam.cpython-310.pyc +0 -0
- sam2/sam2/__pycache__/granularity_embedding.cpython-310.pyc +0 -0
- sam2/sam2/__pycache__/sam2_image_predictor.cpython-310.pyc +0 -0
- sam2/sam2/automatic_mask_generator.py +469 -0
- sam2/sam2/build_sam.py +252 -0
- sam2/sam2/configs/unsamv2_small.yaml +122 -0
- sam2/sam2/configs/unsamv2_small_training.yaml +323 -0
- sam2/sam2/csrc/connected_components.cu +289 -0
- sam2/sam2/granularity_embedding.py +67 -0
- sam2/sam2/modeling/__init__.py +5 -0
- sam2/sam2/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
- sam2/sam2/modeling/__pycache__/memory_attention.cpython-310.pyc +0 -0
- sam2/sam2/modeling/__pycache__/memory_encoder.cpython-310.pyc +0 -0
- sam2/sam2/modeling/__pycache__/position_encoding.cpython-310.pyc +0 -0
- sam2/sam2/modeling/__pycache__/sam2_base.cpython-310.pyc +0 -0
- sam2/sam2/modeling/__pycache__/sam2_utils.cpython-310.pyc +0 -0
- sam2/sam2/modeling/backbones/__init__.py +5 -0
- sam2/sam2/modeling/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
- sam2/sam2/modeling/backbones/__pycache__/hieradet.cpython-310.pyc +0 -0
- sam2/sam2/modeling/backbones/__pycache__/image_encoder.cpython-310.pyc +0 -0
- sam2/sam2/modeling/backbones/__pycache__/utils.cpython-310.pyc +0 -0
- sam2/sam2/modeling/backbones/adapter_hieradet.py +850 -0
- sam2/sam2/modeling/backbones/hieradet.py +321 -0
- sam2/sam2/modeling/backbones/image_encoder.py +136 -0
- sam2/sam2/modeling/backbones/my_adapter.py +317 -0
- sam2/sam2/modeling/backbones/utils.py +93 -0
- sam2/sam2/modeling/memory_attention.py +172 -0
- sam2/sam2/modeling/memory_encoder.py +181 -0
- sam2/sam2/modeling/position_encoding.py +240 -0
- sam2/sam2/modeling/sam/__init__.py +5 -0
- sam2/sam2/modeling/sam/__pycache__/__init__.cpython-310.pyc +0 -0
- sam2/sam2/modeling/sam/__pycache__/gra_mask_decoder.cpython-310.pyc +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
sam2/checkpoints/unsamv2_plus_ckpt.pt filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Gradio demo for UnSAMv2 interactive image segmentation with Hugging Face ZeroGPU support."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import threading
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Optional, Sequence, Tuple
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import spaces # type: ignore
|
| 20 |
+
except ImportError: # pragma: no cover - optional dependency on Spaces runtime
|
| 21 |
+
spaces = None
|
| 22 |
+
|
| 23 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 24 |
+
SAM2_REPO = REPO_ROOT / "sam2"
|
| 25 |
+
if SAM2_REPO.exists():
|
| 26 |
+
sys.path.insert(0, str(SAM2_REPO))
|
| 27 |
+
|
| 28 |
+
from sam2.build_sam import build_sam2 # noqa: E402
|
| 29 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor # noqa: E402
|
| 30 |
+
|
| 31 |
+
logging.basicConfig(level=os.getenv("UNSAMV2_LOGLEVEL", "INFO"))
|
| 32 |
+
LOGGER = logging.getLogger("unsamv2-gradio")
|
| 33 |
+
|
| 34 |
+
CONFIG_PATH = os.getenv("UNSAMV2_CONFIG", "configs/unsamv2_small.yaml")
|
| 35 |
+
CKPT_PATH = Path(
|
| 36 |
+
os.getenv("UNSAMV2_CKPT", SAM2_REPO / "checkpoints" / "unsamv2_plus_ckpt.pt")
|
| 37 |
+
).resolve()
|
| 38 |
+
if not CKPT_PATH.exists():
|
| 39 |
+
raise FileNotFoundError(
|
| 40 |
+
f"Checkpoint not found at {CKPT_PATH}. Set UNSAMV2_CKPT to a valid .pt file."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
GRANULARITY_MIN = float(os.getenv("UNSAMV2_GRAN_MIN", 0.1))
|
| 44 |
+
GRANULARITY_MAX = float(os.getenv("UNSAMV2_GRAN_MAX", 1.0))
|
| 45 |
+
ZERO_GPU_ENABLED = os.getenv("UNSAMV2_ENABLE_ZEROGPU", "1").lower() in {"1", "true", "yes"}
|
| 46 |
+
ZERO_GPU_DURATION = int(os.getenv("UNSAMV2_ZEROGPU_DURATION", "60"))
|
| 47 |
+
|
| 48 |
+
POINT_MODE_TO_LABEL = {"Foreground (+)": 1, "Background (-)": 0}
|
| 49 |
+
POINT_COLORS_BGR = {
|
| 50 |
+
1: (72, 201, 127), # green-ish for positives
|
| 51 |
+
0: (64, 76, 225), # red-ish for negatives
|
| 52 |
+
}
|
| 53 |
+
MASK_COLOR_BGR = (0, 196, 255)
|
| 54 |
+
OUTLINE_COLOR_BGR = (0, 165, 255)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ModelManager:
|
| 58 |
+
"""Keeps SAM2 models on each device and spawns lightweight predictors."""
|
| 59 |
+
|
| 60 |
+
def __init__(self) -> None:
|
| 61 |
+
self._models: dict[str, torch.nn.Module] = {}
|
| 62 |
+
self._lock = threading.Lock()
|
| 63 |
+
|
| 64 |
+
def _build(self, device: torch.device) -> torch.nn.Module:
|
| 65 |
+
LOGGER.info("Loading UnSAMv2 weights onto %s", device)
|
| 66 |
+
return build_sam2(
|
| 67 |
+
CONFIG_PATH,
|
| 68 |
+
ckpt_path=str(CKPT_PATH),
|
| 69 |
+
device=device,
|
| 70 |
+
mode="eval",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def get_model(self, device: torch.device) -> torch.nn.Module:
|
| 74 |
+
key = (
|
| 75 |
+
f"{device.type}:{device.index}"
|
| 76 |
+
if device.type == "cuda"
|
| 77 |
+
else device.type
|
| 78 |
+
)
|
| 79 |
+
with self._lock:
|
| 80 |
+
if key not in self._models:
|
| 81 |
+
self._models[key] = self._build(device)
|
| 82 |
+
return self._models[key]
|
| 83 |
+
|
| 84 |
+
def make_predictor(self, device: torch.device) -> SAM2ImagePredictor:
|
| 85 |
+
return SAM2ImagePredictor(self.get_model(device))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
MODEL_MANAGER = ModelManager()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def ensure_uint8(image: Optional[np.ndarray]) -> Optional[np.ndarray]:
|
| 92 |
+
if image is None:
|
| 93 |
+
return None
|
| 94 |
+
img = image[..., :3] # drop alpha if present
|
| 95 |
+
if img.dtype == np.float32 or img.dtype == np.float64:
|
| 96 |
+
if img.max() <= 1.0:
|
| 97 |
+
img = (img * 255).clip(0, 255).astype(np.uint8)
|
| 98 |
+
else:
|
| 99 |
+
img = img.clip(0, 255).astype(np.uint8)
|
| 100 |
+
elif img.dtype != np.uint8:
|
| 101 |
+
img = img.clip(0, 255).astype(np.uint8)
|
| 102 |
+
return img
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def choose_device() -> torch.device:
|
| 106 |
+
preference = os.getenv("UNSAMV2_DEVICE", "auto").lower()
|
| 107 |
+
if preference == "cpu":
|
| 108 |
+
return torch.device("cpu")
|
| 109 |
+
if preference.startswith("cuda") or preference == "gpu":
|
| 110 |
+
if torch.cuda.is_available():
|
| 111 |
+
return torch.device(preference if preference.startswith("cuda") else "cuda")
|
| 112 |
+
LOGGER.warning("CUDA requested but not available; defaulting to CPU")
|
| 113 |
+
return torch.device("cpu")
|
| 114 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def build_granularity_tensor(value: float, device: torch.device) -> torch.Tensor:
|
| 118 |
+
tensor = torch.tensor([[[[value]]]], dtype=torch.float32, device=device)
|
| 119 |
+
return tensor
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def draw_overlay(
|
| 123 |
+
image: np.ndarray,
|
| 124 |
+
mask: Optional[np.ndarray],
|
| 125 |
+
points: Sequence[Sequence[float]],
|
| 126 |
+
labels: Sequence[int],
|
| 127 |
+
alpha: float = 0.55,
|
| 128 |
+
) -> np.ndarray:
|
| 129 |
+
canvas_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 130 |
+
if mask is not None:
|
| 131 |
+
mask_bool = mask.astype(bool)
|
| 132 |
+
overlay = np.zeros_like(canvas_bgr, dtype=np.uint8)
|
| 133 |
+
overlay[mask_bool] = MASK_COLOR_BGR
|
| 134 |
+
canvas_bgr = np.where(
|
| 135 |
+
mask_bool[..., None],
|
| 136 |
+
(canvas_bgr * (1.0 - alpha) + overlay * alpha).astype(np.uint8),
|
| 137 |
+
canvas_bgr,
|
| 138 |
+
)
|
| 139 |
+
contours, _ = cv2.findContours(
|
| 140 |
+
mask_bool.astype(np.uint8),
|
| 141 |
+
mode=cv2.RETR_EXTERNAL,
|
| 142 |
+
method=cv2.CHAIN_APPROX_SIMPLE,
|
| 143 |
+
)
|
| 144 |
+
cv2.drawContours(canvas_bgr, contours, -1, OUTLINE_COLOR_BGR, 2)
|
| 145 |
+
for (x, y), lbl in zip(points, labels):
|
| 146 |
+
color = POINT_COLORS_BGR.get(lbl, (255, 255, 255))
|
| 147 |
+
center = (int(round(x)), int(round(y)))
|
| 148 |
+
cv2.circle(canvas_bgr, center, 7, color, thickness=-1, lineType=cv2.LINE_AA)
|
| 149 |
+
cv2.circle(canvas_bgr, center, 9, (255, 255, 255), thickness=2, lineType=cv2.LINE_AA)
|
| 150 |
+
return cv2.cvtColor(canvas_bgr, cv2.COLOR_BGR2RGB)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def points_table(points: Sequence[Sequence[float]], labels: Sequence[int]) -> List[List[str]]:
|
| 154 |
+
table = []
|
| 155 |
+
for idx, ((x, y), lbl) in enumerate(zip(points, labels), start=1):
|
| 156 |
+
table.append([
|
| 157 |
+
idx,
|
| 158 |
+
round(float(x), 1),
|
| 159 |
+
round(float(y), 1),
|
| 160 |
+
"fg" if lbl == 1 else "bg",
|
| 161 |
+
])
|
| 162 |
+
return table
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def handle_image_upload(image: Optional[np.ndarray]):
|
| 166 |
+
img = ensure_uint8(image)
|
| 167 |
+
if img is None:
|
| 168 |
+
return (
|
| 169 |
+
None,
|
| 170 |
+
None,
|
| 171 |
+
None,
|
| 172 |
+
[],
|
| 173 |
+
[],
|
| 174 |
+
[],
|
| 175 |
+
"Upload an image to start adding clicks.",
|
| 176 |
+
)
|
| 177 |
+
return (
|
| 178 |
+
img,
|
| 179 |
+
None,
|
| 180 |
+
img,
|
| 181 |
+
[],
|
| 182 |
+
[],
|
| 183 |
+
[],
|
| 184 |
+
"Image loaded. Choose click type, then tap on the image.",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def handle_click(
|
| 189 |
+
point_mode: str,
|
| 190 |
+
pts: List[Sequence[float]],
|
| 191 |
+
lbls: List[int],
|
| 192 |
+
image: Optional[np.ndarray],
|
| 193 |
+
evt: gr.SelectData,
|
| 194 |
+
):
|
| 195 |
+
if image is None:
|
| 196 |
+
return (
|
| 197 |
+
gr.update(),
|
| 198 |
+
None,
|
| 199 |
+
pts,
|
| 200 |
+
lbls,
|
| 201 |
+
points_table(pts, lbls),
|
| 202 |
+
"Upload an image first.",
|
| 203 |
+
)
|
| 204 |
+
coord = evt.index # (x, y)
|
| 205 |
+
if coord is None:
|
| 206 |
+
return (
|
| 207 |
+
gr.update(),
|
| 208 |
+
None,
|
| 209 |
+
pts,
|
| 210 |
+
lbls,
|
| 211 |
+
points_table(pts, lbls),
|
| 212 |
+
"Couldn't read click position.",
|
| 213 |
+
)
|
| 214 |
+
x, y = coord
|
| 215 |
+
label = POINT_MODE_TO_LABEL.get(point_mode, 1)
|
| 216 |
+
pts = pts + [[float(x), float(y)]]
|
| 217 |
+
lbls = lbls + [label]
|
| 218 |
+
overlay = draw_overlay(image, None, pts, lbls)
|
| 219 |
+
status = f"Added {'positive' if label == 1 else 'negative'} click at ({int(x)}, {int(y)})."
|
| 220 |
+
return overlay, None, pts, lbls, points_table(pts, lbls), status
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def undo_last_click(image: Optional[np.ndarray], pts: List[Sequence[float]], lbls: List[int]):
|
| 224 |
+
if not pts:
|
| 225 |
+
return (
|
| 226 |
+
gr.update(),
|
| 227 |
+
None,
|
| 228 |
+
pts,
|
| 229 |
+
lbls,
|
| 230 |
+
points_table(pts, lbls),
|
| 231 |
+
"No clicks to undo.",
|
| 232 |
+
)
|
| 233 |
+
pts = pts[:-1]
|
| 234 |
+
lbls = lbls[:-1]
|
| 235 |
+
overlay = draw_overlay(image, None, pts, lbls) if image is not None else None
|
| 236 |
+
status = "Removed the last click."
|
| 237 |
+
return overlay, None, pts, lbls, points_table(pts, lbls), status
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def clear_clicks(image: Optional[np.ndarray]):
|
| 241 |
+
overlay = image if image is not None else None
|
| 242 |
+
return overlay, None, [], [], [], "Cleared all clicks."
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _run_segmentation(
|
| 246 |
+
image: Optional[np.ndarray],
|
| 247 |
+
pts: List[Sequence[float]],
|
| 248 |
+
lbls: List[int],
|
| 249 |
+
granularity: float,
|
| 250 |
+
):
|
| 251 |
+
img = ensure_uint8(image)
|
| 252 |
+
if img is None:
|
| 253 |
+
return None, None, "Upload an image to segment."
|
| 254 |
+
if not pts:
|
| 255 |
+
return draw_overlay(img, None, [], []), None, "Add at least one click before running segmentation."
|
| 256 |
+
|
| 257 |
+
device = choose_device()
|
| 258 |
+
predictor = MODEL_MANAGER.make_predictor(device)
|
| 259 |
+
predictor.set_image(img)
|
| 260 |
+
|
| 261 |
+
coords = np.asarray(pts, dtype=np.float32)
|
| 262 |
+
labels = np.asarray(lbls, dtype=np.int32)
|
| 263 |
+
gran_tensor = build_granularity_tensor(granularity, predictor.device)
|
| 264 |
+
|
| 265 |
+
masks, scores, _ = predictor.predict(
|
| 266 |
+
point_coords=coords,
|
| 267 |
+
point_labels=labels,
|
| 268 |
+
multimask_output=True,
|
| 269 |
+
gra=float(granularity),
|
| 270 |
+
granularity=gran_tensor,
|
| 271 |
+
)
|
| 272 |
+
best_idx = int(np.argmax(scores))
|
| 273 |
+
best_mask = masks[best_idx].astype(bool)
|
| 274 |
+
overlay = draw_overlay(img, best_mask, pts, lbls)
|
| 275 |
+
mask_vis = (best_mask.astype(np.uint8) * 255)
|
| 276 |
+
status = f"Best mask #{best_idx + 1} IoU score: {float(scores[best_idx]):.3f} | granularity={granularity:.2f}"
|
| 277 |
+
return overlay, mask_vis, status
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if spaces is not None and ZERO_GPU_ENABLED:
|
| 281 |
+
segment_fn = spaces.GPU(duration=ZERO_GPU_DURATION)(_run_segmentation)
|
| 282 |
+
else:
|
| 283 |
+
segment_fn = _run_segmentation
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def build_demo() -> gr.Blocks:
|
| 287 |
+
with gr.Blocks(title="UnSAMv2 Interactive Segmentation", theme=gr.themes.Soft()) as demo:
|
| 288 |
+
gr.Markdown(
|
| 289 |
+
"""## UnSAMv2 · Interactive Granularity Control
|
| 290 |
+
Upload an image, add positive/negative clicks, tune granularity, and run segmentation.
|
| 291 |
+
ZeroGPU automatically pulls a GPU when available; otherwise the app falls back to CPU."""
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
image_state = gr.State()
|
| 295 |
+
points_state = gr.State([])
|
| 296 |
+
labels_state = gr.State([])
|
| 297 |
+
|
| 298 |
+
with gr.Row():
|
| 299 |
+
image_input = gr.Image(
|
| 300 |
+
label="1 · Upload image & click to add prompts",
|
| 301 |
+
type="numpy",
|
| 302 |
+
height=480,
|
| 303 |
+
)
|
| 304 |
+
overlay_output = gr.Image(
|
| 305 |
+
label="Segmentation preview",
|
| 306 |
+
interactive=False,
|
| 307 |
+
height=480,
|
| 308 |
+
)
|
| 309 |
+
mask_output = gr.Image(
|
| 310 |
+
label="Binary mask",
|
| 311 |
+
interactive=False,
|
| 312 |
+
height=480,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
with gr.Row():
|
| 316 |
+
point_mode = gr.Radio(
|
| 317 |
+
choices=list(POINT_MODE_TO_LABEL.keys()),
|
| 318 |
+
value="Foreground (+)",
|
| 319 |
+
label="Click type",
|
| 320 |
+
)
|
| 321 |
+
granularity_slider = gr.Slider(
|
| 322 |
+
minimum=GRANULARITY_MIN,
|
| 323 |
+
maximum=GRANULARITY_MAX,
|
| 324 |
+
value=0.2,
|
| 325 |
+
step=0.05,
|
| 326 |
+
label="Granularity",
|
| 327 |
+
info="Lower = finer details, Higher = coarser regions",
|
| 328 |
+
)
|
| 329 |
+
segment_button = gr.Button("3 · Segment", variant="primary")
|
| 330 |
+
|
| 331 |
+
with gr.Row():
|
| 332 |
+
undo_button = gr.Button("Undo last click")
|
| 333 |
+
clear_button = gr.Button("Clear clicks")
|
| 334 |
+
|
| 335 |
+
points_table_output = gr.Dataframe(
|
| 336 |
+
headers=["#", "x", "y", "type"],
|
| 337 |
+
datatype=["number", "number", "number", "str"],
|
| 338 |
+
interactive=False,
|
| 339 |
+
label="2 · Click history",
|
| 340 |
+
)
|
| 341 |
+
status_markdown = gr.Markdown(" Ready.")
|
| 342 |
+
|
| 343 |
+
image_input.upload(
|
| 344 |
+
handle_image_upload,
|
| 345 |
+
inputs=[image_input],
|
| 346 |
+
outputs=[
|
| 347 |
+
overlay_output,
|
| 348 |
+
mask_output,
|
| 349 |
+
image_state,
|
| 350 |
+
points_state,
|
| 351 |
+
labels_state,
|
| 352 |
+
points_table_output,
|
| 353 |
+
status_markdown,
|
| 354 |
+
],
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
image_input.clear(
|
| 358 |
+
handle_image_upload,
|
| 359 |
+
inputs=[image_input],
|
| 360 |
+
outputs=[
|
| 361 |
+
overlay_output,
|
| 362 |
+
mask_output,
|
| 363 |
+
image_state,
|
| 364 |
+
points_state,
|
| 365 |
+
labels_state,
|
| 366 |
+
points_table_output,
|
| 367 |
+
status_markdown,
|
| 368 |
+
],
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
image_input.select(
|
| 372 |
+
handle_click,
|
| 373 |
+
inputs=[
|
| 374 |
+
point_mode,
|
| 375 |
+
points_state,
|
| 376 |
+
labels_state,
|
| 377 |
+
image_state,
|
| 378 |
+
],
|
| 379 |
+
outputs=[
|
| 380 |
+
overlay_output,
|
| 381 |
+
mask_output,
|
| 382 |
+
points_state,
|
| 383 |
+
labels_state,
|
| 384 |
+
points_table_output,
|
| 385 |
+
status_markdown,
|
| 386 |
+
],
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
undo_button.click(
|
| 390 |
+
undo_last_click,
|
| 391 |
+
inputs=[image_state, points_state, labels_state],
|
| 392 |
+
outputs=[
|
| 393 |
+
overlay_output,
|
| 394 |
+
mask_output,
|
| 395 |
+
points_state,
|
| 396 |
+
labels_state,
|
| 397 |
+
points_table_output,
|
| 398 |
+
status_markdown,
|
| 399 |
+
],
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
clear_button.click(
|
| 403 |
+
clear_clicks,
|
| 404 |
+
inputs=[image_state],
|
| 405 |
+
outputs=[
|
| 406 |
+
overlay_output,
|
| 407 |
+
mask_output,
|
| 408 |
+
points_state,
|
| 409 |
+
labels_state,
|
| 410 |
+
points_table_output,
|
| 411 |
+
status_markdown,
|
| 412 |
+
],
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
segment_button.click(
|
| 416 |
+
segment_fn,
|
| 417 |
+
inputs=[image_state, points_state, labels_state, granularity_slider],
|
| 418 |
+
outputs=[overlay_output, mask_output, status_markdown],
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
demo.queue(max_size=8)
|
| 422 |
+
return demo
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
demo = build_demo()
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), share=True)
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
spaces==0.32.0
|
| 3 |
+
torch==2.5.1
|
| 4 |
+
torchvision==0.20.1
|
| 5 |
+
numpy==2.1.2
|
| 6 |
+
opencv-python-headless==4.10.0.84
|
| 7 |
+
Pillow==9.5.0
|
| 8 |
+
hydra-core==1.3.2
|
| 9 |
+
omegaconf==2.3.0
|
| 10 |
+
iopath==0.1.10
|
| 11 |
+
huggingface_hub==0.25.2
|
| 12 |
+
PyYAML==6.0
|
| 13 |
+
tqdm==4.67.1
|
sam2/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
sam2/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to segment-anything
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints, using the `ufmt format` command. Linting requires `black==24.2.0`, `usort==1.0.2`, and `ufmt==2.0.0b2`, which can be installed via `pip install -e ".[dev]"`.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to segment-anything, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
sam2/INSTALL.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Installation
|
| 2 |
+
|
| 3 |
+
### Requirements
|
| 4 |
+
|
| 5 |
+
- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
|
| 6 |
+
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
|
| 7 |
+
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
|
| 8 |
+
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
| 9 |
+
|
| 10 |
+
Then, install SAM 2 from the root of this repository via
|
| 11 |
+
```bash
|
| 12 |
+
pip install -e ".[notebooks]"
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Note that you may skip building the SAM 2 CUDA extension during installation via environment variable `SAM2_BUILD_CUDA=0`, as follows:
|
| 16 |
+
```bash
|
| 17 |
+
# skip the SAM 2 CUDA extension
|
| 18 |
+
SAM2_BUILD_CUDA=0 pip install -e ".[notebooks]"
|
| 19 |
+
```
|
| 20 |
+
This would also skip the post-processing step at runtime (removing small holes and sprinkles in the output masks, which requires the CUDA extension), but shouldn't affect the results in most cases.
|
| 21 |
+
|
| 22 |
+
### Building the SAM 2 CUDA extension
|
| 23 |
+
|
| 24 |
+
By default, we allow the installation to proceed even if the SAM 2 CUDA extension fails to build. (In this case, the build errors are hidden unless using `-v` for verbose output in `pip install`.)
|
| 25 |
+
|
| 26 |
+
If you see a message like `Skipping the post-processing step due to the error above` at runtime or `Failed to build the SAM 2 CUDA extension due to the error above` during installation, it indicates that the SAM 2 CUDA extension failed to build in your environment. In this case, **you can still use SAM 2 for both image and video applications**. The post-processing step (removing small holes and sprinkles in the output masks) will be skipped, but this shouldn't affect the results in most cases.
|
| 27 |
+
|
| 28 |
+
If you would like to enable this post-processing step, you can reinstall SAM 2 on a GPU machine with environment variable `SAM2_BUILD_ALLOW_ERRORS=0` to force building the CUDA extension (and raise errors if it fails to build), as follows
|
| 29 |
+
```bash
|
| 30 |
+
pip uninstall -y SAM-2 && \
|
| 31 |
+
rm -f ./sam2/*.so && \
|
| 32 |
+
SAM2_BUILD_ALLOW_ERRORS=0 pip install -v -e ".[notebooks]"
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Note that PyTorch needs to be installed first before building the SAM 2 CUDA extension. It's also necessary to install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. (This should typically be CUDA 12.1 if you follow the default installation command.) After installing the CUDA toolkits, you can check its version via `nvcc --version`.
|
| 36 |
+
|
| 37 |
+
Please check the section below on common installation issues if the CUDA extension fails to build during installation or load at runtime.
|
| 38 |
+
|
| 39 |
+
### Common Installation Issues
|
| 40 |
+
|
| 41 |
+
Click each issue for its solutions:
|
| 42 |
+
|
| 43 |
+
<details>
|
| 44 |
+
<summary>
|
| 45 |
+
I got `ImportError: cannot import name '_C' from 'sam2'`
|
| 46 |
+
</summary>
|
| 47 |
+
<br/>
|
| 48 |
+
|
| 49 |
+
This is usually because you haven't run the `pip install -e ".[notebooks]"` step above or the installation failed. Please install SAM 2 first, and see the other issues if your installation fails.
|
| 50 |
+
|
| 51 |
+
In some systems, you may need to run `python setup.py build_ext --inplace` in the SAM 2 repo root as suggested in https://github.com/facebookresearch/sam2/issues/77.
|
| 52 |
+
</details>
|
| 53 |
+
|
| 54 |
+
<details>
|
| 55 |
+
<summary>
|
| 56 |
+
I got `MissingConfigException: Cannot find primary config 'configs/sam2.1/sam2.1_hiera_l.yaml'`
|
| 57 |
+
</summary>
|
| 58 |
+
<br/>
|
| 59 |
+
|
| 60 |
+
This is usually because you haven't run the `pip install -e .` step above, so `sam2` isn't in your Python's `sys.path`. Please run this installation step. In case it still fails after the installation step, you may try manually adding the root of this repo to `PYTHONPATH` via
|
| 61 |
+
```bash
|
| 62 |
+
export SAM2_REPO_ROOT=/path/to/sam2 # path to this repo
|
| 63 |
+
export PYTHONPATH="${SAM2_REPO_ROOT}:${PYTHONPATH}"
|
| 64 |
+
```
|
| 65 |
+
to manually add `sam2_configs` into your Python's `sys.path`.
|
| 66 |
+
|
| 67 |
+
</details>
|
| 68 |
+
|
| 69 |
+
<details>
|
| 70 |
+
<summary>
|
| 71 |
+
I got `RuntimeError: Error(s) in loading state_dict for SAM2Base` when loading the new SAM 2.1 checkpoints
|
| 72 |
+
</summary>
|
| 73 |
+
<br/>
|
| 74 |
+
|
| 75 |
+
This is likely because you have installed a previous version of this repo, which doesn't have the new modules to support the SAM 2.1 checkpoints yet. Please try the following steps:
|
| 76 |
+
|
| 77 |
+
1. pull the latest code from the `main` branch of this repo
|
| 78 |
+
2. run `pip uninstall -y SAM-2` to uninstall any previous installations
|
| 79 |
+
3. then install the latest repo again using `pip install -e ".[notebooks]"`
|
| 80 |
+
|
| 81 |
+
In case the steps above still don't resolve the error, please try running in your Python environment the following
|
| 82 |
+
```python
|
| 83 |
+
from sam2.modeling import sam2_base
|
| 84 |
+
|
| 85 |
+
print(sam2_base.__file__)
|
| 86 |
+
```
|
| 87 |
+
and check whether the content in the printed local path of `sam2/modeling/sam2_base.py` matches the latest one in https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam2_base.py (e.g. whether your local file has `no_obj_embed_spatial`) to indentify if you're still using a previous installation.
|
| 88 |
+
|
| 89 |
+
</details>
|
| 90 |
+
|
| 91 |
+
<details>
|
| 92 |
+
<summary>
|
| 93 |
+
My installation failed with `CUDA_HOME environment variable is not set`
|
| 94 |
+
</summary>
|
| 95 |
+
<br/>
|
| 96 |
+
|
| 97 |
+
This usually happens because the installation step cannot find the CUDA toolkits (that contain the NVCC compiler) to build a custom CUDA kernel in SAM 2. Please install [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) or the version that matches the CUDA version for your PyTorch installation. If the error persists after installing CUDA toolkits, you may explicitly specify `CUDA_HOME` via
|
| 98 |
+
```
|
| 99 |
+
export CUDA_HOME=/usr/local/cuda # change to your CUDA toolkit path
|
| 100 |
+
```
|
| 101 |
+
and rerun the installation.
|
| 102 |
+
|
| 103 |
+
Also, you should make sure
|
| 104 |
+
```
|
| 105 |
+
python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(torch.cuda.is_available(), CUDA_HOME)'
|
| 106 |
+
```
|
| 107 |
+
print `(True, a directory with cuda)` to verify that the CUDA toolkits are correctly set up.
|
| 108 |
+
|
| 109 |
+
If you are still having problems after verifying that the CUDA toolkit is installed and the `CUDA_HOME` environment variable is set properly, you may have to add the `--no-build-isolation` flag to the pip command:
|
| 110 |
+
```
|
| 111 |
+
pip install --no-build-isolation -e .
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
</details>
|
| 115 |
+
|
| 116 |
+
<details>
|
| 117 |
+
<summary>
|
| 118 |
+
I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar errors)
|
| 119 |
+
</summary>
|
| 120 |
+
<br/>
|
| 121 |
+
|
| 122 |
+
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
|
| 123 |
+
|
| 124 |
+
In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
|
| 125 |
+
|
| 126 |
+
We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
|
| 127 |
+
</details>
|
| 128 |
+
|
| 129 |
+
<details>
|
| 130 |
+
<summary>
|
| 131 |
+
I got `CUDA error: no kernel image is available for execution on the device`
|
| 132 |
+
</summary>
|
| 133 |
+
<br/>
|
| 134 |
+
|
| 135 |
+
A possible cause could be that the CUDA kernel is somehow not compiled towards your GPU's CUDA [capability](https://developer.nvidia.com/cuda-gpus). This could happen if the installation is done in an environment different from the runtime (e.g. in a slurm system).
|
| 136 |
+
|
| 137 |
+
You can try pulling the latest code from the SAM 2 repo and running the following
|
| 138 |
+
```
|
| 139 |
+
export TORCH_CUDA_ARCH_LIST=9.0 8.0 8.6 8.9 7.0 7.2 7.5 6.0`
|
| 140 |
+
```
|
| 141 |
+
to manually specify the CUDA capability in the compilation target that matches your GPU.
|
| 142 |
+
</details>
|
| 143 |
+
|
| 144 |
+
<details>
|
| 145 |
+
<summary>
|
| 146 |
+
I got `RuntimeError: No available kernel. Aborting execution.` (or similar errors)
|
| 147 |
+
</summary>
|
| 148 |
+
<br/>
|
| 149 |
+
|
| 150 |
+
This is probably because your machine doesn't have a GPU or a compatible PyTorch version for Flash Attention (see also https://discuss.pytorch.org/t/using-f-scaled-dot-product-attention-gives-the-error-runtimeerror-no-available-kernel-aborting-execution/180900 for a discussion in PyTorch forum). You may be able to resolve this error by replacing the line
|
| 151 |
+
```python
|
| 152 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
| 153 |
+
```
|
| 154 |
+
in [`sam2/modeling/sam/transformer.py`](sam2/modeling/sam/transformer.py) with
|
| 155 |
+
```python
|
| 156 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = True, True, True
|
| 157 |
+
```
|
| 158 |
+
to relax the attention kernel setting and use other kernels than Flash Attention.
|
| 159 |
+
</details>
|
| 160 |
+
|
| 161 |
+
<details>
|
| 162 |
+
<summary>
|
| 163 |
+
I got `Error compiling objects for extension`
|
| 164 |
+
</summary>
|
| 165 |
+
<br/>
|
| 166 |
+
|
| 167 |
+
You may see error log of:
|
| 168 |
+
> unsupported Microsoft Visual Studio version! Only the versions between 2017 and 2022 (inclusive) are supported! The nvcc flag '-allow-unsupported-compiler' can be used to override this version check; however, using an unsupported host compiler may cause compilation failure or incorrect run time execution. Use at your own risk.
|
| 169 |
+
|
| 170 |
+
This is probably because your versions of CUDA and Visual Studio are incompatible. (see also https://stackoverflow.com/questions/78515942/cuda-compatibility-with-visual-studio-2022-version-17-10 for a discussion in stackoverflow).<br>
|
| 171 |
+
You may be able to fix this by adding the `-allow-unsupported-compiler` argument to `nvcc` after L48 in the [setup.py](https://github.com/facebookresearch/sam2/blob/main/setup.py). <br>
|
| 172 |
+
After adding the argument, `get_extension()` will look like this:
|
| 173 |
+
```python
|
| 174 |
+
def get_extensions():
|
| 175 |
+
srcs = ["sam2/csrc/connected_components.cu"]
|
| 176 |
+
compile_args = {
|
| 177 |
+
"cxx": [],
|
| 178 |
+
"nvcc": [
|
| 179 |
+
"-DCUDA_HAS_FP16=1",
|
| 180 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
| 181 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
| 182 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
| 183 |
+
"-allow-unsupported-compiler" # Add this argument
|
| 184 |
+
],
|
| 185 |
+
}
|
| 186 |
+
ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
|
| 187 |
+
return ext_modules
|
| 188 |
+
```
|
| 189 |
+
</details>
|
sam2/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
sam2/LICENSE_cctorch
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 3-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020, the respective contributors, as shown by the AUTHORS file.
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without
|
| 7 |
+
modification, are permitted provided that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
and/or other materials provided with the distribution.
|
| 15 |
+
|
| 16 |
+
3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
contributors may be used to endorse or promote products derived from
|
| 18 |
+
this software without specific prior written permission.
|
| 19 |
+
|
| 20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
sam2/README.md
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM 2: Segment Anything in Images and Videos
|
| 2 |
+
|
| 3 |
+
**[AI at Meta, FAIR](https://ai.meta.com/research/)**
|
| 4 |
+
|
| 5 |
+
[Nikhila Ravi](https://nikhilaravi.com/), [Valentin Gabeur](https://gabeur.github.io/), [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en), [Ronghang Hu](https://ronghanghu.com/), [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en), [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en), [Haitham Khedr](https://hkhedr.com/), [Roman Rädle](https://scholar.google.de/citations?user=Tpt57v0AAAAJ&hl=en), [Chloe Rolland](https://scholar.google.com/citations?hl=fr&user=n-SnMhoAAAAJ), [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en), [Eric Mintun](https://ericmintun.github.io/), [Junting Pan](https://junting.github.io/), [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en), [Nicolas Carion](https://www.nicolascarion.com/), [Chao-Yuan Wu](https://chaoyuan.org/), [Ross Girshick](https://www.rossgirshick.info/), [Piotr Dollár](https://pdollar.github.io/), [Christoph Feichtenhofer](https://feichtenhofer.github.io/)
|
| 6 |
+
|
| 7 |
+
[[`Paper`](https://ai.meta.com/research/publications/sam-2-segment-anything-in-images-and-videos/)] [[`Project`](https://ai.meta.com/sam2)] [[`Demo`](https://sam2.metademolab.com/)] [[`Dataset`](https://ai.meta.com/datasets/segment-anything-video)] [[`Blog`](https://ai.meta.com/blog/segment-anything-2)] [[`BibTeX`](#citing-sam-2)]
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
**Segment Anything Model 2 (SAM 2)** is a foundation model towards solving promptable visual segmentation in images and videos. We extend SAM to video by considering images as a video with a single frame. The model design is a simple transformer architecture with streaming memory for real-time video processing. We build a model-in-the-loop data engine, which improves model and data via user interaction, to collect [**our SA-V dataset**](https://ai.meta.com/datasets/segment-anything-video), the largest video segmentation dataset to date. SAM 2 trained on our data provides strong performance across a wide range of tasks and visual domains.
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
## Latest updates
|
| 16 |
+
|
| 17 |
+
**12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking**
|
| 18 |
+
|
| 19 |
+
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference.
|
| 20 |
+
- We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts.
|
| 21 |
+
- See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details.
|
| 22 |
+
|
| 23 |
+
**09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released**
|
| 24 |
+
|
| 25 |
+
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
|
| 26 |
+
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
|
| 27 |
+
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
|
| 28 |
+
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
|
| 29 |
+
|
| 30 |
+
## Installation
|
| 31 |
+
|
| 32 |
+
SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
git clone https://github.com/facebookresearch/sam2.git && cd sam2
|
| 36 |
+
|
| 37 |
+
pip install -e .
|
| 38 |
+
```
|
| 39 |
+
If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
|
| 40 |
+
|
| 41 |
+
To use the SAM 2 predictor and run the example notebooks, `jupyter` and `matplotlib` are required and can be installed by:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
pip install -e ".[notebooks]"
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Note:
|
| 48 |
+
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
|
| 49 |
+
2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
|
| 50 |
+
3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).
|
| 51 |
+
|
| 52 |
+
Please see [`INSTALL.md`](./INSTALL.md) for FAQs on potential issues and solutions.
|
| 53 |
+
|
| 54 |
+
## Getting Started
|
| 55 |
+
|
| 56 |
+
### Download Checkpoints
|
| 57 |
+
|
| 58 |
+
First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
cd checkpoints && \
|
| 62 |
+
./download_ckpts.sh && \
|
| 63 |
+
cd ..
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
or individually from:
|
| 67 |
+
|
| 68 |
+
- [sam2.1_hiera_tiny.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)
|
| 69 |
+
- [sam2.1_hiera_small.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)
|
| 70 |
+
- [sam2.1_hiera_base_plus.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)
|
| 71 |
+
- [sam2.1_hiera_large.pt](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)
|
| 72 |
+
|
| 73 |
+
(note that these are the improved checkpoints denoted as SAM 2.1; see [Model Description](#model-description) for details.)
|
| 74 |
+
|
| 75 |
+
Then SAM 2 can be used in a few lines as follows for image and video prediction.
|
| 76 |
+
|
| 77 |
+
### Image prediction
|
| 78 |
+
|
| 79 |
+
SAM 2 has all the capabilities of [SAM](https://github.com/facebookresearch/segment-anything) on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The `SAM2ImagePredictor` class has an easy interface for image prompting.
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
import torch
|
| 83 |
+
from sam2.build_sam import build_sam2
|
| 84 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 85 |
+
|
| 86 |
+
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
| 87 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 88 |
+
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
|
| 89 |
+
|
| 90 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 91 |
+
predictor.set_image(<your_image>)
|
| 92 |
+
masks, _, _ = predictor.predict(<input_prompts>)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Please refer to the examples in [image_predictor_example.ipynb](./notebooks/image_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)) for static image use cases.
|
| 96 |
+
|
| 97 |
+
SAM 2 also supports automatic mask generation on images just like SAM. Please see [automatic_mask_generator_example.ipynb](./notebooks/automatic_mask_generator_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/automatic_mask_generator_example.ipynb)) for automatic mask generation in images.
|
| 98 |
+
|
| 99 |
+
### Video prediction
|
| 100 |
+
|
| 101 |
+
For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
import torch
|
| 105 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 106 |
+
|
| 107 |
+
checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
|
| 108 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 109 |
+
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
|
| 110 |
+
|
| 111 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 112 |
+
state = predictor.init_state(<your_video>)
|
| 113 |
+
|
| 114 |
+
# add new prompts and instantly get the output on the same frame
|
| 115 |
+
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
| 116 |
+
|
| 117 |
+
# propagate the prompts to get masklets throughout the video
|
| 118 |
+
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
| 119 |
+
...
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) (also in Colab [here](https://colab.research.google.com/github/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb)) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
|
| 123 |
+
|
| 124 |
+
## Load from 🤗 Hugging Face
|
| 125 |
+
|
| 126 |
+
Alternatively, models can also be loaded from [Hugging Face](https://huggingface.co/models?search=facebook/sam2) (requires `pip install huggingface_hub`).
|
| 127 |
+
|
| 128 |
+
For image prediction:
|
| 129 |
+
|
| 130 |
+
```python
|
| 131 |
+
import torch
|
| 132 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 133 |
+
|
| 134 |
+
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
|
| 135 |
+
|
| 136 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 137 |
+
predictor.set_image(<your_image>)
|
| 138 |
+
masks, _, _ = predictor.predict(<input_prompts>)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
For video prediction:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
import torch
|
| 145 |
+
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
| 146 |
+
|
| 147 |
+
predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
|
| 148 |
+
|
| 149 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 150 |
+
state = predictor.init_state(<your_video>)
|
| 151 |
+
|
| 152 |
+
# add new prompts and instantly get the output on the same frame
|
| 153 |
+
frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
|
| 154 |
+
|
| 155 |
+
# propagate the prompts to get masklets throughout the video
|
| 156 |
+
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
| 157 |
+
...
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Model Description
|
| 161 |
+
|
| 162 |
+
### SAM 2.1 checkpoints
|
| 163 |
+
|
| 164 |
+
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
|
| 165 |
+
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
| 166 |
+
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
| 167 |
+
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
|
| 168 |
+
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
|
| 169 |
+
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
|
| 170 |
+
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
|
| 171 |
+
|
| 172 |
+
### SAM 2 checkpoints
|
| 173 |
+
|
| 174 |
+
The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:
|
| 175 |
+
|
| 176 |
+
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
| 177 |
+
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
| 178 |
+
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
|
| 179 |
+
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
|
| 180 |
+
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
|
| 181 |
+
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
|
| 182 |
+
|
| 183 |
+
Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
|
| 184 |
+
## Segment Anything Video Dataset
|
| 185 |
+
|
| 186 |
+
See [sav_dataset/README.md](sav_dataset/README.md) for details.
|
| 187 |
+
|
| 188 |
+
## Training SAM 2
|
| 189 |
+
|
| 190 |
+
You can train or fine-tune SAM 2 on custom datasets of images, videos, or both. Please check the training [README](training/README.md) on how to get started.
|
| 191 |
+
|
| 192 |
+
## Web demo for SAM 2
|
| 193 |
+
|
| 194 |
+
We have released the frontend + backend code for the SAM 2 web demo (a locally deployable version similar to https://sam2.metademolab.com/demo). Please see the web demo [README](demo/README.md) for details.
|
| 195 |
+
|
| 196 |
+
## License
|
| 197 |
+
|
| 198 |
+
The SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under [Apache 2.0](./LICENSE), however the [Inter Font](https://github.com/rsms/inter?tab=OFL-1.1-1-ov-file) and [Noto Color Emoji](https://github.com/googlefonts/noto-emoji) used in the SAM 2 demo code are made available under the [SIL Open Font License, version 1.1](https://openfontlicense.org/open-font-license-official-text/).
|
| 199 |
+
|
| 200 |
+
## Contributing
|
| 201 |
+
|
| 202 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
| 203 |
+
|
| 204 |
+
## Contributors
|
| 205 |
+
|
| 206 |
+
The SAM 2 project was made possible with the help of many contributors (alphabetical):
|
| 207 |
+
|
| 208 |
+
Karen Bergan, Daniel Bolya, Alex Bosenberg, Kai Brown, Vispi Cassod, Christopher Chedeau, Ida Cheng, Luc Dahlin, Shoubhik Debnath, Rene Martinez Doehner, Grant Gardner, Sahir Gomez, Rishi Godugu, Baishan Guo, Caleb Ho, Andrew Huang, Somya Jain, Bob Kamma, Amanda Kallet, Jake Kinney, Alexander Kirillov, Shiva Koduvayur, Devansh Kukreja, Robert Kuo, Aohan Lin, Parth Malani, Jitendra Malik, Mallika Malhotra, Miguel Martin, Alexander Miller, Sasha Mitts, William Ngan, George Orlin, Joelle Pineau, Kate Saenko, Rodrick Shepard, Azita Shokrpour, David Soofian, Jonathan Torres, Jenny Truong, Sagar Vaze, Meng Wang, Claudette Ward, Pengchuan Zhang.
|
| 209 |
+
|
| 210 |
+
Third-party code: we use a GPU-based connected component algorithm adapted from [`cc_torch`](https://github.com/zsef123/Connected_components_PyTorch) (with its license in [`LICENSE_cctorch`](./LICENSE_cctorch)) as an optional post-processing step for the mask predictions.
|
| 211 |
+
|
| 212 |
+
## Citing SAM 2
|
| 213 |
+
|
| 214 |
+
If you use SAM 2 or the SA-V dataset in your research, please use the following BibTeX entry.
|
| 215 |
+
|
| 216 |
+
```bibtex
|
| 217 |
+
@article{ravi2024sam2,
|
| 218 |
+
title={SAM 2: Segment Anything in Images and Videos},
|
| 219 |
+
author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
|
| 220 |
+
journal={arXiv preprint arXiv:2408.00714},
|
| 221 |
+
url={https://arxiv.org/abs/2408.00714},
|
| 222 |
+
year={2024}
|
| 223 |
+
}
|
| 224 |
+
```
|
sam2/checkpoints/download_ckpts.sh
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
|
| 6 |
+
# This source code is licensed under the license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
# Use either wget or curl to download the checkpoints
|
| 10 |
+
if command -v wget &> /dev/null; then
|
| 11 |
+
CMD="wget"
|
| 12 |
+
elif command -v curl &> /dev/null; then
|
| 13 |
+
CMD="curl -L -O"
|
| 14 |
+
else
|
| 15 |
+
echo "Please install wget or curl to download the checkpoints."
|
| 16 |
+
exit 1
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
# Define the URLs for SAM 2 checkpoints
|
| 20 |
+
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
|
| 21 |
+
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
|
| 22 |
+
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
|
| 23 |
+
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
|
| 24 |
+
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
|
| 25 |
+
|
| 26 |
+
# Download each of the four checkpoints using wget
|
| 27 |
+
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
|
| 28 |
+
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
|
| 29 |
+
|
| 30 |
+
# echo "Downloading sam2_hiera_small.pt checkpoint..."
|
| 31 |
+
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
|
| 32 |
+
|
| 33 |
+
# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
|
| 34 |
+
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
|
| 35 |
+
|
| 36 |
+
# echo "Downloading sam2_hiera_large.pt checkpoint..."
|
| 37 |
+
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
|
| 38 |
+
|
| 39 |
+
# Define the URLs for SAM 2.1 checkpoints
|
| 40 |
+
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
|
| 41 |
+
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
|
| 42 |
+
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
|
| 43 |
+
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
|
| 44 |
+
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
|
| 45 |
+
|
| 46 |
+
# SAM 2.1 checkpoints
|
| 47 |
+
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
|
| 48 |
+
$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
|
| 49 |
+
|
| 50 |
+
echo "Downloading sam2.1_hiera_small.pt checkpoint..."
|
| 51 |
+
$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
|
| 52 |
+
|
| 53 |
+
echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
|
| 54 |
+
$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
|
| 55 |
+
|
| 56 |
+
echo "Downloading sam2.1_hiera_large.pt checkpoint..."
|
| 57 |
+
$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
|
| 58 |
+
|
| 59 |
+
echo "All checkpoints are downloaded successfully."
|
sam2/checkpoints/unsamv2_plus_ckpt.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4067a6966b06df984828f537da0f02389ff28655a8985a1e1f4a3e1de4077195
|
| 3 |
+
size 188689844
|
sam2/notebooks/cascadepsp.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from pycocotools import mask as mask_util
|
| 3 |
+
|
| 4 |
+
def area(mask):
|
| 5 |
+
return np.count_nonzero(mask) / mask.size
|
| 6 |
+
|
| 7 |
+
def iou(mask1, mask2):
|
| 8 |
+
intersection = np.count_nonzero(np.logical_and(mask1, mask2))
|
| 9 |
+
union = np.count_nonzero(mask1) + np.count_nonzero(mask2) - intersection
|
| 10 |
+
if union == 0: return 0
|
| 11 |
+
return intersection / union
|
| 12 |
+
|
| 13 |
+
def postprocess(args, refiner, annotations, image):
|
| 14 |
+
H, W = image.shape[:2]
|
| 15 |
+
|
| 16 |
+
start_id = annotations["annotations"][0]['id']
|
| 17 |
+
curr_id = 0
|
| 18 |
+
refined_annotations = []
|
| 19 |
+
|
| 20 |
+
for annotation in annotations["annotations"]:
|
| 21 |
+
mask = mask_util.decode(annotation['segmentation'])
|
| 22 |
+
|
| 23 |
+
bbox = annotation['bbox']
|
| 24 |
+
x1, y1, w, h = bbox
|
| 25 |
+
x_center = x1 + w / 2
|
| 26 |
+
y_center = y1 + h / 2
|
| 27 |
+
|
| 28 |
+
longer_side = max(w, h)
|
| 29 |
+
x1_resized = int(max(0, x_center - longer_side))
|
| 30 |
+
y1_resized = int(max(0, y_center - longer_side))
|
| 31 |
+
x2_resized = int(min(W, x_center + longer_side))
|
| 32 |
+
y2_resized = int(min(H, y_center + longer_side))
|
| 33 |
+
|
| 34 |
+
image_crop = image[y1_resized:y2_resized, x1_resized:x2_resized, :]
|
| 35 |
+
mask_crop = mask[y1_resized:y2_resized, x1_resized:x2_resized]
|
| 36 |
+
|
| 37 |
+
L = max(min(max(x2_resized-x1_resized, y2_resized-y1_resized) * args.refine_scale, args.refine_max_L), args.refine_min_L)
|
| 38 |
+
refined_mask_crop = refiner.refine(image_crop, mask_crop * 255, fast=True, L=L)
|
| 39 |
+
refined_mask_crop = (refined_mask_crop > 128).astype(np.uint8)
|
| 40 |
+
|
| 41 |
+
refined_mask = np.zeros((H, W), dtype=np.uint8)
|
| 42 |
+
refined_mask[y1_resized:y2_resized, x1_resized:x2_resized] = refined_mask_crop
|
| 43 |
+
|
| 44 |
+
if area(refined_mask) < args.min_area_thresh or area(refined_mask) > args.max_area_thresh:
|
| 45 |
+
continue
|
| 46 |
+
if iou(mask, refined_mask) < args.iou_thresh:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
binary_mask_encoded = mask_util.encode(np.asfortranarray(refined_mask))
|
| 50 |
+
binary_mask_encoded['counts'] = binary_mask_encoded['counts'].decode('ascii')
|
| 51 |
+
|
| 52 |
+
annotation['segmentation'] = binary_mask_encoded
|
| 53 |
+
annotation['bbox'] = mask_util.toBbox(binary_mask_encoded).tolist()
|
| 54 |
+
annotation['area'] = mask_util.area(binary_mask_encoded).tolist()
|
| 55 |
+
annotation['id'] = start_id + curr_id
|
| 56 |
+
curr_id += 0
|
| 57 |
+
|
| 58 |
+
refined_annotations.append(annotation)
|
| 59 |
+
|
| 60 |
+
annotations["annotations"] = refined_annotations
|
| 61 |
+
return annotations
|
sam2/notebooks/interactive_image_segmentation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sam2/notebooks/video_segmentation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sam2/notebooks/whole_image_segmentation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sam2/pyproject.toml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = [
|
| 3 |
+
"setuptools>=61.0",
|
| 4 |
+
"torch>=2.5.1",
|
| 5 |
+
]
|
| 6 |
+
build-backend = "setuptools.build_meta"
|
sam2/sam2/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from hydra import initialize_config_module
|
| 8 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 9 |
+
|
| 10 |
+
if not GlobalHydra.instance().is_initialized():
|
| 11 |
+
initialize_config_module("sam2", version_base="1.2")
|
sam2/sam2/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (337 Bytes). View file
|
|
|
sam2/sam2/__pycache__/build_sam.cpython-310.pyc
ADDED
|
Binary file (6.38 kB). View file
|
|
|
sam2/sam2/__pycache__/granularity_embedding.cpython-310.pyc
ADDED
|
Binary file (2.54 kB). View file
|
|
|
sam2/sam2/__pycache__/sam2_image_predictor.cpython-310.pyc
ADDED
|
Binary file (16.1 kB). View file
|
|
|
sam2/sam2/automatic_mask_generator.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
| 15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
+
from sam2.utils.amg import (
|
| 17 |
+
area_from_rle,
|
| 18 |
+
batch_iterator,
|
| 19 |
+
batched_mask_to_box,
|
| 20 |
+
box_xyxy_to_xywh,
|
| 21 |
+
build_all_layer_point_grids,
|
| 22 |
+
calculate_stability_score,
|
| 23 |
+
coco_encode_rle,
|
| 24 |
+
generate_crop_boxes,
|
| 25 |
+
is_box_near_crop_edge,
|
| 26 |
+
mask_to_rle_pytorch,
|
| 27 |
+
MaskData,
|
| 28 |
+
remove_small_regions,
|
| 29 |
+
rle_to_mask,
|
| 30 |
+
uncrop_boxes_xyxy,
|
| 31 |
+
uncrop_masks,
|
| 32 |
+
uncrop_points,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SAM2AutomaticMaskGenerator:
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
model: SAM2Base,
|
| 40 |
+
points_per_side: Optional[int] = 32,
|
| 41 |
+
points_per_batch: int = 64,
|
| 42 |
+
pred_iou_thresh: float = 0.8,
|
| 43 |
+
stability_score_thresh: float = 0.95,
|
| 44 |
+
stability_score_offset: float = 1.0,
|
| 45 |
+
mask_threshold: float = 0.0,
|
| 46 |
+
box_nms_thresh: float = 0.7,
|
| 47 |
+
crop_n_layers: int = 0,
|
| 48 |
+
crop_nms_thresh: float = 0.7,
|
| 49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 50 |
+
crop_n_points_downscale_factor: int = 1,
|
| 51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
| 52 |
+
min_mask_region_area: int = 0,
|
| 53 |
+
output_mode: str = "binary_mask",
|
| 54 |
+
use_m2m: bool = False,
|
| 55 |
+
multimask_output: bool = True,
|
| 56 |
+
granularity: Optional[float] = None,
|
| 57 |
+
**kwargs,
|
| 58 |
+
) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Using a SAM 2 model, generates masks for the entire image.
|
| 61 |
+
Generates a grid of point prompts over the image, then filters
|
| 62 |
+
low quality and duplicate masks. The default settings are chosen
|
| 63 |
+
for SAM 2 with a HieraL backbone.
|
| 64 |
+
|
| 65 |
+
Arguments:
|
| 66 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
| 67 |
+
points_per_side (int or None): The number of points to be sampled
|
| 68 |
+
along one side of the image. The total number of points is
|
| 69 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
| 70 |
+
point sampling.
|
| 71 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
| 72 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
| 73 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
| 74 |
+
model's predicted mask quality.
|
| 75 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
| 76 |
+
the stability of the mask under changes to the cutoff used to binarize
|
| 77 |
+
the model's mask predictions.
|
| 78 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
| 79 |
+
calculated the stability score.
|
| 80 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
| 81 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 82 |
+
suppression to filter duplicate masks.
|
| 83 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
| 84 |
+
crops of the image. Sets the number of layers to run, where each
|
| 85 |
+
layer has 2**i_layer number of image crops.
|
| 86 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 87 |
+
suppression to filter duplicate masks between different crops.
|
| 88 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
| 89 |
+
In the first crop layer, crops will overlap by this fraction of
|
| 90 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 91 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
| 92 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 93 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
| 94 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
| 95 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
| 96 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
| 97 |
+
to remove disconnected regions and holes in masks with area smaller
|
| 98 |
+
than min_mask_region_area. Requires opencv.
|
| 99 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
| 100 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
| 101 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
| 102 |
+
memory.
|
| 103 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
| 104 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
| 105 |
+
granularity (float or None): Granularity parameter to control mask generation detail level.
|
| 106 |
+
If None, uses model default behavior.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
assert (points_per_side is None) != (
|
| 110 |
+
point_grids is None
|
| 111 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
| 112 |
+
if points_per_side is not None:
|
| 113 |
+
self.point_grids = build_all_layer_point_grids(
|
| 114 |
+
points_per_side,
|
| 115 |
+
crop_n_layers,
|
| 116 |
+
crop_n_points_downscale_factor,
|
| 117 |
+
)
|
| 118 |
+
elif point_grids is not None:
|
| 119 |
+
self.point_grids = point_grids
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 122 |
+
|
| 123 |
+
assert output_mode in [
|
| 124 |
+
"binary_mask",
|
| 125 |
+
"uncompressed_rle",
|
| 126 |
+
"coco_rle",
|
| 127 |
+
], f"Unknown output_mode {output_mode}."
|
| 128 |
+
if output_mode == "coco_rle":
|
| 129 |
+
try:
|
| 130 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
| 131 |
+
except ImportError as e:
|
| 132 |
+
print("Please install pycocotools")
|
| 133 |
+
raise e
|
| 134 |
+
|
| 135 |
+
self.predictor = SAM2ImagePredictor(
|
| 136 |
+
model,
|
| 137 |
+
max_hole_area=min_mask_region_area,
|
| 138 |
+
max_sprinkle_area=min_mask_region_area,
|
| 139 |
+
)
|
| 140 |
+
self.points_per_batch = points_per_batch
|
| 141 |
+
self.pred_iou_thresh = pred_iou_thresh
|
| 142 |
+
self.stability_score_thresh = stability_score_thresh
|
| 143 |
+
self.stability_score_offset = stability_score_offset
|
| 144 |
+
self.mask_threshold = mask_threshold
|
| 145 |
+
self.box_nms_thresh = box_nms_thresh
|
| 146 |
+
self.crop_n_layers = crop_n_layers
|
| 147 |
+
self.crop_nms_thresh = crop_nms_thresh
|
| 148 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
| 149 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 150 |
+
self.min_mask_region_area = min_mask_region_area
|
| 151 |
+
self.output_mode = output_mode
|
| 152 |
+
self.use_m2m = use_m2m
|
| 153 |
+
self.multimask_output = multimask_output
|
| 154 |
+
self.granularity = granularity
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
| 158 |
+
"""
|
| 159 |
+
Load a pretrained model from the Hugging Face hub.
|
| 160 |
+
|
| 161 |
+
Arguments:
|
| 162 |
+
model_id (str): The Hugging Face repository ID.
|
| 163 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
| 167 |
+
"""
|
| 168 |
+
from sam2.build_sam import build_sam2_hf
|
| 169 |
+
|
| 170 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 171 |
+
return cls(sam_model, **kwargs)
|
| 172 |
+
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def generate(self, image: np.ndarray, gra: float = 1.0) -> List[Dict[str, Any]]:
|
| 175 |
+
"""
|
| 176 |
+
Generates masks for the given image.
|
| 177 |
+
|
| 178 |
+
Arguments:
|
| 179 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
| 183 |
+
a dict containing the following keys:
|
| 184 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
| 185 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
| 186 |
+
is a dictionary containing the RLE.
|
| 187 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
| 188 |
+
area (int): The area in pixels of the mask.
|
| 189 |
+
predicted_iou (float): The model's own prediction of the mask's
|
| 190 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
| 191 |
+
point_coords (list(list(float))): The point coordinates input
|
| 192 |
+
to the model to generate this mask.
|
| 193 |
+
stability_score (float): A measure of the mask's quality. This
|
| 194 |
+
is filtered on using the stability_score_thresh parameter.
|
| 195 |
+
crop_box (list(float)): The crop of the image used to generate
|
| 196 |
+
the mask, given in XYWH format.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
# Generate masks
|
| 200 |
+
mask_data = self._generate_masks(image, gra)
|
| 201 |
+
|
| 202 |
+
# Encode masks
|
| 203 |
+
if self.output_mode == "coco_rle":
|
| 204 |
+
mask_data["segmentations"] = [
|
| 205 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
| 206 |
+
]
|
| 207 |
+
elif self.output_mode == "binary_mask":
|
| 208 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 209 |
+
else:
|
| 210 |
+
mask_data["segmentations"] = mask_data["rles"]
|
| 211 |
+
|
| 212 |
+
# Write mask records
|
| 213 |
+
curr_anns = []
|
| 214 |
+
for idx in range(len(mask_data["segmentations"])):
|
| 215 |
+
ann = {
|
| 216 |
+
"segmentation": mask_data["segmentations"][idx],
|
| 217 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
| 218 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
| 219 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
| 220 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
| 221 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
| 222 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
| 223 |
+
}
|
| 224 |
+
curr_anns.append(ann)
|
| 225 |
+
|
| 226 |
+
return curr_anns
|
| 227 |
+
|
| 228 |
+
def _generate_masks(self, image: np.ndarray, gra: float = 1.0) -> MaskData:
|
| 229 |
+
orig_size = image.shape[:2]
|
| 230 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 231 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Iterate over image crops
|
| 235 |
+
data = MaskData()
|
| 236 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
| 237 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, gra)
|
| 238 |
+
data.cat(crop_data)
|
| 239 |
+
|
| 240 |
+
# Remove duplicate masks between crops
|
| 241 |
+
if len(crop_boxes) > 1:
|
| 242 |
+
# Prefer masks from smaller crops
|
| 243 |
+
scores = 1 / box_area(data["crop_boxes"])
|
| 244 |
+
scores = scores.to(data["boxes"].device)
|
| 245 |
+
keep_by_nms = batched_nms(
|
| 246 |
+
data["boxes"].float(),
|
| 247 |
+
scores,
|
| 248 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 249 |
+
iou_threshold=self.crop_nms_thresh,
|
| 250 |
+
)
|
| 251 |
+
data.filter(keep_by_nms)
|
| 252 |
+
data.to_numpy()
|
| 253 |
+
return data
|
| 254 |
+
|
| 255 |
+
def _process_crop(
|
| 256 |
+
self,
|
| 257 |
+
image: np.ndarray,
|
| 258 |
+
crop_box: List[int],
|
| 259 |
+
crop_layer_idx: int,
|
| 260 |
+
orig_size: Tuple[int, ...],
|
| 261 |
+
gra: float = 1.0,
|
| 262 |
+
) -> MaskData:
|
| 263 |
+
# Crop the image and calculate embeddings
|
| 264 |
+
x0, y0, x1, y1 = crop_box
|
| 265 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
| 266 |
+
cropped_im_size = cropped_im.shape[:2]
|
| 267 |
+
self.predictor.set_image(cropped_im)
|
| 268 |
+
|
| 269 |
+
# Get points for this crop
|
| 270 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 271 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
| 272 |
+
|
| 273 |
+
# Generate masks for this crop in batches
|
| 274 |
+
data = MaskData()
|
| 275 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 276 |
+
batch_data = self._process_batch(
|
| 277 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True, gra=gra
|
| 278 |
+
)
|
| 279 |
+
data.cat(batch_data)
|
| 280 |
+
del batch_data
|
| 281 |
+
self.predictor.reset_predictor()
|
| 282 |
+
|
| 283 |
+
# Remove duplicates within this crop.
|
| 284 |
+
keep_by_nms = batched_nms(
|
| 285 |
+
data["boxes"].float(),
|
| 286 |
+
data["iou_preds"],
|
| 287 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 288 |
+
iou_threshold=self.box_nms_thresh,
|
| 289 |
+
)
|
| 290 |
+
data.filter(keep_by_nms)
|
| 291 |
+
|
| 292 |
+
# Return to the original image frame
|
| 293 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
| 294 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
| 295 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
| 296 |
+
|
| 297 |
+
return data
|
| 298 |
+
|
| 299 |
+
def _process_batch(
|
| 300 |
+
self,
|
| 301 |
+
points: np.ndarray,
|
| 302 |
+
im_size: Tuple[int, ...],
|
| 303 |
+
crop_box: List[int],
|
| 304 |
+
orig_size: Tuple[int, ...],
|
| 305 |
+
normalize=False,
|
| 306 |
+
gra: float = 1.0,
|
| 307 |
+
) -> MaskData:
|
| 308 |
+
orig_h, orig_w = orig_size
|
| 309 |
+
|
| 310 |
+
# Run model on this batch
|
| 311 |
+
points = torch.as_tensor(
|
| 312 |
+
points, dtype=torch.float32, device=self.predictor.device
|
| 313 |
+
)
|
| 314 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 315 |
+
points, normalize=normalize, orig_hw=im_size
|
| 316 |
+
)
|
| 317 |
+
in_labels = torch.ones(
|
| 318 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Use granularity as a simple float or None
|
| 322 |
+
gra_value = gra if gra is not None else 1.0
|
| 323 |
+
|
| 324 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
| 325 |
+
in_points[:, None, :],
|
| 326 |
+
in_labels[:, None],
|
| 327 |
+
multimask_output=self.multimask_output,
|
| 328 |
+
return_logits=True,
|
| 329 |
+
gra=gra_value,
|
| 330 |
+
granularity=None, # Explicitly set to None to avoid issues
|
| 331 |
+
)
|
| 332 |
+
# Serialize predictions and store in MaskData
|
| 333 |
+
data = MaskData(
|
| 334 |
+
masks=masks.flatten(0, 1),
|
| 335 |
+
iou_preds=iou_preds.flatten(0, 1),
|
| 336 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
| 337 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
| 338 |
+
)
|
| 339 |
+
del masks
|
| 340 |
+
|
| 341 |
+
if not self.use_m2m:
|
| 342 |
+
# Filter by predicted IoU
|
| 343 |
+
if self.pred_iou_thresh > 0.0:
|
| 344 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 345 |
+
data.filter(keep_mask)
|
| 346 |
+
# Calculate and filter by stability score
|
| 347 |
+
data["stability_score"] = calculate_stability_score(
|
| 348 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 349 |
+
)
|
| 350 |
+
if self.stability_score_thresh > 0.0:
|
| 351 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 352 |
+
data.filter(keep_mask)
|
| 353 |
+
else:
|
| 354 |
+
# One step refinement using previous mask predictions
|
| 355 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 356 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
| 357 |
+
)
|
| 358 |
+
labels = torch.ones(
|
| 359 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 360 |
+
)
|
| 361 |
+
masks, ious = self.refine_with_m2m(
|
| 362 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch, gra=gra
|
| 363 |
+
)
|
| 364 |
+
data["masks"] = masks.squeeze(1)
|
| 365 |
+
data["iou_preds"] = ious.squeeze(1)
|
| 366 |
+
|
| 367 |
+
if self.pred_iou_thresh > 0.0:
|
| 368 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 369 |
+
data.filter(keep_mask)
|
| 370 |
+
|
| 371 |
+
data["stability_score"] = calculate_stability_score(
|
| 372 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 373 |
+
)
|
| 374 |
+
if self.stability_score_thresh > 0.0:
|
| 375 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 376 |
+
data.filter(keep_mask)
|
| 377 |
+
|
| 378 |
+
# Threshold masks and calculate boxes
|
| 379 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
| 380 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 381 |
+
|
| 382 |
+
# Filter boxes that touch crop boundaries
|
| 383 |
+
keep_mask = ~is_box_near_crop_edge(
|
| 384 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
| 385 |
+
)
|
| 386 |
+
if not torch.all(keep_mask):
|
| 387 |
+
data.filter(keep_mask)
|
| 388 |
+
|
| 389 |
+
# Compress to RLE
|
| 390 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
| 391 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
| 392 |
+
del data["masks"]
|
| 393 |
+
|
| 394 |
+
return data
|
| 395 |
+
|
| 396 |
+
@staticmethod
|
| 397 |
+
def postprocess_small_regions(
|
| 398 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 399 |
+
) -> MaskData:
|
| 400 |
+
"""
|
| 401 |
+
Removes small disconnected regions and holes in masks, then reruns
|
| 402 |
+
box NMS to remove any new duplicates.
|
| 403 |
+
|
| 404 |
+
Edits mask_data in place.
|
| 405 |
+
|
| 406 |
+
Requires open-cv as a dependency.
|
| 407 |
+
"""
|
| 408 |
+
if len(mask_data["rles"]) == 0:
|
| 409 |
+
return mask_data
|
| 410 |
+
|
| 411 |
+
# Filter small disconnected regions and holes
|
| 412 |
+
new_masks = []
|
| 413 |
+
scores = []
|
| 414 |
+
for rle in mask_data["rles"]:
|
| 415 |
+
mask = rle_to_mask(rle)
|
| 416 |
+
|
| 417 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
| 418 |
+
unchanged = not changed
|
| 419 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
| 420 |
+
unchanged = unchanged and not changed
|
| 421 |
+
|
| 422 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
| 423 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
| 424 |
+
# so NMS will prefer ones that didn't need postprocessing
|
| 425 |
+
scores.append(float(unchanged))
|
| 426 |
+
|
| 427 |
+
# Recalculate boxes and remove any new duplicates
|
| 428 |
+
masks = torch.cat(new_masks, dim=0)
|
| 429 |
+
boxes = batched_mask_to_box(masks)
|
| 430 |
+
keep_by_nms = batched_nms(
|
| 431 |
+
boxes.float(),
|
| 432 |
+
torch.as_tensor(scores),
|
| 433 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
| 434 |
+
iou_threshold=nms_thresh,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Only recalculate RLEs for masks that have changed
|
| 438 |
+
for i_mask in keep_by_nms:
|
| 439 |
+
if scores[i_mask] == 0.0:
|
| 440 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
| 441 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
| 442 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
| 443 |
+
mask_data.filter(keep_by_nms)
|
| 444 |
+
|
| 445 |
+
return mask_data
|
| 446 |
+
|
| 447 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch, gra=1.0):
|
| 448 |
+
new_masks = []
|
| 449 |
+
new_iou_preds = []
|
| 450 |
+
|
| 451 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
| 452 |
+
points_per_batch, points, point_labels, low_res_masks
|
| 453 |
+
):
|
| 454 |
+
# Use granularity as a simple float for M2M
|
| 455 |
+
gra_value = gra if gra is not None else 1.0
|
| 456 |
+
|
| 457 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
| 458 |
+
cur_points[:, None, :],
|
| 459 |
+
cur_point_labels[:, None],
|
| 460 |
+
mask_input=low_res_mask[:, None, :],
|
| 461 |
+
multimask_output=False,
|
| 462 |
+
return_logits=True,
|
| 463 |
+
gra=gra_value,
|
| 464 |
+
granularity=None, # Explicitly set to None to avoid issues
|
| 465 |
+
)
|
| 466 |
+
new_masks.append(best_masks)
|
| 467 |
+
new_iou_preds.append(best_iou_preds)
|
| 468 |
+
masks = torch.cat(new_masks, dim=0)
|
| 469 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
sam2/sam2/build_sam.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from hydra import compose
|
| 12 |
+
from hydra.utils import instantiate
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
|
| 15 |
+
import sam2
|
| 16 |
+
from training.utils.checkpoint_utils import (
|
| 17 |
+
apply_lora_state_dict,
|
| 18 |
+
split_lora_state_dict,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
| 22 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
| 23 |
+
# it could shadow the sam2 package and cause issues.
|
| 24 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
| 25 |
+
print(os.path.join(sam2.__path__[0], "sam2"))
|
| 26 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
| 27 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
| 28 |
+
# This typically happens because the user is running Python from the parent directory
|
| 29 |
+
# that contains the sam2 repo they cloned.
|
| 30 |
+
raise RuntimeError(
|
| 31 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
| 32 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
| 33 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
| 34 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
| 35 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
| 36 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
| 41 |
+
"facebook/sam2-hiera-tiny": (
|
| 42 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
| 43 |
+
"sam2_hiera_tiny.pt",
|
| 44 |
+
),
|
| 45 |
+
"facebook/sam2-hiera-small": (
|
| 46 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
| 47 |
+
"sam2_hiera_small.pt",
|
| 48 |
+
),
|
| 49 |
+
"facebook/sam2-hiera-base-plus": (
|
| 50 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
| 51 |
+
"sam2_hiera_base_plus.pt",
|
| 52 |
+
),
|
| 53 |
+
"facebook/sam2-hiera-large": (
|
| 54 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
| 55 |
+
"sam2_hiera_large.pt",
|
| 56 |
+
),
|
| 57 |
+
"facebook/sam2.1-hiera-tiny": (
|
| 58 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
| 59 |
+
"sam2.1_hiera_tiny.pt",
|
| 60 |
+
),
|
| 61 |
+
"facebook/sam2.1-hiera-small": (
|
| 62 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
| 63 |
+
"sam2.1_hiera_small.pt",
|
| 64 |
+
),
|
| 65 |
+
"facebook/sam2.1-hiera-base-plus": (
|
| 66 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
| 67 |
+
"sam2.1_hiera_base_plus.pt",
|
| 68 |
+
),
|
| 69 |
+
"facebook/sam2.1-hiera-large": (
|
| 70 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
| 71 |
+
"sam2.1_hiera_large.pt",
|
| 72 |
+
),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def build_sam2(
|
| 77 |
+
config_file,
|
| 78 |
+
ckpt_path=None,
|
| 79 |
+
device="cuda",
|
| 80 |
+
mode="eval",
|
| 81 |
+
hydra_overrides_extra=[],
|
| 82 |
+
apply_postprocessing=True,
|
| 83 |
+
**kwargs,
|
| 84 |
+
):
|
| 85 |
+
|
| 86 |
+
if apply_postprocessing:
|
| 87 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 88 |
+
hydra_overrides_extra += [
|
| 89 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 90 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 91 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 92 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 93 |
+
]
|
| 94 |
+
# Read config and init model
|
| 95 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 96 |
+
OmegaConf.resolve(cfg)
|
| 97 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 98 |
+
|
| 99 |
+
_load_checkpoint(model, ckpt_path)
|
| 100 |
+
_setup_lora_after_loading(model)
|
| 101 |
+
model = model.to(device)
|
| 102 |
+
if mode == "eval":
|
| 103 |
+
model.eval()
|
| 104 |
+
return model
|
| 105 |
+
|
| 106 |
+
def _setup_lora_after_loading(model):
|
| 107 |
+
"""Setup LoRA modules after loading pretrained weights."""
|
| 108 |
+
if not (
|
| 109 |
+
hasattr(model, 'sam_mask_decoder') and
|
| 110 |
+
hasattr(model.sam_mask_decoder, 'transformer') and
|
| 111 |
+
hasattr(model.sam_mask_decoder.transformer, 'lora_config') and
|
| 112 |
+
model.sam_mask_decoder.transformer.lora_config is not None
|
| 113 |
+
):
|
| 114 |
+
logging.info("No LoRA config found, skipping LoRA setup")
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
logging.info("Setting up LoRA after loading pretrained weights...")
|
| 118 |
+
model.sam_mask_decoder.transformer.setup_lora_after_loading()
|
| 119 |
+
_load_pending_lora_weights(model)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _load_pending_lora_weights(model):
|
| 123 |
+
pending = getattr(model, "_pending_lora_state_dict", None)
|
| 124 |
+
if not pending:
|
| 125 |
+
logging.debug("No pending LoRA weights to load.")
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
missing = apply_lora_state_dict(model, pending)
|
| 129 |
+
if missing:
|
| 130 |
+
preview = missing[:5]
|
| 131 |
+
logging.warning(
|
| 132 |
+
"LoRA weights missing in model: %s%s",
|
| 133 |
+
preview,
|
| 134 |
+
f" ... (+{len(missing) - len(preview)} more)" if len(missing) > len(preview) else "",
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
delattr(model, "_pending_lora_state_dict")
|
| 138 |
+
_log_lora_weight_samples(model)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _log_lora_weight_samples(model, limit: int = 5):
|
| 142 |
+
logged = 0
|
| 143 |
+
for name, param in model.named_parameters():
|
| 144 |
+
if "lora_" not in name:
|
| 145 |
+
continue
|
| 146 |
+
tensor = param.detach()
|
| 147 |
+
mean = float(tensor.mean())
|
| 148 |
+
std = float(tensor.std(unbiased=False))
|
| 149 |
+
max_abs = float(tensor.abs().max())
|
| 150 |
+
logged += 1
|
| 151 |
+
if logged >= limit:
|
| 152 |
+
break
|
| 153 |
+
if logged == 0:
|
| 154 |
+
logging.info("LoRA logging skipped: no lora_ parameters found.")
|
| 155 |
+
|
| 156 |
+
def build_sam2_video_predictor(
|
| 157 |
+
config_file,
|
| 158 |
+
ckpt_path=None,
|
| 159 |
+
device="cuda",
|
| 160 |
+
mode="eval",
|
| 161 |
+
hydra_overrides_extra=[],
|
| 162 |
+
apply_postprocessing=True,
|
| 163 |
+
vos_optimized=False,
|
| 164 |
+
**kwargs,
|
| 165 |
+
):
|
| 166 |
+
hydra_overrides = [
|
| 167 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
| 168 |
+
]
|
| 169 |
+
if vos_optimized:
|
| 170 |
+
hydra_overrides = [
|
| 171 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
|
| 172 |
+
"++model.compile_image_encoder=True", # Let sam2_base handle this
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
if apply_postprocessing:
|
| 176 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 177 |
+
hydra_overrides_extra += [
|
| 178 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 179 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 180 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 181 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 182 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
| 183 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
| 184 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
| 185 |
+
"++model.fill_hole_area=8",
|
| 186 |
+
]
|
| 187 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 188 |
+
|
| 189 |
+
# Read config and init model
|
| 190 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
| 191 |
+
OmegaConf.resolve(cfg)
|
| 192 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 193 |
+
|
| 194 |
+
_load_checkpoint(model, ckpt_path)
|
| 195 |
+
_setup_lora_after_loading(model)
|
| 196 |
+
model = model.to(device)
|
| 197 |
+
if mode == "eval":
|
| 198 |
+
model.eval()
|
| 199 |
+
return model
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _hf_download(model_id):
|
| 203 |
+
from huggingface_hub import hf_hub_download
|
| 204 |
+
|
| 205 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
| 206 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
| 207 |
+
return config_name, ckpt_path
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def build_sam2_hf(model_id, **kwargs):
|
| 211 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 212 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
| 216 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 217 |
+
return build_sam2_video_predictor(
|
| 218 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
_OPTIONAL_MISSING_KEYS = {
|
| 223 |
+
"sam_prompt_encoder.granularity_round_embedding.gran_values",
|
| 224 |
+
"sam_prompt_encoder.granularity_round_embedding.embedding.weight",
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _load_checkpoint(model, ckpt_path):
|
| 229 |
+
if ckpt_path is not None:
|
| 230 |
+
sd_full = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
| 231 |
+
base_state, lora_state = split_lora_state_dict(sd_full)
|
| 232 |
+
|
| 233 |
+
missing_keys, unexpected_keys = model.load_state_dict(base_state, strict=False)
|
| 234 |
+
|
| 235 |
+
optional_missing = [k for k in missing_keys if k in _OPTIONAL_MISSING_KEYS]
|
| 236 |
+
if optional_missing:
|
| 237 |
+
logging.warning(
|
| 238 |
+
"Checkpoint is missing optional parameters, using defaults: %s",
|
| 239 |
+
optional_missing,
|
| 240 |
+
)
|
| 241 |
+
missing_keys = [k for k in missing_keys if k not in _OPTIONAL_MISSING_KEYS]
|
| 242 |
+
if missing_keys:
|
| 243 |
+
logging.error(missing_keys)
|
| 244 |
+
raise RuntimeError()
|
| 245 |
+
if unexpected_keys:
|
| 246 |
+
logging.error(unexpected_keys)
|
| 247 |
+
raise RuntimeError()
|
| 248 |
+
if lora_state:
|
| 249 |
+
pending = getattr(model, "_pending_lora_state_dict", {})
|
| 250 |
+
pending.update(lora_state)
|
| 251 |
+
setattr(model, "_pending_lora_state_dict", pending)
|
| 252 |
+
logging.info("Loaded checkpoint sucessfully")
|
sam2/sam2/configs/unsamv2_small.yaml
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 96
|
| 12 |
+
num_heads: 1
|
| 13 |
+
stages: [1, 2, 11, 2]
|
| 14 |
+
global_att_blocks: [7, 10, 13]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
neck:
|
| 17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 18 |
+
position_encoding:
|
| 19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 20 |
+
num_pos_feats: 256
|
| 21 |
+
normalize: true
|
| 22 |
+
scale: null
|
| 23 |
+
temperature: 10000
|
| 24 |
+
d_model: 256
|
| 25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 27 |
+
fpn_interp_model: nearest
|
| 28 |
+
|
| 29 |
+
memory_attention:
|
| 30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 31 |
+
d_model: 256
|
| 32 |
+
pos_enc_at_input: true
|
| 33 |
+
layer:
|
| 34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 35 |
+
activation: relu
|
| 36 |
+
dim_feedforward: 2048
|
| 37 |
+
dropout: 0.1
|
| 38 |
+
pos_enc_at_attn: false
|
| 39 |
+
self_attention:
|
| 40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 41 |
+
rope_theta: 10000.0
|
| 42 |
+
feat_sizes: [64, 64]
|
| 43 |
+
embedding_dim: 256
|
| 44 |
+
num_heads: 1
|
| 45 |
+
downsample_rate: 1
|
| 46 |
+
dropout: 0.1
|
| 47 |
+
d_model: 256
|
| 48 |
+
pos_enc_at_cross_attn_keys: true
|
| 49 |
+
pos_enc_at_cross_attn_queries: false
|
| 50 |
+
cross_attention:
|
| 51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 52 |
+
rope_theta: 10000.0
|
| 53 |
+
feat_sizes: [64, 64]
|
| 54 |
+
rope_k_repeat: True
|
| 55 |
+
embedding_dim: 256
|
| 56 |
+
num_heads: 1
|
| 57 |
+
downsample_rate: 1
|
| 58 |
+
dropout: 0.1
|
| 59 |
+
kv_in_dim: 64
|
| 60 |
+
num_layers: 4
|
| 61 |
+
temperature: 100
|
| 62 |
+
fourier_dim: 128
|
| 63 |
+
lora_rank: 8
|
| 64 |
+
|
| 65 |
+
memory_encoder:
|
| 66 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 67 |
+
out_dim: 64
|
| 68 |
+
position_encoding:
|
| 69 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 70 |
+
num_pos_feats: 64
|
| 71 |
+
normalize: true
|
| 72 |
+
scale: null
|
| 73 |
+
temperature: 10000
|
| 74 |
+
mask_downsampler:
|
| 75 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 76 |
+
kernel_size: 3
|
| 77 |
+
stride: 2
|
| 78 |
+
padding: 1
|
| 79 |
+
fuser:
|
| 80 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 81 |
+
layer:
|
| 82 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 83 |
+
dim: 256
|
| 84 |
+
kernel_size: 7
|
| 85 |
+
padding: 3
|
| 86 |
+
layer_scale_init_value: 1e-6
|
| 87 |
+
use_dwconv: True # depth-wise convs
|
| 88 |
+
num_layers: 2
|
| 89 |
+
|
| 90 |
+
num_maskmem: 7
|
| 91 |
+
image_size: 1024
|
| 92 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 93 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 94 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 95 |
+
use_mask_input_as_output_without_sam: true
|
| 96 |
+
# Memory
|
| 97 |
+
directly_add_no_mem_embed: true
|
| 98 |
+
no_obj_embed_spatial: true
|
| 99 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 100 |
+
use_high_res_features_in_sam: true
|
| 101 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 102 |
+
multimask_output_in_sam: false
|
| 103 |
+
# SAM heads
|
| 104 |
+
iou_prediction_use_sigmoid: True
|
| 105 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 106 |
+
use_obj_ptrs_in_encoder: true
|
| 107 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 108 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 109 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 110 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 111 |
+
# object occlusion prediction
|
| 112 |
+
pred_obj_scores: true
|
| 113 |
+
pred_obj_scores_mlp: true
|
| 114 |
+
fixed_no_obj_ptr: true
|
| 115 |
+
# multimask tracking settings
|
| 116 |
+
multimask_output_for_tracking: false
|
| 117 |
+
use_multimask_token_for_obj_ptr: false
|
| 118 |
+
multimask_min_pt_num: 0
|
| 119 |
+
multimask_max_pt_num: 1
|
| 120 |
+
use_mlp_for_obj_ptr_proj: true
|
| 121 |
+
# Compilation flag
|
| 122 |
+
compile_image_encoder: False
|
sam2/sam2/configs/unsamv2_small_training.yaml
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
scratch:
|
| 4 |
+
resolution: 1024
|
| 5 |
+
train_batch_size: 2
|
| 6 |
+
num_train_workers: 10
|
| 7 |
+
num_frames: 1
|
| 8 |
+
max_num_objects: 30
|
| 9 |
+
base_lr: 1e-4
|
| 10 |
+
vision_lr: 3e-6
|
| 11 |
+
phases_per_epoch: 1
|
| 12 |
+
num_epochs: 5
|
| 13 |
+
checkpoint_path: "../checkpoints/sam2.1_hiera_small.pt"
|
| 14 |
+
|
| 15 |
+
dataset:
|
| 16 |
+
# PATHS to Dataset
|
| 17 |
+
img_folder: "/home/yujunwei/UnSAM/sa1b_023_data/images"
|
| 18 |
+
gt_folder: "/home/yujunwei/dinov3_gra/sa1b_gt_results_real/combined"
|
| 19 |
+
multiplier: 2
|
| 20 |
+
|
| 21 |
+
# Video transforms
|
| 22 |
+
vos:
|
| 23 |
+
train_transforms:
|
| 24 |
+
- _target_: training.dataset.transforms.ComposeAPI
|
| 25 |
+
transforms:
|
| 26 |
+
- _target_: training.dataset.transforms.RandomHorizontalFlip
|
| 27 |
+
consistent_transform: True
|
| 28 |
+
- _target_: training.dataset.transforms.RandomAffine
|
| 29 |
+
degrees: 25
|
| 30 |
+
shear: 20
|
| 31 |
+
image_interpolation: bilinear
|
| 32 |
+
consistent_transform: True
|
| 33 |
+
- _target_: training.dataset.transforms.RandomResizeAPI
|
| 34 |
+
sizes: ${scratch.resolution}
|
| 35 |
+
square: true
|
| 36 |
+
consistent_transform: True
|
| 37 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 38 |
+
consistent_transform: True
|
| 39 |
+
brightness: 0.1
|
| 40 |
+
contrast: 0.03
|
| 41 |
+
saturation: 0.03
|
| 42 |
+
hue: null
|
| 43 |
+
- _target_: training.dataset.transforms.RandomGrayscale
|
| 44 |
+
p: 0.05
|
| 45 |
+
consistent_transform: True
|
| 46 |
+
- _target_: training.dataset.transforms.ColorJitter
|
| 47 |
+
consistent_transform: False
|
| 48 |
+
brightness: 0.1
|
| 49 |
+
contrast: 0.05
|
| 50 |
+
saturation: 0.05
|
| 51 |
+
hue: null
|
| 52 |
+
- _target_: training.dataset.transforms.ToTensorAPI
|
| 53 |
+
- _target_: training.dataset.transforms.NormalizeAPI
|
| 54 |
+
mean: [0.485, 0.456, 0.406]
|
| 55 |
+
std: [0.229, 0.224, 0.225]
|
| 56 |
+
|
| 57 |
+
trainer:
|
| 58 |
+
_target_: training.trainer.Trainer
|
| 59 |
+
mode: train_only
|
| 60 |
+
max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
|
| 61 |
+
accelerator: cuda
|
| 62 |
+
seed_value: 1234
|
| 63 |
+
|
| 64 |
+
model:
|
| 65 |
+
_target_: training.model.sam2.SAM2Train
|
| 66 |
+
image_encoder:
|
| 67 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 68 |
+
scalp: 1
|
| 69 |
+
trunk:
|
| 70 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 71 |
+
embed_dim: 96
|
| 72 |
+
num_heads: 1
|
| 73 |
+
stages: [1, 2, 11, 2]
|
| 74 |
+
global_att_blocks: [7, 10, 13]
|
| 75 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 76 |
+
neck:
|
| 77 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 78 |
+
position_encoding:
|
| 79 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 80 |
+
num_pos_feats: 256
|
| 81 |
+
normalize: true
|
| 82 |
+
scale: null
|
| 83 |
+
temperature: 10000
|
| 84 |
+
d_model: 256
|
| 85 |
+
backbone_channel_list: [768, 384, 192, 96]
|
| 86 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 87 |
+
fpn_interp_model: nearest
|
| 88 |
+
|
| 89 |
+
memory_attention:
|
| 90 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 91 |
+
d_model: 256
|
| 92 |
+
pos_enc_at_input: true
|
| 93 |
+
layer:
|
| 94 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 95 |
+
activation: relu
|
| 96 |
+
dim_feedforward: 2048
|
| 97 |
+
dropout: 0.1
|
| 98 |
+
pos_enc_at_attn: false
|
| 99 |
+
self_attention:
|
| 100 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 101 |
+
rope_theta: 10000.0
|
| 102 |
+
feat_sizes: [32, 32]
|
| 103 |
+
embedding_dim: 256
|
| 104 |
+
num_heads: 1
|
| 105 |
+
downsample_rate: 1
|
| 106 |
+
dropout: 0.1
|
| 107 |
+
d_model: 256
|
| 108 |
+
pos_enc_at_cross_attn_keys: true
|
| 109 |
+
pos_enc_at_cross_attn_queries: false
|
| 110 |
+
cross_attention:
|
| 111 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 112 |
+
rope_theta: 10000.0
|
| 113 |
+
feat_sizes: [32, 32]
|
| 114 |
+
rope_k_repeat: True
|
| 115 |
+
embedding_dim: 256
|
| 116 |
+
num_heads: 1
|
| 117 |
+
downsample_rate: 1
|
| 118 |
+
dropout: 0.1
|
| 119 |
+
kv_in_dim: 64
|
| 120 |
+
num_layers: 4
|
| 121 |
+
|
| 122 |
+
memory_encoder:
|
| 123 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 124 |
+
out_dim: 64
|
| 125 |
+
position_encoding:
|
| 126 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 127 |
+
num_pos_feats: 64
|
| 128 |
+
normalize: true
|
| 129 |
+
scale: null
|
| 130 |
+
temperature: 10000
|
| 131 |
+
mask_downsampler:
|
| 132 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 133 |
+
kernel_size: 3
|
| 134 |
+
stride: 2
|
| 135 |
+
padding: 1
|
| 136 |
+
fuser:
|
| 137 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 138 |
+
layer:
|
| 139 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 140 |
+
dim: 256
|
| 141 |
+
kernel_size: 7
|
| 142 |
+
padding: 3
|
| 143 |
+
layer_scale_init_value: 1e-6
|
| 144 |
+
use_dwconv: True # depth-wise convs
|
| 145 |
+
num_layers: 2
|
| 146 |
+
freeze_image_encoder: True
|
| 147 |
+
temperature: 100
|
| 148 |
+
fourier_dim: 128
|
| 149 |
+
lora_rank: 8
|
| 150 |
+
use_threshold_adjustment: False
|
| 151 |
+
|
| 152 |
+
num_maskmem: 7
|
| 153 |
+
image_size: 1024
|
| 154 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 155 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 156 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 157 |
+
use_mask_input_as_output_without_sam: true
|
| 158 |
+
# Memory
|
| 159 |
+
directly_add_no_mem_embed: true
|
| 160 |
+
no_obj_embed_spatial: true
|
| 161 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 162 |
+
use_high_res_features_in_sam: true
|
| 163 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 164 |
+
multimask_output_in_sam: false
|
| 165 |
+
# SAM heads
|
| 166 |
+
iou_prediction_use_sigmoid: True
|
| 167 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 168 |
+
use_obj_ptrs_in_encoder: true
|
| 169 |
+
add_tpos_enc_to_obj_ptrs: true
|
| 170 |
+
proj_tpos_enc_in_obj_ptrs: true
|
| 171 |
+
use_signed_tpos_enc_to_obj_ptrs: true
|
| 172 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 173 |
+
# object occlusion prediction
|
| 174 |
+
pred_obj_scores: true
|
| 175 |
+
pred_obj_scores_mlp: true
|
| 176 |
+
fixed_no_obj_ptr: true
|
| 177 |
+
# multimask tracking settings
|
| 178 |
+
multimask_output_for_tracking: false
|
| 179 |
+
use_multimask_token_for_obj_ptr: false
|
| 180 |
+
multimask_min_pt_num: 0
|
| 181 |
+
multimask_max_pt_num: 1
|
| 182 |
+
use_mlp_for_obj_ptr_proj: true
|
| 183 |
+
# Compilation flag
|
| 184 |
+
# compile_image_encoder: False
|
| 185 |
+
|
| 186 |
+
####### Training specific params #######
|
| 187 |
+
# box/point input and corrections
|
| 188 |
+
prob_to_use_pt_input_for_train: 1.0
|
| 189 |
+
prob_to_use_pt_input_for_eval: 0.0
|
| 190 |
+
prob_to_use_box_input_for_train: 0.0
|
| 191 |
+
prob_to_use_box_input_for_eval: 0.0
|
| 192 |
+
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
|
| 193 |
+
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
|
| 194 |
+
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
|
| 195 |
+
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
|
| 196 |
+
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
|
| 197 |
+
# maximum 2 initial conditioning frames
|
| 198 |
+
num_init_cond_frames_for_train: 2
|
| 199 |
+
rand_init_cond_frames_for_train: True # random 1~2
|
| 200 |
+
num_correction_pt_per_frame: 3
|
| 201 |
+
use_act_ckpt_iterative_pt_sampling: false
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
num_init_cond_frames_for_eval: 1 # only mask on the first frame
|
| 206 |
+
forward_backbone_per_frame_for_eval: True
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
data:
|
| 210 |
+
train:
|
| 211 |
+
_target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
|
| 212 |
+
phases_per_epoch: ${scratch.phases_per_epoch}
|
| 213 |
+
batch_sizes:
|
| 214 |
+
- ${scratch.train_batch_size}
|
| 215 |
+
|
| 216 |
+
datasets:
|
| 217 |
+
- _target_: training.dataset.vos_dataset.VOSDataset
|
| 218 |
+
training: true
|
| 219 |
+
video_dataset:
|
| 220 |
+
_target_: training.dataset.vos_raw_dataset.UnSAMRawDataset
|
| 221 |
+
img_folder: ${dataset.img_folder}
|
| 222 |
+
gt_folder: ${dataset.gt_folder}
|
| 223 |
+
tsv_file: ${dataset.tsv_file}
|
| 224 |
+
lineidx_file: ${dataset.lineidx_file}
|
| 225 |
+
sampler:
|
| 226 |
+
_target_: training.dataset.vos_sampler.RandomUniformSampler
|
| 227 |
+
num_frames: 1
|
| 228 |
+
max_num_objects: ${scratch.max_num_objects}
|
| 229 |
+
transforms: ${vos.train_transforms}
|
| 230 |
+
multiplier: ${dataset.multiplier}
|
| 231 |
+
shuffle: True
|
| 232 |
+
num_workers: ${scratch.num_train_workers}
|
| 233 |
+
pin_memory: True
|
| 234 |
+
drop_last: True
|
| 235 |
+
collate_fn:
|
| 236 |
+
_target_: training.utils.data_utils.collate_fn
|
| 237 |
+
_partial_: true
|
| 238 |
+
dict_key: all
|
| 239 |
+
|
| 240 |
+
optim:
|
| 241 |
+
amp:
|
| 242 |
+
enabled: True
|
| 243 |
+
amp_dtype: bfloat16
|
| 244 |
+
|
| 245 |
+
optimizer:
|
| 246 |
+
_target_: torch.optim.AdamW
|
| 247 |
+
|
| 248 |
+
options:
|
| 249 |
+
lr:
|
| 250 |
+
- scheduler:
|
| 251 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 252 |
+
start_value: ${scratch.base_lr}
|
| 253 |
+
end_value: ${divide:${scratch.base_lr},10}
|
| 254 |
+
- scheduler:
|
| 255 |
+
_target_: fvcore.common.param_scheduler.CosineParamScheduler
|
| 256 |
+
start_value: ${divide:${scratch.base_lr},10}
|
| 257 |
+
end_value: ${divide:${divide:${scratch.base_lr},10},10}
|
| 258 |
+
param_names:
|
| 259 |
+
- "image_encoder.*"
|
| 260 |
+
weight_decay:
|
| 261 |
+
- scheduler:
|
| 262 |
+
_target_: fvcore.common.param_scheduler.ConstantParamScheduler
|
| 263 |
+
value: 1e-5
|
| 264 |
+
|
| 265 |
+
loss:
|
| 266 |
+
all:
|
| 267 |
+
_target_: training.loss_fns.MultiStepMultiMasksAndIous
|
| 268 |
+
weight_dict:
|
| 269 |
+
loss_mask: 20
|
| 270 |
+
loss_dice: 1
|
| 271 |
+
loss_iou: 1
|
| 272 |
+
loss_class: 1
|
| 273 |
+
supervise_all_iou: true
|
| 274 |
+
iou_use_l1_loss: true
|
| 275 |
+
pred_obj_scores: true
|
| 276 |
+
focal_gamma_obj_score: 0.0
|
| 277 |
+
focal_alpha_obj_score: -1.0
|
| 278 |
+
|
| 279 |
+
distributed:
|
| 280 |
+
backend: nccl
|
| 281 |
+
find_unused_parameters: True
|
| 282 |
+
|
| 283 |
+
logging:
|
| 284 |
+
tensorboard_writer:
|
| 285 |
+
_target_: training.utils.logger.make_tensorboard_logger
|
| 286 |
+
log_dir: ${launcher.experiment_log_dir}/tensorboard
|
| 287 |
+
flush_secs: 120
|
| 288 |
+
should_log: True
|
| 289 |
+
log_dir: ${launcher.experiment_log_dir}/logs
|
| 290 |
+
log_freq: 10
|
| 291 |
+
|
| 292 |
+
# initialize from a SAM 2 checkpoint
|
| 293 |
+
checkpoint:
|
| 294 |
+
save_dir: ${launcher.experiment_log_dir}/checkpoints
|
| 295 |
+
save_freq: 1 # 0 only last checkpoint is saved.
|
| 296 |
+
model_weight_initializer:
|
| 297 |
+
_partial_: True
|
| 298 |
+
_target_: training.utils.checkpoint_utils.load_state_dict_into_model
|
| 299 |
+
strict: False
|
| 300 |
+
ignore_unexpected_keys: null
|
| 301 |
+
ignore_missing_keys: null
|
| 302 |
+
|
| 303 |
+
state_dict:
|
| 304 |
+
_target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
|
| 305 |
+
checkpoint_path: ${scratch.checkpoint_path} # PATH to SAM 2.1 checkpoint
|
| 306 |
+
ckpt_state_dict_keys: ['model']
|
| 307 |
+
|
| 308 |
+
launcher:
|
| 309 |
+
num_nodes: 1
|
| 310 |
+
gpus_per_node: 2
|
| 311 |
+
experiment_log_dir: ./exp_logs/path/to/output/dir # Path to log directory, defaults to ./sam2_logs/${config_name}
|
| 312 |
+
|
| 313 |
+
# SLURM args if running on a cluster
|
| 314 |
+
submitit:
|
| 315 |
+
partition: null
|
| 316 |
+
account: null
|
| 317 |
+
qos: null
|
| 318 |
+
cpus_per_task: 10
|
| 319 |
+
use_cluster: false
|
| 320 |
+
timeout_hour: 24
|
| 321 |
+
name: null
|
| 322 |
+
port_range: [62000, 65000]
|
| 323 |
+
|
sam2/sam2/csrc/connected_components.cu
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
// adapted from https://github.com/zsef123/Connected_components_PyTorch
|
| 8 |
+
// with license found in the LICENSE_cctorch file in the root directory.
|
| 9 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 10 |
+
#include <cuda.h>
|
| 11 |
+
#include <cuda_runtime.h>
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
#include <torch/script.h>
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
// 2d
|
| 17 |
+
#define BLOCK_ROWS 16
|
| 18 |
+
#define BLOCK_COLS 16
|
| 19 |
+
|
| 20 |
+
namespace cc2d {
|
| 21 |
+
|
| 22 |
+
template <typename T>
|
| 23 |
+
__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
|
| 24 |
+
return (bitmap >> pos) & 1;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
__device__ int32_t find(const int32_t* s_buf, int32_t n) {
|
| 28 |
+
while (s_buf[n] != n)
|
| 29 |
+
n = s_buf[n];
|
| 30 |
+
return n;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
|
| 34 |
+
const int32_t id = n;
|
| 35 |
+
while (s_buf[n] != n) {
|
| 36 |
+
n = s_buf[n];
|
| 37 |
+
s_buf[id] = n;
|
| 38 |
+
}
|
| 39 |
+
return n;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
|
| 43 |
+
bool done;
|
| 44 |
+
do {
|
| 45 |
+
a = find(s_buf, a);
|
| 46 |
+
b = find(s_buf, b);
|
| 47 |
+
|
| 48 |
+
if (a < b) {
|
| 49 |
+
int32_t old = atomicMin(s_buf + b, a);
|
| 50 |
+
done = (old == b);
|
| 51 |
+
b = old;
|
| 52 |
+
} else if (b < a) {
|
| 53 |
+
int32_t old = atomicMin(s_buf + a, b);
|
| 54 |
+
done = (old == a);
|
| 55 |
+
a = old;
|
| 56 |
+
} else
|
| 57 |
+
done = true;
|
| 58 |
+
|
| 59 |
+
} while (!done);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
__global__ void
|
| 63 |
+
init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
|
| 64 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 65 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 66 |
+
const uint32_t idx = row * W + col;
|
| 67 |
+
|
| 68 |
+
if (row < H && col < W)
|
| 69 |
+
label[idx] = idx;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
__global__ void
|
| 73 |
+
merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
|
| 74 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 75 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 76 |
+
const uint32_t idx = row * W + col;
|
| 77 |
+
|
| 78 |
+
if (row >= H || col >= W)
|
| 79 |
+
return;
|
| 80 |
+
|
| 81 |
+
uint32_t P = 0;
|
| 82 |
+
|
| 83 |
+
if (img[idx])
|
| 84 |
+
P |= 0x777;
|
| 85 |
+
if (row + 1 < H && img[idx + W])
|
| 86 |
+
P |= 0x777 << 4;
|
| 87 |
+
if (col + 1 < W && img[idx + 1])
|
| 88 |
+
P |= 0x777 << 1;
|
| 89 |
+
|
| 90 |
+
if (col == 0)
|
| 91 |
+
P &= 0xEEEE;
|
| 92 |
+
if (col + 1 >= W)
|
| 93 |
+
P &= 0x3333;
|
| 94 |
+
else if (col + 2 >= W)
|
| 95 |
+
P &= 0x7777;
|
| 96 |
+
|
| 97 |
+
if (row == 0)
|
| 98 |
+
P &= 0xFFF0;
|
| 99 |
+
if (row + 1 >= H)
|
| 100 |
+
P &= 0xFF;
|
| 101 |
+
|
| 102 |
+
if (P > 0) {
|
| 103 |
+
// If need check about top-left pixel(if flag the first bit) and hit the
|
| 104 |
+
// top-left pixel
|
| 105 |
+
if (hasBit(P, 0) && img[idx - W - 1]) {
|
| 106 |
+
union_(label, idx, idx - 2 * W - 2); // top left block
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
|
| 110 |
+
union_(label, idx, idx - 2 * W); // top bottom block
|
| 111 |
+
|
| 112 |
+
if (hasBit(P, 3) && img[idx + 2 - W])
|
| 113 |
+
union_(label, idx, idx - 2 * W + 2); // top right block
|
| 114 |
+
|
| 115 |
+
if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
|
| 116 |
+
union_(label, idx, idx - 2); // just left block
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
|
| 121 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 122 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 123 |
+
const uint32_t idx = row * W + col;
|
| 124 |
+
|
| 125 |
+
if (row < H && col < W)
|
| 126 |
+
find_n_compress(label, idx);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
__global__ void final_labeling(
|
| 130 |
+
const uint8_t* img,
|
| 131 |
+
int32_t* label,
|
| 132 |
+
const int32_t W,
|
| 133 |
+
const int32_t H) {
|
| 134 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
|
| 135 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
|
| 136 |
+
const uint32_t idx = row * W + col;
|
| 137 |
+
|
| 138 |
+
if (row >= H || col >= W)
|
| 139 |
+
return;
|
| 140 |
+
|
| 141 |
+
int32_t y = label[idx] + 1;
|
| 142 |
+
|
| 143 |
+
if (img[idx])
|
| 144 |
+
label[idx] = y;
|
| 145 |
+
else
|
| 146 |
+
label[idx] = 0;
|
| 147 |
+
|
| 148 |
+
if (col + 1 < W) {
|
| 149 |
+
if (img[idx + 1])
|
| 150 |
+
label[idx + 1] = y;
|
| 151 |
+
else
|
| 152 |
+
label[idx + 1] = 0;
|
| 153 |
+
|
| 154 |
+
if (row + 1 < H) {
|
| 155 |
+
if (img[idx + W + 1])
|
| 156 |
+
label[idx + W + 1] = y;
|
| 157 |
+
else
|
| 158 |
+
label[idx + W + 1] = 0;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if (row + 1 < H) {
|
| 163 |
+
if (img[idx + W])
|
| 164 |
+
label[idx + W] = y;
|
| 165 |
+
else
|
| 166 |
+
label[idx + W] = 0;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
__global__ void init_counting(
|
| 171 |
+
const int32_t* label,
|
| 172 |
+
int32_t* count_init,
|
| 173 |
+
const int32_t W,
|
| 174 |
+
const int32_t H) {
|
| 175 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
|
| 176 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
|
| 177 |
+
const uint32_t idx = row * W + col;
|
| 178 |
+
|
| 179 |
+
if (row >= H || col >= W)
|
| 180 |
+
return;
|
| 181 |
+
|
| 182 |
+
int32_t y = label[idx];
|
| 183 |
+
if (y > 0) {
|
| 184 |
+
int32_t count_idx = y - 1;
|
| 185 |
+
atomicAdd(count_init + count_idx, 1);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
__global__ void final_counting(
|
| 190 |
+
const int32_t* label,
|
| 191 |
+
const int32_t* count_init,
|
| 192 |
+
int32_t* count_final,
|
| 193 |
+
const int32_t W,
|
| 194 |
+
const int32_t H) {
|
| 195 |
+
const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
|
| 196 |
+
const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
|
| 197 |
+
const uint32_t idx = row * W + col;
|
| 198 |
+
|
| 199 |
+
if (row >= H || col >= W)
|
| 200 |
+
return;
|
| 201 |
+
|
| 202 |
+
int32_t y = label[idx];
|
| 203 |
+
if (y > 0) {
|
| 204 |
+
int32_t count_idx = y - 1;
|
| 205 |
+
count_final[idx] = count_init[count_idx];
|
| 206 |
+
} else {
|
| 207 |
+
count_final[idx] = 0;
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
} // namespace cc2d
|
| 212 |
+
|
| 213 |
+
std::vector<torch::Tensor> get_connected_componnets(
|
| 214 |
+
const torch::Tensor& inputs) {
|
| 215 |
+
AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
|
| 216 |
+
AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
|
| 217 |
+
AT_ASSERTM(
|
| 218 |
+
inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
|
| 219 |
+
|
| 220 |
+
const uint32_t N = inputs.size(0);
|
| 221 |
+
const uint32_t C = inputs.size(1);
|
| 222 |
+
const uint32_t H = inputs.size(2);
|
| 223 |
+
const uint32_t W = inputs.size(3);
|
| 224 |
+
|
| 225 |
+
AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
|
| 226 |
+
AT_ASSERTM((H % 2) == 0, "height must be an even number");
|
| 227 |
+
AT_ASSERTM((W % 2) == 0, "width must be an even number");
|
| 228 |
+
|
| 229 |
+
// label must be uint32_t
|
| 230 |
+
auto label_options =
|
| 231 |
+
torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
|
| 232 |
+
torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
|
| 233 |
+
torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
|
| 234 |
+
torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
|
| 235 |
+
|
| 236 |
+
dim3 grid = dim3(
|
| 237 |
+
((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
|
| 238 |
+
((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
|
| 239 |
+
dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
|
| 240 |
+
dim3 grid_count =
|
| 241 |
+
dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
|
| 242 |
+
dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
|
| 243 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 244 |
+
|
| 245 |
+
for (int n = 0; n < N; n++) {
|
| 246 |
+
uint32_t offset = n * H * W;
|
| 247 |
+
|
| 248 |
+
cc2d::init_labeling<<<grid, block, 0, stream>>>(
|
| 249 |
+
labels.data_ptr<int32_t>() + offset, W, H);
|
| 250 |
+
cc2d::merge<<<grid, block, 0, stream>>>(
|
| 251 |
+
inputs.data_ptr<uint8_t>() + offset,
|
| 252 |
+
labels.data_ptr<int32_t>() + offset,
|
| 253 |
+
W,
|
| 254 |
+
H);
|
| 255 |
+
cc2d::compression<<<grid, block, 0, stream>>>(
|
| 256 |
+
labels.data_ptr<int32_t>() + offset, W, H);
|
| 257 |
+
cc2d::final_labeling<<<grid, block, 0, stream>>>(
|
| 258 |
+
inputs.data_ptr<uint8_t>() + offset,
|
| 259 |
+
labels.data_ptr<int32_t>() + offset,
|
| 260 |
+
W,
|
| 261 |
+
H);
|
| 262 |
+
|
| 263 |
+
// get the counting of each pixel
|
| 264 |
+
cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
|
| 265 |
+
labels.data_ptr<int32_t>() + offset,
|
| 266 |
+
counts_init.data_ptr<int32_t>() + offset,
|
| 267 |
+
W,
|
| 268 |
+
H);
|
| 269 |
+
cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
|
| 270 |
+
labels.data_ptr<int32_t>() + offset,
|
| 271 |
+
counts_init.data_ptr<int32_t>() + offset,
|
| 272 |
+
counts_final.data_ptr<int32_t>() + offset,
|
| 273 |
+
W,
|
| 274 |
+
H);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// returned values are [labels, counts]
|
| 278 |
+
std::vector<torch::Tensor> outputs;
|
| 279 |
+
outputs.push_back(labels);
|
| 280 |
+
outputs.push_back(counts_final);
|
| 281 |
+
return outputs;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 285 |
+
m.def(
|
| 286 |
+
"get_connected_componnets",
|
| 287 |
+
&get_connected_componnets,
|
| 288 |
+
"get_connected_componnets");
|
| 289 |
+
}
|
sam2/sam2/granularity_embedding.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
class FourierGranularityMLP(nn.Module):
|
| 6 |
+
def __init__(self, fourier_dim=128, decoder_dim=256, hidden_dim=None,
|
| 7 |
+
num_layers=2, dropout=0.1, temperature=100):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.fourier_embedder = FourierEmbedder(hidden_dim=fourier_dim, temperature=temperature)
|
| 11 |
+
|
| 12 |
+
self.mlp = GranularityMLP(
|
| 13 |
+
granularity_dim=fourier_dim,
|
| 14 |
+
decoder_dim=decoder_dim,
|
| 15 |
+
hidden_dim=hidden_dim,
|
| 16 |
+
num_layers=num_layers,
|
| 17 |
+
dropout=dropout
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, granularity):
|
| 21 |
+
if granularity.dim() == 0:
|
| 22 |
+
granularity = granularity.view(1)
|
| 23 |
+
|
| 24 |
+
fourier_features = self.fourier_embedder(granularity)
|
| 25 |
+
|
| 26 |
+
return self.mlp(fourier_features)
|
| 27 |
+
|
| 28 |
+
class FourierEmbedder():
|
| 29 |
+
def __init__(self, hidden_dim=128, temperature=100):
|
| 30 |
+
self.hidden_dim = hidden_dim
|
| 31 |
+
self.num_freqs = hidden_dim // 2
|
| 32 |
+
self.remaining_dim = hidden_dim % 2
|
| 33 |
+
self.temperature = temperature
|
| 34 |
+
self.freq_bands = temperature ** (torch.arange(self.num_freqs) / self.num_freqs)
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def __call__(self, x, cat_dim=-1):
|
| 38 |
+
out = []
|
| 39 |
+
# Add sin/cos pairs
|
| 40 |
+
for freq in self.freq_bands:
|
| 41 |
+
out.append(torch.sin(freq * x))
|
| 42 |
+
out.append(torch.cos(freq * x))
|
| 43 |
+
|
| 44 |
+
if self.remaining_dim:
|
| 45 |
+
out.append(torch.sin(self.temperature * x))
|
| 46 |
+
|
| 47 |
+
return torch.cat(out, cat_dim)
|
| 48 |
+
|
| 49 |
+
class GranularityMLP(nn.Module):
|
| 50 |
+
def __init__(self, granularity_dim, decoder_dim=256, hidden_dim=None, num_layers=2, dropout=0.1):
|
| 51 |
+
super().__init__()
|
| 52 |
+
if hidden_dim is None:
|
| 53 |
+
hidden_dim = (granularity_dim + decoder_dim) // 2
|
| 54 |
+
|
| 55 |
+
layers = []
|
| 56 |
+
input_dim = granularity_dim
|
| 57 |
+
for _ in range(num_layers - 1):
|
| 58 |
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
| 59 |
+
layers.append(nn.ReLU())
|
| 60 |
+
layers.append(nn.Dropout(dropout))
|
| 61 |
+
input_dim = hidden_dim
|
| 62 |
+
|
| 63 |
+
layers.append(nn.Linear(hidden_dim, decoder_dim))
|
| 64 |
+
self.mlp = nn.Sequential(*layers)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
return self.mlp(x)
|
sam2/sam2/modeling/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sam2/sam2/modeling/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
sam2/sam2/modeling/__pycache__/memory_attention.cpython-310.pyc
ADDED
|
Binary file (3.96 kB). View file
|
|
|
sam2/sam2/modeling/__pycache__/memory_encoder.cpython-310.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
sam2/sam2/modeling/__pycache__/position_encoding.cpython-310.pyc
ADDED
|
Binary file (7.98 kB). View file
|
|
|
sam2/sam2/modeling/__pycache__/sam2_base.cpython-310.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
sam2/sam2/modeling/__pycache__/sam2_utils.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
sam2/sam2/modeling/backbones/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sam2/sam2/modeling/backbones/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
sam2/sam2/modeling/backbones/__pycache__/hieradet.cpython-310.pyc
ADDED
|
Binary file (7.72 kB). View file
|
|
|
sam2/sam2/modeling/backbones/__pycache__/image_encoder.cpython-310.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
sam2/sam2/modeling/backbones/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.2 kB). View file
|
|
|
sam2/sam2/modeling/backbones/adapter_hieradet.py
ADDED
|
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from iopath.common.file_io import g_pathmgr
|
| 15 |
+
|
| 16 |
+
from sam2.modeling.backbones.utils import (
|
| 17 |
+
PatchEmbed,
|
| 18 |
+
window_partition,
|
| 19 |
+
window_unpartition,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from sam2.modeling.sam2_utils import DropPath, MLP
|
| 23 |
+
from itertools import repeat
|
| 24 |
+
import math
|
| 25 |
+
|
| 26 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
| 27 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
| 28 |
+
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
| 29 |
+
from torch._six import container_abcs
|
| 30 |
+
else:
|
| 31 |
+
import collections.abc as container_abcs
|
| 32 |
+
|
| 33 |
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
| 34 |
+
if pool is None:
|
| 35 |
+
return x
|
| 36 |
+
# (B, H, W, C) -> (B, C, H, W)
|
| 37 |
+
x = x.permute(0, 3, 1, 2)
|
| 38 |
+
x = pool(x)
|
| 39 |
+
# (B, C, H', W') -> (B, H', W', C)
|
| 40 |
+
x = x.permute(0, 2, 3, 1)
|
| 41 |
+
if norm:
|
| 42 |
+
x = norm(x)
|
| 43 |
+
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class MultiScaleAttention(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
dim: int,
|
| 51 |
+
dim_out: int,
|
| 52 |
+
num_heads: int,
|
| 53 |
+
q_pool: nn.Module = None,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.dim = dim
|
| 58 |
+
self.dim_out = dim_out
|
| 59 |
+
|
| 60 |
+
self.num_heads = num_heads
|
| 61 |
+
head_dim = dim_out // num_heads
|
| 62 |
+
self.scale = head_dim**-0.5
|
| 63 |
+
|
| 64 |
+
self.q_pool = q_pool
|
| 65 |
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
| 66 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
B, H, W, _ = x.shape
|
| 70 |
+
# qkv with shape (B, H * W, 3, nHead, C)
|
| 71 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
| 72 |
+
# q, k, v with shape (B, H * W, nheads, C)
|
| 73 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 74 |
+
|
| 75 |
+
# Q pooling (for downsample at stage changes)
|
| 76 |
+
if self.q_pool:
|
| 77 |
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
| 78 |
+
H, W = q.shape[1:3] # downsampled shape
|
| 79 |
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
| 80 |
+
|
| 81 |
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
| 82 |
+
x = F.scaled_dot_product_attention(
|
| 83 |
+
q.transpose(1, 2),
|
| 84 |
+
k.transpose(1, 2),
|
| 85 |
+
v.transpose(1, 2),
|
| 86 |
+
)
|
| 87 |
+
# Transpose back
|
| 88 |
+
x = x.transpose(1, 2)
|
| 89 |
+
x = x.reshape(B, H, W, -1)
|
| 90 |
+
|
| 91 |
+
x = self.proj(x)
|
| 92 |
+
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class MultiScaleBlock(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
dim: int,
|
| 100 |
+
dim_out: int,
|
| 101 |
+
num_heads: int,
|
| 102 |
+
mlp_ratio: float = 4.0,
|
| 103 |
+
drop_path: float = 0.0,
|
| 104 |
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
| 105 |
+
q_stride: Tuple[int, int] = None,
|
| 106 |
+
act_layer: nn.Module = nn.GELU,
|
| 107 |
+
window_size: int = 0,
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
|
| 111 |
+
if isinstance(norm_layer, str):
|
| 112 |
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
| 113 |
+
|
| 114 |
+
self.dim = dim
|
| 115 |
+
self.dim_out = dim_out
|
| 116 |
+
self.norm1 = norm_layer(dim)
|
| 117 |
+
|
| 118 |
+
self.window_size = window_size
|
| 119 |
+
|
| 120 |
+
self.pool, self.q_stride = None, q_stride
|
| 121 |
+
if self.q_stride:
|
| 122 |
+
self.pool = nn.MaxPool2d(
|
| 123 |
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.attn = MultiScaleAttention(
|
| 127 |
+
dim,
|
| 128 |
+
dim_out,
|
| 129 |
+
num_heads=num_heads,
|
| 130 |
+
q_pool=self.pool,
|
| 131 |
+
)
|
| 132 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 133 |
+
|
| 134 |
+
self.norm2 = norm_layer(dim_out)
|
| 135 |
+
self.mlp = MLP(
|
| 136 |
+
dim_out,
|
| 137 |
+
int(dim_out * mlp_ratio),
|
| 138 |
+
dim_out,
|
| 139 |
+
num_layers=2,
|
| 140 |
+
activation=act_layer,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if dim != dim_out:
|
| 144 |
+
self.proj = nn.Linear(dim, dim_out)
|
| 145 |
+
|
| 146 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
shortcut = x # B, H, W, C
|
| 148 |
+
x = self.norm1(x)
|
| 149 |
+
|
| 150 |
+
# Skip connection
|
| 151 |
+
if self.dim != self.dim_out:
|
| 152 |
+
shortcut = do_pool(self.proj(x), self.pool)
|
| 153 |
+
|
| 154 |
+
# Window partition
|
| 155 |
+
window_size = self.window_size
|
| 156 |
+
if window_size > 0:
|
| 157 |
+
H, W = x.shape[1], x.shape[2]
|
| 158 |
+
x, pad_hw = window_partition(x, window_size)
|
| 159 |
+
|
| 160 |
+
# Window Attention + Q Pooling (if stage change)
|
| 161 |
+
x = self.attn(x)
|
| 162 |
+
if self.q_stride:
|
| 163 |
+
# Shapes have changed due to Q pooling
|
| 164 |
+
window_size = self.window_size // self.q_stride[0]
|
| 165 |
+
H, W = shortcut.shape[1:3]
|
| 166 |
+
|
| 167 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 168 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 169 |
+
pad_hw = (H + pad_h, W + pad_w)
|
| 170 |
+
|
| 171 |
+
# Reverse window partition
|
| 172 |
+
if self.window_size > 0:
|
| 173 |
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
| 174 |
+
|
| 175 |
+
x = shortcut + self.drop_path(x)
|
| 176 |
+
# MLP
|
| 177 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
class OverlapPatchEmbed(nn.Module):
|
| 181 |
+
""" Image to Patch Embedding
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
| 185 |
+
super().__init__()
|
| 186 |
+
img_size = to_2tuple(img_size)
|
| 187 |
+
patch_size = to_2tuple(patch_size)
|
| 188 |
+
|
| 189 |
+
self.img_size = img_size
|
| 190 |
+
self.patch_size = patch_size
|
| 191 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
| 192 |
+
self.num_patches = self.H * self.W
|
| 193 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 194 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 195 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 196 |
+
|
| 197 |
+
self.apply(self._init_weights)
|
| 198 |
+
|
| 199 |
+
def _init_weights(self, m):
|
| 200 |
+
if isinstance(m, nn.Linear):
|
| 201 |
+
trunc_normal_(m.weight, std=.02)
|
| 202 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 203 |
+
nn.init.constant_(m.bias, 0)
|
| 204 |
+
elif isinstance(m, nn.LayerNorm):
|
| 205 |
+
nn.init.constant_(m.bias, 0)
|
| 206 |
+
nn.init.constant_(m.weight, 1.0)
|
| 207 |
+
elif isinstance(m, nn.Conv2d):
|
| 208 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 209 |
+
fan_out //= m.groups
|
| 210 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 211 |
+
if m.bias is not None:
|
| 212 |
+
m.bias.data.zero_()
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
x = self.proj(x)
|
| 216 |
+
_, _, H, W = x.shape
|
| 217 |
+
x = x.flatten(2).transpose(1, 2)
|
| 218 |
+
x = self.norm(x)
|
| 219 |
+
|
| 220 |
+
return x, H, W
|
| 221 |
+
class Hiera(nn.Module):
|
| 222 |
+
"""
|
| 223 |
+
Reference: https://arxiv.org/abs/2306.00989
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
in_chans=3,
|
| 229 |
+
embed_dims=[96, 192, 384, 768],
|
| 230 |
+
img_size: int = 1024,
|
| 231 |
+
patch_size: int = 4,
|
| 232 |
+
embed_dim: int = 96,# initial embed dim
|
| 233 |
+
num_heads: int = 1, # initial number of heads
|
| 234 |
+
drop_path_rate: float = 0.0, # stochastic depth
|
| 235 |
+
q_pool: int = 3, # number of q_pool stages
|
| 236 |
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
| 237 |
+
stages: Tuple[int, ...] = (2, 6, 36, 4), # blocks per stage
|
| 238 |
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
| 239 |
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
| 240 |
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (7, 7),
|
| 241 |
+
# window size per stage, when not using global att.
|
| 242 |
+
window_spec: Tuple[int, ...] = (
|
| 243 |
+
8,
|
| 244 |
+
4,
|
| 245 |
+
16,
|
| 246 |
+
8
|
| 247 |
+
),
|
| 248 |
+
# global attn in these blocks
|
| 249 |
+
global_att_blocks: Tuple[int, ...] = (
|
| 250 |
+
23,
|
| 251 |
+
33,
|
| 252 |
+
43
|
| 253 |
+
),
|
| 254 |
+
return_interm_layers=True, # return feats from every stage
|
| 255 |
+
):
|
| 256 |
+
super().__init__()
|
| 257 |
+
|
| 258 |
+
assert len(stages) == len(window_spec)
|
| 259 |
+
self.window_spec = window_spec
|
| 260 |
+
|
| 261 |
+
depth = sum(stages)
|
| 262 |
+
self.q_stride = q_stride
|
| 263 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
| 264 |
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
| 265 |
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
| 266 |
+
self.return_interm_layers = return_interm_layers
|
| 267 |
+
|
| 268 |
+
self.patch_embed = PatchEmbed(
|
| 269 |
+
embed_dim=embed_dim,
|
| 270 |
+
)
|
| 271 |
+
# self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
| 272 |
+
# embed_dim=embed_dims[0])
|
| 273 |
+
# self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
| 274 |
+
# embed_dim=embed_dims[1])
|
| 275 |
+
# self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
| 276 |
+
# embed_dim=embed_dims[2])
|
| 277 |
+
# self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
| 278 |
+
# embed_dim=embed_dims[3])
|
| 279 |
+
# Which blocks have global att?
|
| 280 |
+
self.global_att_blocks = global_att_blocks
|
| 281 |
+
|
| 282 |
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 283 |
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 284 |
+
self.pos_embed = nn.Parameter(
|
| 285 |
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
| 286 |
+
)
|
| 287 |
+
self.pos_embed_window = nn.Parameter(
|
| 288 |
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
dpr = [
|
| 292 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 293 |
+
] # stochastic depth decay rule
|
| 294 |
+
|
| 295 |
+
# patch_embed
|
| 296 |
+
# self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
| 297 |
+
# embed_dim=embed_dims[0])
|
| 298 |
+
# self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
| 299 |
+
# embed_dim=embed_dims[1])
|
| 300 |
+
# self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
| 301 |
+
# embed_dim=embed_dims[2])
|
| 302 |
+
# self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
| 303 |
+
# embed_dim=embed_dims[3])
|
| 304 |
+
|
| 305 |
+
cur_stage = 1
|
| 306 |
+
self.embed_dim = embed_dims
|
| 307 |
+
self.depth = stages
|
| 308 |
+
self.blocks = nn.ModuleList()
|
| 309 |
+
self.scale_factor = 32
|
| 310 |
+
self.prompt_type = 'highpass'
|
| 311 |
+
self.tuning_stage = "1234"
|
| 312 |
+
self.input_type = 'fft'
|
| 313 |
+
self.freq_nums = 0.25
|
| 314 |
+
self.handcrafted_tune = False
|
| 315 |
+
self.embedding_tune = True
|
| 316 |
+
self.adaptor = 'adaptor'
|
| 317 |
+
self.prompt_generator = PromptGenerator(self.scale_factor, self.prompt_type, self.embed_dim,
|
| 318 |
+
self.tuning_stage, self.depth,
|
| 319 |
+
self.input_type, self.freq_nums,
|
| 320 |
+
self.handcrafted_tune, self.embedding_tune, self.adaptor,
|
| 321 |
+
img_size)
|
| 322 |
+
|
| 323 |
+
for i in range(depth):
|
| 324 |
+
dim_out = embed_dim
|
| 325 |
+
# lags by a block, so first block of
|
| 326 |
+
# next stage uses an initial window size
|
| 327 |
+
# of previous stage and final window size of current stage
|
| 328 |
+
window_size = self.window_spec[cur_stage - 1]
|
| 329 |
+
|
| 330 |
+
if self.global_att_blocks is not None:
|
| 331 |
+
window_size = 0 if i in self.global_att_blocks else window_size
|
| 332 |
+
|
| 333 |
+
if i - 1 in self.stage_ends:
|
| 334 |
+
dim_out = int(embed_dim * dim_mul)
|
| 335 |
+
num_heads = int(num_heads * head_mul)
|
| 336 |
+
cur_stage += 1
|
| 337 |
+
|
| 338 |
+
block = MultiScaleBlock(
|
| 339 |
+
dim=embed_dim,
|
| 340 |
+
dim_out=dim_out,
|
| 341 |
+
num_heads=num_heads,
|
| 342 |
+
drop_path=dpr[i],
|
| 343 |
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
| 344 |
+
window_size=window_size,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
embed_dim = dim_out
|
| 348 |
+
self.blocks.append(block)
|
| 349 |
+
|
| 350 |
+
self.channel_list = (
|
| 351 |
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 352 |
+
if return_interm_layers
|
| 353 |
+
else [self.blocks[-1].dim_out]
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
| 357 |
+
h, w = hw
|
| 358 |
+
window_embed = self.pos_embed_window
|
| 359 |
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 360 |
+
pos_embed = pos_embed + window_embed.tile(
|
| 361 |
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 362 |
+
)
|
| 363 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 364 |
+
return pos_embed
|
| 365 |
+
|
| 366 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 367 |
+
inp = x
|
| 368 |
+
x = self.patch_embed(x)
|
| 369 |
+
# x: (B, H, W, C)
|
| 370 |
+
handcrafted1, handcrafted2, handcrafted3, handcrafted4 = self.prompt_generator.init_handcrafted(inp)
|
| 371 |
+
|
| 372 |
+
self.block1 = []
|
| 373 |
+
self.block2 = []
|
| 374 |
+
self.block3 = []
|
| 375 |
+
self.block4 = []
|
| 376 |
+
outputs = []
|
| 377 |
+
|
| 378 |
+
for i, blk in enumerate(self.blocks):
|
| 379 |
+
if i < 2:
|
| 380 |
+
self.block1.append(blk) # 第一个块包含前2个元素
|
| 381 |
+
elif 1 < i < 4:
|
| 382 |
+
self.block2.append(blk) # 第二个块包含接下来的2个元素
|
| 383 |
+
elif 3 < i < 15:
|
| 384 |
+
self.block3.append(blk) # 第三个块包含接下来的11个元素
|
| 385 |
+
elif 14 < i:
|
| 386 |
+
self.block4.append(blk) # 其余元素组成第四个块
|
| 387 |
+
# print(i)
|
| 388 |
+
|
| 389 |
+
# for i, blk in enumerate(self.blocks):
|
| 390 |
+
# if i < 3:
|
| 391 |
+
# self.block1.append(blk) # 第一个块包含前3个元素
|
| 392 |
+
# elif 2 < i < 9:
|
| 393 |
+
# self.block2.append(blk) # 第二个块包含接下来的6个元素
|
| 394 |
+
# elif 8 < i < 45:
|
| 395 |
+
# self.block3.append(blk) # 第三个块包含接下来的36个元素
|
| 396 |
+
# elif 44 < i:
|
| 397 |
+
# self.block4.append(blk) # 其余元素组成第四个块
|
| 398 |
+
|
| 399 |
+
# Add pos embed
|
| 400 |
+
x = x + self._get_pos_embed(x.shape[1:3])
|
| 401 |
+
|
| 402 |
+
if '1' in self.tuning_stage:
|
| 403 |
+
prompt1 = self.prompt_generator.init_prompt(x, handcrafted1, 1)
|
| 404 |
+
for i, blk in enumerate(self.block1):
|
| 405 |
+
if '1' in self.tuning_stage:
|
| 406 |
+
x = self.prompt_generator.get_prompt(x, prompt1, 1, i)
|
| 407 |
+
x = blk(x)
|
| 408 |
+
# x = self.norm1(x)
|
| 409 |
+
if i == 0:
|
| 410 |
+
feat = x.permute(0, 3, 1, 2)
|
| 411 |
+
outputs.append(feat)
|
| 412 |
+
|
| 413 |
+
if '2' in self.tuning_stage:
|
| 414 |
+
prompt2 = self.prompt_generator.init_prompt(x, handcrafted2, 2)
|
| 415 |
+
for i, blk in enumerate(self.block2):
|
| 416 |
+
if '2' in self.tuning_stage:
|
| 417 |
+
x = self.prompt_generator.get_prompt(x, prompt2, 2, i)
|
| 418 |
+
x = blk(x)
|
| 419 |
+
# x = self.norm2(x)
|
| 420 |
+
if i == 0:
|
| 421 |
+
feat = x.permute(0, 3, 1, 2)
|
| 422 |
+
outputs.append(feat)
|
| 423 |
+
|
| 424 |
+
if '3' in self.tuning_stage:
|
| 425 |
+
prompt3 = self.prompt_generator.init_prompt(x, handcrafted3, 3)
|
| 426 |
+
for i, blk in enumerate(self.block3):
|
| 427 |
+
if '3' in self.tuning_stage:
|
| 428 |
+
x = self.prompt_generator.get_prompt(x,prompt3, 3, i)
|
| 429 |
+
x = blk(x)
|
| 430 |
+
# x = self.norm3(x)
|
| 431 |
+
if i == 9:
|
| 432 |
+
feat = x.permute(0, 3, 1, 2)
|
| 433 |
+
outputs.append(feat)
|
| 434 |
+
|
| 435 |
+
if '4' in self.tuning_stage:
|
| 436 |
+
prompt4 = self.prompt_generator.init_prompt(x, handcrafted4, 4)
|
| 437 |
+
for i, blk in enumerate(self.block4):
|
| 438 |
+
if '4' in self.tuning_stage:
|
| 439 |
+
x = self.prompt_generator.get_prompt(x, prompt4, 4, i)
|
| 440 |
+
x = blk(x)
|
| 441 |
+
# x = self.norm4(x)
|
| 442 |
+
if i == 0:
|
| 443 |
+
feat = x.permute(0, 3, 1, 2)
|
| 444 |
+
outputs.append(feat)
|
| 445 |
+
|
| 446 |
+
return outputs
|
| 447 |
+
def to_2tuple(x):
|
| 448 |
+
if isinstance(x, container_abcs.Iterable):
|
| 449 |
+
return x
|
| 450 |
+
return tuple(repeat(x, 2))
|
| 451 |
+
|
| 452 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 453 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
| 454 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 455 |
+
normal distribution. The values are effectively drawn from the
|
| 456 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 457 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 458 |
+
the bounds. The method used for generating the random values works
|
| 459 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 460 |
+
Args:
|
| 461 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 462 |
+
mean: the mean of the normal distribution
|
| 463 |
+
std: the standard deviation of the normal distribution
|
| 464 |
+
a: the minimum cutoff value
|
| 465 |
+
b: the maximum cutoff value
|
| 466 |
+
Examples:
|
| 467 |
+
>>> w = torch.empty(3, 5)
|
| 468 |
+
>>> nn.init.trunc_normal_(w)
|
| 469 |
+
"""
|
| 470 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 471 |
+
|
| 472 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 473 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 474 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 475 |
+
def norm_cdf(x):
|
| 476 |
+
# Computes standard normal cumulative distribution function
|
| 477 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 478 |
+
|
| 479 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 480 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 481 |
+
"The distribution of values may be incorrect.",
|
| 482 |
+
stacklevel=2)
|
| 483 |
+
|
| 484 |
+
with torch.no_grad():
|
| 485 |
+
# Values are generated by using a truncated uniform distribution and
|
| 486 |
+
# then using the inverse CDF for the normal distribution.
|
| 487 |
+
# Get upper and lower cdf values
|
| 488 |
+
l = norm_cdf((a - mean) / std)
|
| 489 |
+
u = norm_cdf((b - mean) / std)
|
| 490 |
+
|
| 491 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 492 |
+
# [2l-1, 2u-1].
|
| 493 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 494 |
+
|
| 495 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 496 |
+
# standard normal
|
| 497 |
+
tensor.erfinv_()
|
| 498 |
+
|
| 499 |
+
# Transform to proper mean, std
|
| 500 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 501 |
+
tensor.add_(mean)
|
| 502 |
+
|
| 503 |
+
# Clamp to ensure it's in the proper range
|
| 504 |
+
tensor.clamp_(min=a, max=b)
|
| 505 |
+
return tensor
|
| 506 |
+
class PromptGenerator(nn.Module):
|
| 507 |
+
def __init__(self, scale_factor, prompt_type, embed_dims, tuning_stage, depths, input_type,
|
| 508 |
+
freq_nums, handcrafted_tune, embedding_tune, adaptor, img_size):
|
| 509 |
+
"""
|
| 510 |
+
Args:
|
| 511 |
+
"""
|
| 512 |
+
super(PromptGenerator, self).__init__()
|
| 513 |
+
self.scale_factor = scale_factor
|
| 514 |
+
self.prompt_type = prompt_type
|
| 515 |
+
self.embed_dims = embed_dims
|
| 516 |
+
self.input_type = input_type
|
| 517 |
+
self.freq_nums = freq_nums
|
| 518 |
+
self.tuning_stage = tuning_stage
|
| 519 |
+
self.depths = depths
|
| 520 |
+
self.handcrafted_tune = handcrafted_tune
|
| 521 |
+
self.embedding_tune = embedding_tune
|
| 522 |
+
self.adaptor = adaptor
|
| 523 |
+
|
| 524 |
+
if self.input_type == 'gaussian':
|
| 525 |
+
self.gaussian_filter = GaussianFilter()
|
| 526 |
+
if self.input_type == 'srm':
|
| 527 |
+
self.srm_filter = SRMFilter()
|
| 528 |
+
if self.input_type == 'all':
|
| 529 |
+
self.prompt = nn.Parameter(torch.zeros(3, img_size, img_size), requires_grad=False)
|
| 530 |
+
|
| 531 |
+
if self.handcrafted_tune:
|
| 532 |
+
if '1' in self.tuning_stage:
|
| 533 |
+
self.handcrafted_generator1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=3,
|
| 534 |
+
embed_dim=self.embed_dims[0] // self.scale_factor)
|
| 535 |
+
if '2' in self.tuning_stage:
|
| 536 |
+
self.handcrafted_generator2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2,
|
| 537 |
+
in_chans=self.embed_dims[0] // self.scale_factor,
|
| 538 |
+
embed_dim=self.embed_dims[1] // self.scale_factor)
|
| 539 |
+
if '3' in self.tuning_stage:
|
| 540 |
+
self.handcrafted_generator3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2,
|
| 541 |
+
in_chans=self.embed_dims[1] // self.scale_factor,
|
| 542 |
+
embed_dim=self.embed_dims[2] // self.scale_factor)
|
| 543 |
+
if '4' in self.tuning_stage:
|
| 544 |
+
self.handcrafted_generator4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2,
|
| 545 |
+
in_chans=self.embed_dims[2] // self.scale_factor,
|
| 546 |
+
embed_dim=self.embed_dims[3] // self.scale_factor)
|
| 547 |
+
|
| 548 |
+
if self.embedding_tune:
|
| 549 |
+
if '1' in self.tuning_stage:
|
| 550 |
+
self.embedding_generator1 = nn.Linear(self.embed_dims[0], self.embed_dims[0] // self.scale_factor)
|
| 551 |
+
if '2' in self.tuning_stage:
|
| 552 |
+
self.embedding_generator2 = nn.Linear(self.embed_dims[1], self.embed_dims[1] // self.scale_factor)
|
| 553 |
+
if '3' in self.tuning_stage:
|
| 554 |
+
self.embedding_generator3 = nn.Linear(self.embed_dims[2], self.embed_dims[2] // self.scale_factor)
|
| 555 |
+
if '4' in self.tuning_stage:
|
| 556 |
+
self.embedding_generator4 = nn.Linear(self.embed_dims[3], self.embed_dims[3] // self.scale_factor)
|
| 557 |
+
|
| 558 |
+
if self.adaptor == 'adaptor':
|
| 559 |
+
if '1' in self.tuning_stage:
|
| 560 |
+
for i in range(self.depths[0]+1):
|
| 561 |
+
lightweight_mlp = nn.Sequential(
|
| 562 |
+
nn.Linear(self.embed_dims[0] // self.scale_factor, self.embed_dims[0] // self.scale_factor),
|
| 563 |
+
nn.GELU(),
|
| 564 |
+
)
|
| 565 |
+
setattr(self, 'lightweight_mlp1_{}'.format(str(i)), lightweight_mlp)
|
| 566 |
+
self.shared_mlp1 = nn.Linear(self.embed_dims[0] // self.scale_factor, self.embed_dims[0])
|
| 567 |
+
|
| 568 |
+
if '2' in self.tuning_stage:
|
| 569 |
+
for i in range(self.depths[1]+1):
|
| 570 |
+
lightweight_mlp = nn.Sequential(
|
| 571 |
+
nn.Linear(self.embed_dims[1] // self.scale_factor, self.embed_dims[1] // self.scale_factor),
|
| 572 |
+
nn.GELU(),
|
| 573 |
+
)
|
| 574 |
+
setattr(self, 'lightweight_mlp2_{}'.format(str(i)), lightweight_mlp)
|
| 575 |
+
self.shared_mlp2 = nn.Linear(self.embed_dims[1] // self.scale_factor, self.embed_dims[1])
|
| 576 |
+
|
| 577 |
+
if '3' in self.tuning_stage:
|
| 578 |
+
for i in range(self.depths[2]+1):
|
| 579 |
+
lightweight_mlp = nn.Sequential(
|
| 580 |
+
nn.Linear(self.embed_dims[2] // self.scale_factor, self.embed_dims[2] // self.scale_factor),
|
| 581 |
+
nn.GELU(),
|
| 582 |
+
)
|
| 583 |
+
setattr(self, 'lightweight_mlp3_{}'.format(str(i)), lightweight_mlp)
|
| 584 |
+
self.shared_mlp3 = nn.Linear(self.embed_dims[2] // self.scale_factor, self.embed_dims[2])
|
| 585 |
+
|
| 586 |
+
if '4' in self.tuning_stage:
|
| 587 |
+
for i in range(self.depths[3]+1):
|
| 588 |
+
lightweight_mlp = nn.Sequential(
|
| 589 |
+
nn.Linear(self.embed_dims[3] // self.scale_factor, self.embed_dims[3] // self.scale_factor),
|
| 590 |
+
nn.GELU(),
|
| 591 |
+
)
|
| 592 |
+
setattr(self, 'lightweight_mlp4_{}'.format(str(i)), lightweight_mlp)
|
| 593 |
+
self.shared_mlp4 = nn.Linear(self.embed_dims[3] // self.scale_factor, self.embed_dims[3])
|
| 594 |
+
|
| 595 |
+
elif self.adaptor == 'fully_shared':
|
| 596 |
+
self.fully_shared_mlp1 = nn.Sequential(
|
| 597 |
+
nn.Linear(self.embed_dims[0] // self.scale_factor, self.embed_dims[0] // self.scale_factor),
|
| 598 |
+
nn.GELU(),
|
| 599 |
+
nn.Linear(self.embed_dims[0] // self.scale_factor, self.embed_dims[0])
|
| 600 |
+
)
|
| 601 |
+
self.fully_shared_mlp2 = nn.Sequential(
|
| 602 |
+
nn.Linear(self.embed_dims[1] // self.scale_factor, self.embed_dims[1] // self.scale_factor),
|
| 603 |
+
nn.GELU(),
|
| 604 |
+
nn.Linear(self.embed_dims[1] // self.scale_factor, self.embed_dims[1])
|
| 605 |
+
)
|
| 606 |
+
self.fully_shared_mlp3 = nn.Sequential(
|
| 607 |
+
nn.Linear(self.embed_dims[2] // self.scale_factor, self.embed_dims[2] // self.scale_factor),
|
| 608 |
+
nn.GELU(),
|
| 609 |
+
nn.Linear(self.embed_dims[2] // self.scale_factor, self.embed_dims[2])
|
| 610 |
+
)
|
| 611 |
+
self.fully_shared_mlp4 = nn.Sequential(
|
| 612 |
+
nn.Linear(self.embed_dims[3] // self.scale_factor, self.embed_dims[3] // self.scale_factor),
|
| 613 |
+
nn.GELU(),
|
| 614 |
+
nn.Linear(self.embed_dims[3] // self.scale_factor, self.embed_dims[3])
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
elif self.adaptor == 'fully_unshared':
|
| 618 |
+
for i in range(self.depths[0]):
|
| 619 |
+
fully_unshared_mlp1 = nn.Sequential(
|
| 620 |
+
nn.Linear(self.embed_dims[0] // self.scale_factor, self.embed_dims[0] // self.scale_factor),
|
| 621 |
+
nn.GELU(),
|
| 622 |
+
nn.Linear(self.embed_dims[0] // self.scale_factor, self.embed_dims[0])
|
| 623 |
+
)
|
| 624 |
+
setattr(self, 'fully_unshared_mlp1_{}'.format(str(i)), fully_unshared_mlp1)
|
| 625 |
+
for i in range(self.depths[1]):
|
| 626 |
+
fully_unshared_mlp1 = nn.Sequential(
|
| 627 |
+
nn.Linear(self.embed_dims[1] // self.scale_factor, self.embed_dims[1] // self.scale_factor),
|
| 628 |
+
nn.GELU(),
|
| 629 |
+
nn.Linear(self.embed_dims[1] // self.scale_factor, self.embed_dims[1])
|
| 630 |
+
)
|
| 631 |
+
setattr(self, 'fully_unshared_mlp2_{}'.format(str(i)), fully_unshared_mlp1)
|
| 632 |
+
for i in range(self.depths[2]):
|
| 633 |
+
fully_unshared_mlp1 = nn.Sequential(
|
| 634 |
+
nn.Linear(self.embed_dims[2] // self.scale_factor, self.embed_dims[2] // self.scale_factor),
|
| 635 |
+
nn.GELU(),
|
| 636 |
+
nn.Linear(self.embed_dims[2] // self.scale_factor, self.embed_dims[2])
|
| 637 |
+
)
|
| 638 |
+
setattr(self, 'fully_unshared_mlp3_{}'.format(str(i)), fully_unshared_mlp1)
|
| 639 |
+
for i in range(self.depths[3]):
|
| 640 |
+
fully_unshared_mlp1 = nn.Sequential(
|
| 641 |
+
nn.Linear(self.embed_dims[3] // self.scale_factor, self.embed_dims[3] // self.scale_factor),
|
| 642 |
+
nn.GELU(),
|
| 643 |
+
nn.Linear(self.embed_dims[3] // self.scale_factor, self.embed_dims[3])
|
| 644 |
+
)
|
| 645 |
+
setattr(self, 'fully_unshared_mlp4_{}'.format(str(i)), fully_unshared_mlp1)
|
| 646 |
+
|
| 647 |
+
self.apply(self._init_weights)
|
| 648 |
+
|
| 649 |
+
def _init_weights(self, m):
|
| 650 |
+
if isinstance(m, nn.Linear):
|
| 651 |
+
trunc_normal_(m.weight, std=.02)
|
| 652 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 653 |
+
nn.init.constant_(m.bias, 0)
|
| 654 |
+
elif isinstance(m, nn.LayerNorm):
|
| 655 |
+
nn.init.constant_(m.bias, 0)
|
| 656 |
+
nn.init.constant_(m.weight, 1.0)
|
| 657 |
+
elif isinstance(m, nn.Conv2d):
|
| 658 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 659 |
+
fan_out //= m.groups
|
| 660 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 661 |
+
if m.bias is not None:
|
| 662 |
+
m.bias.data.zero_()
|
| 663 |
+
|
| 664 |
+
def init_handcrafted(self, x):
|
| 665 |
+
return None, None, None, None
|
| 666 |
+
if self.input_type == 'fft':
|
| 667 |
+
x = self.fft(x, self.freq_nums, self.prompt_type)
|
| 668 |
+
|
| 669 |
+
elif self.input_type == 'all':
|
| 670 |
+
x = self.prompt.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
| 671 |
+
|
| 672 |
+
elif self.input_type == 'gaussian':
|
| 673 |
+
x = self.gaussian_filter.conv_gauss(x)
|
| 674 |
+
|
| 675 |
+
elif self.input_type == 'srm':
|
| 676 |
+
x = self.srm_filter.srm_layer(x)
|
| 677 |
+
|
| 678 |
+
# return x
|
| 679 |
+
B = x.shape[0]
|
| 680 |
+
# get prompting
|
| 681 |
+
|
| 682 |
+
# if '1' in self.tuning_stage:
|
| 683 |
+
# handcrafted1, H1, W1 = self.handcrafted_generator1(x)
|
| 684 |
+
# else:
|
| 685 |
+
# handcrafted1 = None
|
| 686 |
+
|
| 687 |
+
# if '2' in self.tuning_stage:
|
| 688 |
+
# handcrafted2, H2, W2 = self.handcrafted_generator2(handcrafted1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous())
|
| 689 |
+
# else:
|
| 690 |
+
# handcrafted2 = None
|
| 691 |
+
|
| 692 |
+
# if '3' in self.tuning_stage:
|
| 693 |
+
# handcrafted3, H3, W3 = self.handcrafted_generator3(handcrafted2.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous())
|
| 694 |
+
# else:
|
| 695 |
+
# handcrafted3 = None
|
| 696 |
+
|
| 697 |
+
# if '4' in self.tuning_stage:
|
| 698 |
+
# handcrafted4, H4, W4 = self.handcrafted_generator4(handcrafted3.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous())
|
| 699 |
+
# else:
|
| 700 |
+
# handcrafted4 = None
|
| 701 |
+
|
| 702 |
+
return None, None, None, None
|
| 703 |
+
# return handcrafted1, handcrafted2, handcrafted3, handcrafted4
|
| 704 |
+
|
| 705 |
+
def init_prompt(self, embedding_feature, handcrafted_feature, block_num):
|
| 706 |
+
if self.embedding_tune:
|
| 707 |
+
embedding_generator = getattr(self, 'embedding_generator{}'.format(str(block_num)))
|
| 708 |
+
# print(embedding_generator) # 144 -> 4
|
| 709 |
+
# print(embedding_feature.shape) # [1, 256, 256, 96]
|
| 710 |
+
embedding_feature = embedding_generator(embedding_feature)
|
| 711 |
+
if self.handcrafted_tune:
|
| 712 |
+
handcrafted_feature = handcrafted_feature
|
| 713 |
+
|
| 714 |
+
return handcrafted_feature, embedding_feature
|
| 715 |
+
|
| 716 |
+
def get_embedding_feature(self, x, block_num):
|
| 717 |
+
if self.embedding_tune:
|
| 718 |
+
embedding_generator = getattr(self, 'embedding_generator{}'.format(str(block_num)))
|
| 719 |
+
embedding_feature = embedding_generator(x)
|
| 720 |
+
|
| 721 |
+
return embedding_feature
|
| 722 |
+
else:
|
| 723 |
+
return None
|
| 724 |
+
|
| 725 |
+
def get_handcrafted_feature(self, x, block_num):
|
| 726 |
+
if self.handcrafted_tune:
|
| 727 |
+
handcrafted_generator = getattr(self, 'handcrafted_generator{}'.format(str(block_num)))
|
| 728 |
+
handcrafted_feature = handcrafted_generator(x)
|
| 729 |
+
|
| 730 |
+
return handcrafted_feature
|
| 731 |
+
else:
|
| 732 |
+
return None
|
| 733 |
+
|
| 734 |
+
def get_prompt(self, x, prompt, block_num, depth_num):
|
| 735 |
+
feat = 0
|
| 736 |
+
B, H, W = prompt[1].shape[0], prompt[1].shape[1], prompt[1].shape[2]
|
| 737 |
+
if self.handcrafted_tune:
|
| 738 |
+
feat += prompt[0].reshape(B, H, W, -1)
|
| 739 |
+
if self.embedding_tune:
|
| 740 |
+
# if False:
|
| 741 |
+
feat += prompt[1]
|
| 742 |
+
|
| 743 |
+
if self.adaptor == 'adaptor':
|
| 744 |
+
lightweight_mlp = getattr(self, 'lightweight_mlp' + str(block_num) + '_' + str(depth_num))
|
| 745 |
+
shared_mlp = getattr(self, 'shared_mlp' + str(block_num))
|
| 746 |
+
|
| 747 |
+
feat = lightweight_mlp(feat)
|
| 748 |
+
feat = shared_mlp(feat)
|
| 749 |
+
|
| 750 |
+
elif self.adaptor == 'fully_shared':
|
| 751 |
+
fully_shared_mlp = getattr(self, 'fully_shared_mlp' + str(block_num))
|
| 752 |
+
feat = fully_shared_mlp(feat)
|
| 753 |
+
|
| 754 |
+
elif self.adaptor == 'fully_unshared':
|
| 755 |
+
fully_unshared_mlp = getattr(self, 'fully_unshared_mlp' + str(block_num) + '_' + str(depth_num))
|
| 756 |
+
feat = fully_unshared_mlp(feat)
|
| 757 |
+
|
| 758 |
+
x = x + feat
|
| 759 |
+
|
| 760 |
+
return x
|
| 761 |
+
|
| 762 |
+
def fft(self, x, rate, prompt_type):
|
| 763 |
+
mask = torch.zeros(x.shape).to('cuda')
|
| 764 |
+
w, h = x.shape[-2:]
|
| 765 |
+
line = int((w * h * rate) ** .5 // 2)
|
| 766 |
+
mask[:, :, w//2-line:w//2+line, h//2-line:h//2+line] = 1
|
| 767 |
+
|
| 768 |
+
fft = torch.fft.fftshift(torch.fft.fft2(x, norm="forward"))
|
| 769 |
+
|
| 770 |
+
if prompt_type == 'highpass':
|
| 771 |
+
fft = fft * (1 - mask)
|
| 772 |
+
elif prompt_type == 'lowpass':
|
| 773 |
+
fft = fft * mask
|
| 774 |
+
fr = fft.real
|
| 775 |
+
fi = fft.imag
|
| 776 |
+
|
| 777 |
+
fft_hires = torch.fft.ifftshift(torch.complex(fr, fi))
|
| 778 |
+
inv = torch.fft.ifft2(fft_hires, norm="forward").real
|
| 779 |
+
|
| 780 |
+
inv = torch.abs(inv)
|
| 781 |
+
|
| 782 |
+
return inv
|
| 783 |
+
|
| 784 |
+
class GaussianFilter(nn.Module):
|
| 785 |
+
def __init__(self):
|
| 786 |
+
super(GaussianFilter, self).__init__()
|
| 787 |
+
self.kernel = self.gauss_kernel()
|
| 788 |
+
|
| 789 |
+
def gauss_kernel(self, channels=3):
|
| 790 |
+
kernel = torch.tensor([[1., 4., 6., 4., 1],
|
| 791 |
+
[4., 16., 24., 16., 4.],
|
| 792 |
+
[6., 24., 36., 24., 6.],
|
| 793 |
+
[4., 16., 24., 16., 4.],
|
| 794 |
+
[1., 4., 6., 4., 1.]])
|
| 795 |
+
kernel /= 256.
|
| 796 |
+
kernel = kernel.repeat(channels, 1, 1, 1)
|
| 797 |
+
kernel = kernel.to(device)
|
| 798 |
+
return kernel
|
| 799 |
+
|
| 800 |
+
def conv_gauss(self, img):
|
| 801 |
+
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
|
| 802 |
+
out = torch.nn.functional.conv2d(img, self.kernel, groups=img.shape[1])
|
| 803 |
+
return out
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
class SRMFilter(nn.Module):
|
| 807 |
+
def __init__(self):
|
| 808 |
+
super(SRMFilter, self).__init__()
|
| 809 |
+
self.srm_layer = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=2,)
|
| 810 |
+
filter1 = [[0, 0, 0, 0, 0],
|
| 811 |
+
[0, -1 / 4, 2 / 4, -1 / 4, 0],
|
| 812 |
+
[0, 2 / 4, -4 / 4, 2 / 4, 0],
|
| 813 |
+
[0, -1 / 4, 2 / 4, -1 / 4, 0],
|
| 814 |
+
[0, 0, 0, 0, 0]]
|
| 815 |
+
filter2 = [[-1 / 12, 2 / 12, -2 / 12, 2 / 12, -1 / 12],
|
| 816 |
+
[2 / 12, -6 / 12, 8 / 12, -6 / 12, 2 / 12],
|
| 817 |
+
[-2 / 12, 8 / 12, -12 / 12, 8 / 12, -2 / 12],
|
| 818 |
+
[2 / 12, -6 / 12, 8 / 12, -6 / 12, 2 / 12],
|
| 819 |
+
[-1 / 12, 2 / 12, -2 / 12, 2 / 12, -1 / 12]]
|
| 820 |
+
filter3 = [[0, 0, 0, 0, 0],
|
| 821 |
+
[0, 0, 0, 0, 0],
|
| 822 |
+
[0, 1 / 2, -2 / 2, 1 / 2, 0],
|
| 823 |
+
[0, 0, 0, 0, 0],
|
| 824 |
+
[0, 0, 0, 0, 0]]
|
| 825 |
+
self.srm_layer.weight.data = torch.Tensor(
|
| 826 |
+
[[filter1, filter1, filter1],
|
| 827 |
+
[filter2, filter2, filter2],
|
| 828 |
+
[filter3, filter3, filter3]]
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
for param in self.srm_layer.parameters():
|
| 832 |
+
param.requires_grad = False
|
| 833 |
+
|
| 834 |
+
def conv_srm(self, img):
|
| 835 |
+
out = self.srm_layer(img)
|
| 836 |
+
return out
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
class DWConv(nn.Module):
|
| 840 |
+
def __init__(self, dim=768):
|
| 841 |
+
super(DWConv, self).__init__()
|
| 842 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 843 |
+
|
| 844 |
+
def forward(self, x, H, W):
|
| 845 |
+
B, N, C = x.shape
|
| 846 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
| 847 |
+
x = self.dwconv(x)
|
| 848 |
+
x = x.flatten(2).transpose(1, 2)
|
| 849 |
+
|
| 850 |
+
return x
|
sam2/sam2/modeling/backbones/hieradet.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from iopath.common.file_io import g_pathmgr
|
| 15 |
+
|
| 16 |
+
from sam2.modeling.backbones.utils import (
|
| 17 |
+
PatchEmbed,
|
| 18 |
+
window_partition,
|
| 19 |
+
window_unpartition,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from sam2.modeling.sam2_utils import DropPath, MLP
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
| 26 |
+
if pool is None:
|
| 27 |
+
return x
|
| 28 |
+
# (B, H, W, C) -> (B, C, H, W)
|
| 29 |
+
x = x.permute(0, 3, 1, 2)
|
| 30 |
+
x = pool(x)
|
| 31 |
+
# (B, C, H', W') -> (B, H', W', C)
|
| 32 |
+
x = x.permute(0, 2, 3, 1)
|
| 33 |
+
if norm:
|
| 34 |
+
x = norm(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiScaleAttention(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
dim_out: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
q_pool: nn.Module = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.dim = dim
|
| 50 |
+
self.dim_out = dim_out
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
self.q_pool = q_pool
|
| 53 |
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
| 54 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
B, H, W, _ = x.shape
|
| 58 |
+
# qkv with shape (B, H * W, 3, nHead, C)
|
| 59 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
| 60 |
+
# q, k, v with shape (B, H * W, nheads, C)
|
| 61 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 62 |
+
|
| 63 |
+
# Q pooling (for downsample at stage changes)
|
| 64 |
+
if self.q_pool:
|
| 65 |
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
| 66 |
+
H, W = q.shape[1:3] # downsampled shape
|
| 67 |
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
| 68 |
+
|
| 69 |
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
| 70 |
+
x = F.scaled_dot_product_attention(
|
| 71 |
+
q.transpose(1, 2),
|
| 72 |
+
k.transpose(1, 2),
|
| 73 |
+
v.transpose(1, 2),
|
| 74 |
+
)
|
| 75 |
+
# Transpose back
|
| 76 |
+
x = x.transpose(1, 2)
|
| 77 |
+
x = x.reshape(B, H, W, -1)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiScaleBlock(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
dim_out: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
mlp_ratio: float = 4.0,
|
| 91 |
+
drop_path: float = 0.0,
|
| 92 |
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
| 93 |
+
q_stride: Tuple[int, int] = None,
|
| 94 |
+
act_layer: nn.Module = nn.GELU,
|
| 95 |
+
window_size: int = 0,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
if isinstance(norm_layer, str):
|
| 100 |
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
| 101 |
+
|
| 102 |
+
self.dim = dim
|
| 103 |
+
self.dim_out = dim_out
|
| 104 |
+
self.norm1 = norm_layer(dim)
|
| 105 |
+
|
| 106 |
+
self.window_size = window_size
|
| 107 |
+
|
| 108 |
+
self.pool, self.q_stride = None, q_stride
|
| 109 |
+
if self.q_stride:
|
| 110 |
+
self.pool = nn.MaxPool2d(
|
| 111 |
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.attn = MultiScaleAttention(
|
| 115 |
+
dim,
|
| 116 |
+
dim_out,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
q_pool=self.pool,
|
| 119 |
+
)
|
| 120 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 121 |
+
|
| 122 |
+
self.norm2 = norm_layer(dim_out)
|
| 123 |
+
self.mlp = MLP(
|
| 124 |
+
dim_out,
|
| 125 |
+
int(dim_out * mlp_ratio),
|
| 126 |
+
dim_out,
|
| 127 |
+
num_layers=2,
|
| 128 |
+
activation=act_layer,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if dim != dim_out:
|
| 132 |
+
self.proj = nn.Linear(dim, dim_out)
|
| 133 |
+
|
| 134 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
shortcut = x # B, H, W, C
|
| 136 |
+
x = self.norm1(x)
|
| 137 |
+
|
| 138 |
+
# Skip connection
|
| 139 |
+
if self.dim != self.dim_out:
|
| 140 |
+
shortcut = do_pool(self.proj(x), self.pool)
|
| 141 |
+
|
| 142 |
+
# Window partition
|
| 143 |
+
window_size = self.window_size
|
| 144 |
+
if window_size > 0:
|
| 145 |
+
H, W = x.shape[1], x.shape[2]
|
| 146 |
+
x, pad_hw = window_partition(x, window_size)
|
| 147 |
+
|
| 148 |
+
# Window Attention + Q Pooling (if stage change)
|
| 149 |
+
x = self.attn(x)
|
| 150 |
+
if self.q_stride:
|
| 151 |
+
# Shapes have changed due to Q pooling
|
| 152 |
+
window_size = self.window_size // self.q_stride[0]
|
| 153 |
+
H, W = shortcut.shape[1:3]
|
| 154 |
+
|
| 155 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 156 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 157 |
+
pad_hw = (H + pad_h, W + pad_w)
|
| 158 |
+
|
| 159 |
+
# Reverse window partition
|
| 160 |
+
if self.window_size > 0:
|
| 161 |
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
| 162 |
+
|
| 163 |
+
x = shortcut + self.drop_path(x)
|
| 164 |
+
# MLP
|
| 165 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Hiera(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Reference: https://arxiv.org/abs/2306.00989
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
embed_dim: int = 96, # initial embed dim
|
| 177 |
+
num_heads: int = 1, # initial number of heads
|
| 178 |
+
drop_path_rate: float = 0.0, # stochastic depth
|
| 179 |
+
q_pool: int = 3, # number of q_pool stages
|
| 180 |
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
| 181 |
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
| 182 |
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
| 183 |
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
| 184 |
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
| 185 |
+
# window size per stage, when not using global att.
|
| 186 |
+
window_spec: Tuple[int, ...] = (
|
| 187 |
+
8,
|
| 188 |
+
4,
|
| 189 |
+
14,
|
| 190 |
+
7,
|
| 191 |
+
),
|
| 192 |
+
# global attn in these blocks
|
| 193 |
+
global_att_blocks: Tuple[int, ...] = (
|
| 194 |
+
12,
|
| 195 |
+
16,
|
| 196 |
+
20,
|
| 197 |
+
),
|
| 198 |
+
weights_path=None,
|
| 199 |
+
return_interm_layers=True, # return feats from every stage
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
assert len(stages) == len(window_spec)
|
| 204 |
+
self.window_spec = window_spec
|
| 205 |
+
|
| 206 |
+
depth = sum(stages)
|
| 207 |
+
self.q_stride = q_stride
|
| 208 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
| 209 |
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
| 210 |
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
| 211 |
+
self.return_interm_layers = return_interm_layers
|
| 212 |
+
|
| 213 |
+
self.patch_embed = PatchEmbed(
|
| 214 |
+
embed_dim=embed_dim,
|
| 215 |
+
)
|
| 216 |
+
# Which blocks have global att?
|
| 217 |
+
self.global_att_blocks = global_att_blocks
|
| 218 |
+
|
| 219 |
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 220 |
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 221 |
+
self.pos_embed = nn.Parameter(
|
| 222 |
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
| 223 |
+
)
|
| 224 |
+
self.pos_embed_window = nn.Parameter(
|
| 225 |
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
dpr = [
|
| 229 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 230 |
+
] # stochastic depth decay rule
|
| 231 |
+
|
| 232 |
+
cur_stage = 1
|
| 233 |
+
self.blocks = nn.ModuleList()
|
| 234 |
+
|
| 235 |
+
for i in range(depth):
|
| 236 |
+
dim_out = embed_dim
|
| 237 |
+
# lags by a block, so first block of
|
| 238 |
+
# next stage uses an initial window size
|
| 239 |
+
# of previous stage and final window size of current stage
|
| 240 |
+
window_size = self.window_spec[cur_stage - 1]
|
| 241 |
+
|
| 242 |
+
if self.global_att_blocks is not None:
|
| 243 |
+
window_size = 0 if i in self.global_att_blocks else window_size
|
| 244 |
+
|
| 245 |
+
if i - 1 in self.stage_ends:
|
| 246 |
+
dim_out = int(embed_dim * dim_mul)
|
| 247 |
+
num_heads = int(num_heads * head_mul)
|
| 248 |
+
cur_stage += 1
|
| 249 |
+
|
| 250 |
+
block = MultiScaleBlock(
|
| 251 |
+
dim=embed_dim,
|
| 252 |
+
dim_out=dim_out,
|
| 253 |
+
num_heads=num_heads,
|
| 254 |
+
drop_path=dpr[i],
|
| 255 |
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
| 256 |
+
window_size=window_size,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
embed_dim = dim_out
|
| 260 |
+
self.blocks.append(block)
|
| 261 |
+
|
| 262 |
+
self.channel_list = (
|
| 263 |
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 264 |
+
if return_interm_layers
|
| 265 |
+
else [self.blocks[-1].dim_out]
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if weights_path is not None:
|
| 269 |
+
with g_pathmgr.open(weights_path, "rb") as f:
|
| 270 |
+
chkpt = torch.load(f, map_location="cpu")
|
| 271 |
+
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
| 272 |
+
|
| 273 |
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
| 274 |
+
h, w = hw
|
| 275 |
+
window_embed = self.pos_embed_window
|
| 276 |
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 277 |
+
pos_embed = pos_embed + window_embed.tile(
|
| 278 |
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 279 |
+
)
|
| 280 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 281 |
+
return pos_embed
|
| 282 |
+
|
| 283 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 284 |
+
x = self.patch_embed(x)
|
| 285 |
+
# x: (B, H, W, C)
|
| 286 |
+
# print("x.shape", x.shape) # [B, 256, 256, 96]
|
| 287 |
+
# if gra is not None:
|
| 288 |
+
# x += gra
|
| 289 |
+
# print("gra.shape", gra.shape)
|
| 290 |
+
# print("x.shape after gra", x.shape)
|
| 291 |
+
# Add pos embed
|
| 292 |
+
x = x + self._get_pos_embed(x.shape[1:3])
|
| 293 |
+
|
| 294 |
+
outputs = []
|
| 295 |
+
for i, blk in enumerate(self.blocks):
|
| 296 |
+
x = blk(x)
|
| 297 |
+
if (i == self.stage_ends[-1]) or (
|
| 298 |
+
i in self.stage_ends and self.return_interm_layers
|
| 299 |
+
):
|
| 300 |
+
feats = x.permute(0, 3, 1, 2)
|
| 301 |
+
outputs.append(feats)
|
| 302 |
+
|
| 303 |
+
return outputs
|
| 304 |
+
|
| 305 |
+
def get_layer_id(self, layer_name):
|
| 306 |
+
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
| 307 |
+
num_layers = self.get_num_layers()
|
| 308 |
+
|
| 309 |
+
if layer_name.find("rel_pos") != -1:
|
| 310 |
+
return num_layers + 1
|
| 311 |
+
elif layer_name.find("pos_embed") != -1:
|
| 312 |
+
return 0
|
| 313 |
+
elif layer_name.find("patch_embed") != -1:
|
| 314 |
+
return 0
|
| 315 |
+
elif layer_name.find("blocks") != -1:
|
| 316 |
+
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
| 317 |
+
else:
|
| 318 |
+
return num_layers + 1
|
| 319 |
+
|
| 320 |
+
def get_num_layers(self) -> int:
|
| 321 |
+
return len(self.blocks)
|
sam2/sam2/modeling/backbones/image_encoder.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageEncoder(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
trunk: nn.Module,
|
| 18 |
+
neck: nn.Module,
|
| 19 |
+
scalp: int = 0,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.trunk = trunk
|
| 23 |
+
self.neck = neck
|
| 24 |
+
self.scalp = scalp
|
| 25 |
+
# for n, p in self.named_parameters():
|
| 26 |
+
# p.requires_grad = False
|
| 27 |
+
assert (
|
| 28 |
+
self.trunk.channel_list == self.neck.backbone_channel_list
|
| 29 |
+
), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
|
| 30 |
+
|
| 31 |
+
def forward(self, sample: torch.Tensor):
|
| 32 |
+
# Forward through backbone
|
| 33 |
+
features, pos = self.neck(self.trunk(sample))
|
| 34 |
+
if self.scalp > 0:
|
| 35 |
+
# Discard the lowest resolution features
|
| 36 |
+
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
| 37 |
+
|
| 38 |
+
src = features[-1]
|
| 39 |
+
output = {
|
| 40 |
+
"vision_features": src,
|
| 41 |
+
"vision_pos_enc": pos,
|
| 42 |
+
"backbone_fpn": features,
|
| 43 |
+
}
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class FpnNeck(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
A modified variant of Feature Pyramid Network (FPN) neck
|
| 50 |
+
(we remove output conv and also do bicubic interpolation similar to ViT
|
| 51 |
+
pos embed interpolation)
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
position_encoding: nn.Module,
|
| 57 |
+
d_model: int,
|
| 58 |
+
backbone_channel_list: List[int],
|
| 59 |
+
kernel_size: int = 1,
|
| 60 |
+
stride: int = 1,
|
| 61 |
+
padding: int = 0,
|
| 62 |
+
fpn_interp_model: str = "bilinear",
|
| 63 |
+
fuse_type: str = "sum",
|
| 64 |
+
fpn_top_down_levels: Optional[List[int]] = None,
|
| 65 |
+
):
|
| 66 |
+
"""Initialize the neck
|
| 67 |
+
:param trunk: the backbone
|
| 68 |
+
:param position_encoding: the positional encoding to use
|
| 69 |
+
:param d_model: the dimension of the model
|
| 70 |
+
:param neck_norm: the normalization to use
|
| 71 |
+
"""
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.position_encoding = position_encoding
|
| 74 |
+
self.convs = nn.ModuleList()
|
| 75 |
+
self.backbone_channel_list = backbone_channel_list
|
| 76 |
+
self.d_model = d_model
|
| 77 |
+
for dim in backbone_channel_list:
|
| 78 |
+
current = nn.Sequential()
|
| 79 |
+
current.add_module(
|
| 80 |
+
"conv",
|
| 81 |
+
nn.Conv2d(
|
| 82 |
+
in_channels=dim,
|
| 83 |
+
out_channels=d_model,
|
| 84 |
+
kernel_size=kernel_size,
|
| 85 |
+
stride=stride,
|
| 86 |
+
padding=padding,
|
| 87 |
+
),
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.convs.append(current)
|
| 91 |
+
self.fpn_interp_model = fpn_interp_model
|
| 92 |
+
assert fuse_type in ["sum", "avg"]
|
| 93 |
+
self.fuse_type = fuse_type
|
| 94 |
+
|
| 95 |
+
# levels to have top-down features in its outputs
|
| 96 |
+
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
| 97 |
+
# have top-down propagation, while outputs of level 0 and level 1 have only
|
| 98 |
+
# lateral features from the same backbone level.
|
| 99 |
+
if fpn_top_down_levels is None:
|
| 100 |
+
# default is to have top-down features on all levels
|
| 101 |
+
fpn_top_down_levels = range(len(self.convs))
|
| 102 |
+
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
| 103 |
+
|
| 104 |
+
def forward(self, xs: List[torch.Tensor]):
|
| 105 |
+
|
| 106 |
+
out = [None] * len(self.convs)
|
| 107 |
+
pos = [None] * len(self.convs)
|
| 108 |
+
assert len(xs) == len(self.convs)
|
| 109 |
+
# fpn forward pass
|
| 110 |
+
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
| 111 |
+
prev_features = None
|
| 112 |
+
# forward in top-down order (from low to high resolution)
|
| 113 |
+
n = len(self.convs) - 1
|
| 114 |
+
for i in range(n, -1, -1):
|
| 115 |
+
x = xs[i]
|
| 116 |
+
lateral_features = self.convs[n - i](x)
|
| 117 |
+
if i in self.fpn_top_down_levels and prev_features is not None:
|
| 118 |
+
top_down_features = F.interpolate(
|
| 119 |
+
prev_features.to(dtype=torch.float32),
|
| 120 |
+
scale_factor=2.0,
|
| 121 |
+
mode=self.fpn_interp_model,
|
| 122 |
+
align_corners=(
|
| 123 |
+
None if self.fpn_interp_model == "nearest" else False
|
| 124 |
+
),
|
| 125 |
+
antialias=False,
|
| 126 |
+
)
|
| 127 |
+
prev_features = lateral_features + top_down_features
|
| 128 |
+
if self.fuse_type == "avg":
|
| 129 |
+
prev_features /= 2
|
| 130 |
+
else:
|
| 131 |
+
prev_features = lateral_features
|
| 132 |
+
x_out = prev_features
|
| 133 |
+
out[i] = x_out
|
| 134 |
+
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
| 135 |
+
|
| 136 |
+
return out, pos
|
sam2/sam2/modeling/backbones/my_adapter.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from iopath.common.file_io import g_pathmgr
|
| 15 |
+
|
| 16 |
+
from sam2.modeling.backbones.utils import (
|
| 17 |
+
PatchEmbed,
|
| 18 |
+
window_partition,
|
| 19 |
+
window_unpartition,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from sam2.modeling.sam2_utils import DropPath, MLP
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
| 26 |
+
if pool is None:
|
| 27 |
+
return x
|
| 28 |
+
# (B, H, W, C) -> (B, C, H, W)
|
| 29 |
+
x = x.permute(0, 3, 1, 2)
|
| 30 |
+
x = pool(x)
|
| 31 |
+
# (B, C, H', W') -> (B, H', W', C)
|
| 32 |
+
x = x.permute(0, 2, 3, 1)
|
| 33 |
+
if norm:
|
| 34 |
+
x = norm(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiScaleAttention(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
dim: int,
|
| 43 |
+
dim_out: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
q_pool: nn.Module = None,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.dim = dim
|
| 50 |
+
self.dim_out = dim_out
|
| 51 |
+
self.num_heads = num_heads
|
| 52 |
+
self.q_pool = q_pool
|
| 53 |
+
self.qkv = nn.Linear(dim, dim_out * 3)
|
| 54 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
B, H, W, _ = x.shape
|
| 58 |
+
# qkv with shape (B, H * W, 3, nHead, C)
|
| 59 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
| 60 |
+
# q, k, v with shape (B, H * W, nheads, C)
|
| 61 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 62 |
+
|
| 63 |
+
# Q pooling (for downsample at stage changes)
|
| 64 |
+
if self.q_pool:
|
| 65 |
+
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
| 66 |
+
H, W = q.shape[1:3] # downsampled shape
|
| 67 |
+
q = q.reshape(B, H * W, self.num_heads, -1)
|
| 68 |
+
|
| 69 |
+
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
| 70 |
+
x = F.scaled_dot_product_attention(
|
| 71 |
+
q.transpose(1, 2),
|
| 72 |
+
k.transpose(1, 2),
|
| 73 |
+
v.transpose(1, 2),
|
| 74 |
+
)
|
| 75 |
+
# Transpose back
|
| 76 |
+
x = x.transpose(1, 2)
|
| 77 |
+
x = x.reshape(B, H, W, -1)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiScaleBlock(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
dim_out: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
mlp_ratio: float = 4.0,
|
| 91 |
+
drop_path: float = 0.0,
|
| 92 |
+
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
| 93 |
+
q_stride: Tuple[int, int] = None,
|
| 94 |
+
act_layer: nn.Module = nn.GELU,
|
| 95 |
+
window_size: int = 0,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
if isinstance(norm_layer, str):
|
| 100 |
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
| 101 |
+
|
| 102 |
+
self.dim = dim
|
| 103 |
+
self.dim_out = dim_out
|
| 104 |
+
self.norm1 = norm_layer(dim)
|
| 105 |
+
|
| 106 |
+
self.window_size = window_size
|
| 107 |
+
|
| 108 |
+
self.pool, self.q_stride = None, q_stride
|
| 109 |
+
if self.q_stride:
|
| 110 |
+
self.pool = nn.MaxPool2d(
|
| 111 |
+
kernel_size=q_stride, stride=q_stride, ceil_mode=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.attn = MultiScaleAttention(
|
| 115 |
+
dim,
|
| 116 |
+
dim_out,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
q_pool=self.pool,
|
| 119 |
+
)
|
| 120 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 121 |
+
|
| 122 |
+
self.norm2 = norm_layer(dim_out)
|
| 123 |
+
self.mlp = MLP(
|
| 124 |
+
dim_out,
|
| 125 |
+
int(dim_out * mlp_ratio),
|
| 126 |
+
dim_out,
|
| 127 |
+
num_layers=2,
|
| 128 |
+
activation=act_layer,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if dim != dim_out:
|
| 132 |
+
self.proj = nn.Linear(dim, dim_out)
|
| 133 |
+
|
| 134 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
shortcut = x # B, H, W, C
|
| 136 |
+
x = self.norm1(x)
|
| 137 |
+
|
| 138 |
+
# Skip connection
|
| 139 |
+
if self.dim != self.dim_out:
|
| 140 |
+
shortcut = do_pool(self.proj(x), self.pool)
|
| 141 |
+
|
| 142 |
+
# Window partition
|
| 143 |
+
window_size = self.window_size
|
| 144 |
+
if window_size > 0:
|
| 145 |
+
H, W = x.shape[1], x.shape[2]
|
| 146 |
+
x, pad_hw = window_partition(x, window_size)
|
| 147 |
+
|
| 148 |
+
# Window Attention + Q Pooling (if stage change)
|
| 149 |
+
x = self.attn(x)
|
| 150 |
+
if self.q_stride:
|
| 151 |
+
# Shapes have changed due to Q pooling
|
| 152 |
+
window_size = self.window_size // self.q_stride[0]
|
| 153 |
+
H, W = shortcut.shape[1:3]
|
| 154 |
+
|
| 155 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 156 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 157 |
+
pad_hw = (H + pad_h, W + pad_w)
|
| 158 |
+
|
| 159 |
+
# Reverse window partition
|
| 160 |
+
if self.window_size > 0:
|
| 161 |
+
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
| 162 |
+
|
| 163 |
+
x = shortcut + self.drop_path(x)
|
| 164 |
+
# MLP
|
| 165 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Hiera(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Reference: https://arxiv.org/abs/2306.00989
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
embed_dim: int = 96, # initial embed dim
|
| 177 |
+
num_heads: int = 1, # initial number of heads
|
| 178 |
+
drop_path_rate: float = 0.0, # stochastic depth
|
| 179 |
+
q_pool: int = 3, # number of q_pool stages
|
| 180 |
+
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
| 181 |
+
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
| 182 |
+
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
| 183 |
+
head_mul: float = 2.0, # head_mul factor at stage shift
|
| 184 |
+
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
| 185 |
+
# window size per stage, when not using global att.
|
| 186 |
+
window_spec: Tuple[int, ...] = (
|
| 187 |
+
8,
|
| 188 |
+
4,
|
| 189 |
+
14,
|
| 190 |
+
7,
|
| 191 |
+
),
|
| 192 |
+
# global attn in these blocks
|
| 193 |
+
global_att_blocks: Tuple[int, ...] = (
|
| 194 |
+
12,
|
| 195 |
+
16,
|
| 196 |
+
20,
|
| 197 |
+
),
|
| 198 |
+
weights_path=None,
|
| 199 |
+
return_interm_layers=True, # return feats from every stage
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
assert len(stages) == len(window_spec)
|
| 204 |
+
self.window_spec = window_spec
|
| 205 |
+
|
| 206 |
+
depth = sum(stages)
|
| 207 |
+
self.q_stride = q_stride
|
| 208 |
+
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
| 209 |
+
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
| 210 |
+
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
| 211 |
+
self.return_interm_layers = return_interm_layers
|
| 212 |
+
|
| 213 |
+
self.patch_embed = PatchEmbed(
|
| 214 |
+
embed_dim=embed_dim,
|
| 215 |
+
)
|
| 216 |
+
# Which blocks have global att?
|
| 217 |
+
self.global_att_blocks = global_att_blocks
|
| 218 |
+
|
| 219 |
+
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
| 220 |
+
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
| 221 |
+
self.pos_embed = nn.Parameter(
|
| 222 |
+
torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
|
| 223 |
+
)
|
| 224 |
+
self.pos_embed_window = nn.Parameter(
|
| 225 |
+
torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
dpr = [
|
| 229 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 230 |
+
] # stochastic depth decay rule
|
| 231 |
+
|
| 232 |
+
cur_stage = 1
|
| 233 |
+
self.blocks = nn.ModuleList()
|
| 234 |
+
|
| 235 |
+
for i in range(depth):
|
| 236 |
+
dim_out = embed_dim
|
| 237 |
+
# lags by a block, so first block of
|
| 238 |
+
# next stage uses an initial window size
|
| 239 |
+
# of previous stage and final window size of current stage
|
| 240 |
+
window_size = self.window_spec[cur_stage - 1]
|
| 241 |
+
|
| 242 |
+
if self.global_att_blocks is not None:
|
| 243 |
+
window_size = 0 if i in self.global_att_blocks else window_size
|
| 244 |
+
|
| 245 |
+
if i - 1 in self.stage_ends:
|
| 246 |
+
dim_out = int(embed_dim * dim_mul)
|
| 247 |
+
num_heads = int(num_heads * head_mul)
|
| 248 |
+
cur_stage += 1
|
| 249 |
+
|
| 250 |
+
block = MultiScaleBlock(
|
| 251 |
+
dim=embed_dim,
|
| 252 |
+
dim_out=dim_out,
|
| 253 |
+
num_heads=num_heads,
|
| 254 |
+
drop_path=dpr[i],
|
| 255 |
+
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
| 256 |
+
window_size=window_size,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
embed_dim = dim_out
|
| 260 |
+
self.blocks.append(block)
|
| 261 |
+
|
| 262 |
+
self.channel_list = (
|
| 263 |
+
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
| 264 |
+
if return_interm_layers
|
| 265 |
+
else [self.blocks[-1].dim_out]
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if weights_path is not None:
|
| 269 |
+
with g_pathmgr.open(weights_path, "rb") as f:
|
| 270 |
+
chkpt = torch.load(f, map_location="cpu")
|
| 271 |
+
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
| 272 |
+
|
| 273 |
+
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
| 274 |
+
h, w = hw
|
| 275 |
+
window_embed = self.pos_embed_window
|
| 276 |
+
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
| 277 |
+
pos_embed = pos_embed + window_embed.tile(
|
| 278 |
+
[x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
|
| 279 |
+
)
|
| 280 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
| 281 |
+
return pos_embed
|
| 282 |
+
|
| 283 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 284 |
+
x = self.patch_embed(x)
|
| 285 |
+
# x: (B, H, W, C)
|
| 286 |
+
|
| 287 |
+
# Add pos embed
|
| 288 |
+
x = x + self._get_pos_embed(x.shape[1:3])
|
| 289 |
+
|
| 290 |
+
outputs = []
|
| 291 |
+
for i, blk in enumerate(self.blocks):
|
| 292 |
+
x = blk(x)
|
| 293 |
+
if (i == self.stage_ends[-1]) or (
|
| 294 |
+
i in self.stage_ends and self.return_interm_layers
|
| 295 |
+
):
|
| 296 |
+
feats = x.permute(0, 3, 1, 2)
|
| 297 |
+
outputs.append(feats)
|
| 298 |
+
|
| 299 |
+
return outputs
|
| 300 |
+
|
| 301 |
+
def get_layer_id(self, layer_name):
|
| 302 |
+
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
| 303 |
+
num_layers = self.get_num_layers()
|
| 304 |
+
|
| 305 |
+
if layer_name.find("rel_pos") != -1:
|
| 306 |
+
return num_layers + 1
|
| 307 |
+
elif layer_name.find("pos_embed") != -1:
|
| 308 |
+
return 0
|
| 309 |
+
elif layer_name.find("patch_embed") != -1:
|
| 310 |
+
return 0
|
| 311 |
+
elif layer_name.find("blocks") != -1:
|
| 312 |
+
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
| 313 |
+
else:
|
| 314 |
+
return num_layers + 1
|
| 315 |
+
|
| 316 |
+
def get_num_layers(self) -> int:
|
| 317 |
+
return len(self.blocks)
|
sam2/sam2/modeling/backbones/utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Some utilities for backbones, in particular for windowing"""
|
| 8 |
+
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def window_partition(x, window_size):
|
| 17 |
+
"""
|
| 18 |
+
Partition into non-overlapping windows with padding if needed.
|
| 19 |
+
Args:
|
| 20 |
+
x (tensor): input tokens with [B, H, W, C].
|
| 21 |
+
window_size (int): window size.
|
| 22 |
+
Returns:
|
| 23 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
| 24 |
+
(Hp, Wp): padded height and width before partition
|
| 25 |
+
"""
|
| 26 |
+
B, H, W, C = x.shape
|
| 27 |
+
|
| 28 |
+
pad_h = (window_size - H % window_size) % window_size
|
| 29 |
+
pad_w = (window_size - W % window_size) % window_size
|
| 30 |
+
if pad_h > 0 or pad_w > 0:
|
| 31 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
| 32 |
+
Hp, Wp = H + pad_h, W + pad_w
|
| 33 |
+
|
| 34 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
| 35 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
| 36 |
+
return windows, (Hp, Wp)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def window_unpartition(windows, window_size, pad_hw, hw):
|
| 40 |
+
"""
|
| 41 |
+
Window unpartition into original sequences and removing padding.
|
| 42 |
+
Args:
|
| 43 |
+
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
| 44 |
+
window_size (int): window size.
|
| 45 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
| 46 |
+
hw (Tuple): original height and width (H, W) before padding.
|
| 47 |
+
Returns:
|
| 48 |
+
x: unpartitioned sequences with [B, H, W, C].
|
| 49 |
+
"""
|
| 50 |
+
Hp, Wp = pad_hw
|
| 51 |
+
H, W = hw
|
| 52 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
| 53 |
+
x = windows.reshape(
|
| 54 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
| 55 |
+
)
|
| 56 |
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
| 57 |
+
|
| 58 |
+
if Hp > H or Wp > W:
|
| 59 |
+
x = x[:, :H, :W, :]
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PatchEmbed(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Image to Patch Embedding.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
kernel_size: Tuple[int, ...] = (7, 7),
|
| 71 |
+
stride: Tuple[int, ...] = (4, 4),
|
| 72 |
+
padding: Tuple[int, ...] = (3, 3),
|
| 73 |
+
in_chans: int = 3,
|
| 74 |
+
embed_dim: int = 768,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
| 79 |
+
stride (Tuple): stride of the projection layer.
|
| 80 |
+
padding (Tuple): padding size of the projection layer.
|
| 81 |
+
in_chans (int): Number of input image channels.
|
| 82 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
| 83 |
+
"""
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.proj = nn.Conv2d(
|
| 86 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
x = self.proj(x)
|
| 91 |
+
# B C H W -> B H W C
|
| 92 |
+
x = x.permute(0, 2, 3, 1)
|
| 93 |
+
return x
|
sam2/sam2/modeling/memory_attention.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
from sam2.modeling.sam.transformer import RoPEAttention
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_utils import get_activation_fn, get_clones
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MemoryAttentionLayer(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
activation: str,
|
| 22 |
+
cross_attention: nn.Module,
|
| 23 |
+
d_model: int,
|
| 24 |
+
dim_feedforward: int,
|
| 25 |
+
dropout: float,
|
| 26 |
+
pos_enc_at_attn: bool,
|
| 27 |
+
pos_enc_at_cross_attn_keys: bool,
|
| 28 |
+
pos_enc_at_cross_attn_queries: bool,
|
| 29 |
+
self_attention: nn.Module,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.d_model = d_model
|
| 33 |
+
self.dim_feedforward = dim_feedforward
|
| 34 |
+
self.dropout_value = dropout
|
| 35 |
+
self.self_attn = self_attention
|
| 36 |
+
self.cross_attn_image = cross_attention
|
| 37 |
+
|
| 38 |
+
# Implementation of Feedforward model
|
| 39 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 40 |
+
self.dropout = nn.Dropout(dropout)
|
| 41 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 42 |
+
|
| 43 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 44 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 45 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 46 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 47 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 48 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
self.activation_str = activation
|
| 51 |
+
self.activation = get_activation_fn(activation)
|
| 52 |
+
|
| 53 |
+
# Where to add pos enc
|
| 54 |
+
self.pos_enc_at_attn = pos_enc_at_attn
|
| 55 |
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
| 56 |
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
| 57 |
+
|
| 58 |
+
# for n, p in self.named_parameters():
|
| 59 |
+
# p.requires_grad = False
|
| 60 |
+
|
| 61 |
+
def _forward_sa(self, tgt, query_pos):
|
| 62 |
+
# Self-Attention
|
| 63 |
+
tgt2 = self.norm1(tgt)
|
| 64 |
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
| 65 |
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
| 66 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 67 |
+
return tgt
|
| 68 |
+
|
| 69 |
+
def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
|
| 70 |
+
kwds = {}
|
| 71 |
+
if num_k_exclude_rope > 0:
|
| 72 |
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
| 73 |
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
| 74 |
+
|
| 75 |
+
# Cross-Attention
|
| 76 |
+
tgt2 = self.norm2(tgt)
|
| 77 |
+
tgt2 = self.cross_attn_image(
|
| 78 |
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
| 79 |
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
| 80 |
+
v=memory,
|
| 81 |
+
**kwds,
|
| 82 |
+
)
|
| 83 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 84 |
+
return tgt
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
tgt,
|
| 89 |
+
memory,
|
| 90 |
+
pos: Optional[Tensor] = None,
|
| 91 |
+
query_pos: Optional[Tensor] = None,
|
| 92 |
+
num_k_exclude_rope: int = 0,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
|
| 95 |
+
# Self-Attn, Cross-Attn
|
| 96 |
+
tgt = self._forward_sa(tgt, query_pos)
|
| 97 |
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
| 98 |
+
# MLP
|
| 99 |
+
tgt2 = self.norm3(tgt)
|
| 100 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 101 |
+
tgt = tgt + self.dropout3(tgt2)
|
| 102 |
+
return tgt
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class MemoryAttention(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
d_model: int,
|
| 109 |
+
pos_enc_at_input: bool,
|
| 110 |
+
layer: nn.Module,
|
| 111 |
+
num_layers: int,
|
| 112 |
+
batch_first: bool = True, # Do layers expect batch first input?
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.d_model = d_model
|
| 116 |
+
self.layers = get_clones(layer, num_layers)
|
| 117 |
+
self.num_layers = num_layers
|
| 118 |
+
self.norm = nn.LayerNorm(d_model)
|
| 119 |
+
self.pos_enc_at_input = pos_enc_at_input
|
| 120 |
+
self.batch_first = batch_first
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
curr: torch.Tensor, # self-attention inputs
|
| 125 |
+
memory: torch.Tensor, # cross-attention inputs
|
| 126 |
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
| 127 |
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
| 128 |
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
| 129 |
+
):
|
| 130 |
+
if isinstance(curr, list):
|
| 131 |
+
assert isinstance(curr_pos, list)
|
| 132 |
+
assert len(curr) == len(curr_pos) == 1
|
| 133 |
+
curr, curr_pos = (
|
| 134 |
+
curr[0],
|
| 135 |
+
curr_pos[0],
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
assert (
|
| 139 |
+
curr.shape[1] == memory.shape[1]
|
| 140 |
+
), "Batch size must be the same for curr and memory"
|
| 141 |
+
|
| 142 |
+
output = curr
|
| 143 |
+
if self.pos_enc_at_input and curr_pos is not None:
|
| 144 |
+
output = output + 0.1 * curr_pos
|
| 145 |
+
|
| 146 |
+
if self.batch_first:
|
| 147 |
+
# Convert to batch first
|
| 148 |
+
output = output.transpose(0, 1)
|
| 149 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 150 |
+
memory = memory.transpose(0, 1)
|
| 151 |
+
memory_pos = memory_pos.transpose(0, 1)
|
| 152 |
+
|
| 153 |
+
for layer in self.layers:
|
| 154 |
+
kwds = {}
|
| 155 |
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
| 156 |
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
| 157 |
+
|
| 158 |
+
output = layer(
|
| 159 |
+
tgt=output,
|
| 160 |
+
memory=memory,
|
| 161 |
+
pos=memory_pos,
|
| 162 |
+
query_pos=curr_pos,
|
| 163 |
+
**kwds,
|
| 164 |
+
)
|
| 165 |
+
normed_output = self.norm(output)
|
| 166 |
+
|
| 167 |
+
if self.batch_first:
|
| 168 |
+
# Convert back to seq first
|
| 169 |
+
normed_output = normed_output.transpose(0, 1)
|
| 170 |
+
curr_pos = curr_pos.transpose(0, 1)
|
| 171 |
+
|
| 172 |
+
return normed_output
|
sam2/sam2/modeling/memory_encoder.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MaskDownSampler(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Progressively downsample a mask by total_stride, each time by stride.
|
| 20 |
+
Note that LayerNorm is applied per *token*, like in ViT.
|
| 21 |
+
|
| 22 |
+
With each downsample (by a factor stride**2), channel capacity increases by the same factor.
|
| 23 |
+
In the end, we linearly project to embed_dim channels.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
embed_dim=256,
|
| 29 |
+
kernel_size=4,
|
| 30 |
+
stride=4,
|
| 31 |
+
padding=0,
|
| 32 |
+
total_stride=16,
|
| 33 |
+
activation=nn.GELU,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
| 37 |
+
assert stride**num_layers == total_stride
|
| 38 |
+
self.encoder = nn.Sequential()
|
| 39 |
+
mask_in_chans, mask_out_chans = 1, 1
|
| 40 |
+
for _ in range(num_layers):
|
| 41 |
+
mask_out_chans = mask_in_chans * (stride**2)
|
| 42 |
+
self.encoder.append(
|
| 43 |
+
nn.Conv2d(
|
| 44 |
+
mask_in_chans,
|
| 45 |
+
mask_out_chans,
|
| 46 |
+
kernel_size=kernel_size,
|
| 47 |
+
stride=stride,
|
| 48 |
+
padding=padding,
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
self.encoder.append(LayerNorm2d(mask_out_chans))
|
| 52 |
+
self.encoder.append(activation())
|
| 53 |
+
mask_in_chans = mask_out_chans
|
| 54 |
+
|
| 55 |
+
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return self.encoder(x)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
| 62 |
+
class CXBlock(nn.Module):
|
| 63 |
+
r"""ConvNeXt Block. There are two equivalent implementations:
|
| 64 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 65 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 66 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
dim (int): Number of input channels.
|
| 70 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 71 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim,
|
| 77 |
+
kernel_size=7,
|
| 78 |
+
padding=3,
|
| 79 |
+
drop_path=0.0,
|
| 80 |
+
layer_scale_init_value=1e-6,
|
| 81 |
+
use_dwconv=True,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.dwconv = nn.Conv2d(
|
| 85 |
+
dim,
|
| 86 |
+
dim,
|
| 87 |
+
kernel_size=kernel_size,
|
| 88 |
+
padding=padding,
|
| 89 |
+
groups=dim if use_dwconv else 1,
|
| 90 |
+
) # depthwise conv
|
| 91 |
+
self.norm = LayerNorm2d(dim, eps=1e-6)
|
| 92 |
+
self.pwconv1 = nn.Linear(
|
| 93 |
+
dim, 4 * dim
|
| 94 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 95 |
+
self.act = nn.GELU()
|
| 96 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 97 |
+
self.gamma = (
|
| 98 |
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 99 |
+
if layer_scale_init_value > 0
|
| 100 |
+
else None
|
| 101 |
+
)
|
| 102 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
input = x
|
| 106 |
+
x = self.dwconv(x)
|
| 107 |
+
x = self.norm(x)
|
| 108 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 109 |
+
x = self.pwconv1(x)
|
| 110 |
+
x = self.act(x)
|
| 111 |
+
x = self.pwconv2(x)
|
| 112 |
+
if self.gamma is not None:
|
| 113 |
+
x = self.gamma * x
|
| 114 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 115 |
+
|
| 116 |
+
x = input + self.drop_path(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Fuser(nn.Module):
|
| 121 |
+
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.proj = nn.Identity()
|
| 124 |
+
self.layers = get_clones(layer, num_layers)
|
| 125 |
+
|
| 126 |
+
if input_projection:
|
| 127 |
+
assert dim is not None
|
| 128 |
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
# normally x: (N, C, H, W)
|
| 132 |
+
x = self.proj(x)
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
x = layer(x)
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MemoryEncoder(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
out_dim,
|
| 142 |
+
mask_downsampler,
|
| 143 |
+
fuser,
|
| 144 |
+
position_encoding,
|
| 145 |
+
in_dim=256, # in_dim of pix_feats
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.mask_downsampler = mask_downsampler
|
| 150 |
+
|
| 151 |
+
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
| 152 |
+
self.fuser = fuser
|
| 153 |
+
self.position_encoding = position_encoding
|
| 154 |
+
self.out_proj = nn.Identity()
|
| 155 |
+
if out_dim != in_dim:
|
| 156 |
+
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
pix_feat: torch.Tensor,
|
| 161 |
+
masks: torch.Tensor,
|
| 162 |
+
skip_mask_sigmoid: bool = False,
|
| 163 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
## Process masks
|
| 165 |
+
# sigmoid, so that less domain shift from gt masks which are bool
|
| 166 |
+
if not skip_mask_sigmoid:
|
| 167 |
+
masks = F.sigmoid(masks)
|
| 168 |
+
masks = self.mask_downsampler(masks)
|
| 169 |
+
|
| 170 |
+
## Fuse pix_feats and downsampled masks
|
| 171 |
+
# in case the visual features are on CPU, cast them to CUDA
|
| 172 |
+
pix_feat = pix_feat.to(masks.device)
|
| 173 |
+
|
| 174 |
+
x = self.pix_feat_proj(pix_feat)
|
| 175 |
+
x = x + masks
|
| 176 |
+
x = self.fuser(x)
|
| 177 |
+
x = self.out_proj(x)
|
| 178 |
+
|
| 179 |
+
pos = self.position_encoding(x).to(x.dtype)
|
| 180 |
+
|
| 181 |
+
return {"vision_features": x, "vision_pos_enc": [pos]}
|
sam2/sam2/modeling/position_encoding.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Any, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PositionEmbeddingSine(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 19 |
+
used by the Attention Is All You Need paper, generalized to work on images.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
num_pos_feats,
|
| 25 |
+
temperature: int = 10000,
|
| 26 |
+
normalize: bool = True,
|
| 27 |
+
scale: Optional[float] = None,
|
| 28 |
+
# Following settings only relevant
|
| 29 |
+
# for warmping up cache for compilation
|
| 30 |
+
warmup_cache: bool = True,
|
| 31 |
+
image_size: int = 1024,
|
| 32 |
+
strides: Tuple[int] = (4, 8, 16, 32),
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
| 36 |
+
self.num_pos_feats = num_pos_feats // 2
|
| 37 |
+
self.temperature = temperature
|
| 38 |
+
self.normalize = normalize
|
| 39 |
+
if scale is not None and normalize is False:
|
| 40 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 41 |
+
if scale is None:
|
| 42 |
+
scale = 2 * math.pi
|
| 43 |
+
self.scale = scale
|
| 44 |
+
|
| 45 |
+
self.cache = {}
|
| 46 |
+
if warmup_cache and torch.cuda.is_available():
|
| 47 |
+
# Warmup cache for cuda, to help with compilation
|
| 48 |
+
device = torch.device("cuda")
|
| 49 |
+
for stride in strides:
|
| 50 |
+
cache_key = (image_size // stride, image_size // stride)
|
| 51 |
+
self._pe(1, device, *cache_key)
|
| 52 |
+
|
| 53 |
+
def _encode_xy(self, x, y):
|
| 54 |
+
# The positions are expected to be normalized
|
| 55 |
+
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
| 56 |
+
x_embed = x * self.scale
|
| 57 |
+
y_embed = y * self.scale
|
| 58 |
+
|
| 59 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 60 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 61 |
+
|
| 62 |
+
pos_x = x_embed[:, None] / dim_t
|
| 63 |
+
pos_y = y_embed[:, None] / dim_t
|
| 64 |
+
pos_x = torch.stack(
|
| 65 |
+
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
|
| 66 |
+
).flatten(1)
|
| 67 |
+
pos_y = torch.stack(
|
| 68 |
+
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
|
| 69 |
+
).flatten(1)
|
| 70 |
+
return pos_x, pos_y
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def encode_boxes(self, x, y, w, h):
|
| 74 |
+
pos_x, pos_y = self._encode_xy(x, y)
|
| 75 |
+
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
| 76 |
+
return pos
|
| 77 |
+
|
| 78 |
+
encode = encode_boxes # Backwards compatibility
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def encode_points(self, x, y, labels):
|
| 82 |
+
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
| 83 |
+
assert bx == by and nx == ny and bx == bl and nx == nl
|
| 84 |
+
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
| 85 |
+
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
| 86 |
+
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
| 87 |
+
return pos
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def _pe(self, B, device, *cache_key):
|
| 91 |
+
H, W = cache_key
|
| 92 |
+
if cache_key in self.cache:
|
| 93 |
+
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
| 94 |
+
|
| 95 |
+
y_embed = (
|
| 96 |
+
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
| 97 |
+
.view(1, -1, 1)
|
| 98 |
+
.repeat(B, 1, W)
|
| 99 |
+
)
|
| 100 |
+
x_embed = (
|
| 101 |
+
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
| 102 |
+
.view(1, 1, -1)
|
| 103 |
+
.repeat(B, H, 1)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if self.normalize:
|
| 107 |
+
eps = 1e-6
|
| 108 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 109 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 110 |
+
|
| 111 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
| 112 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 113 |
+
|
| 114 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 115 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 116 |
+
pos_x = torch.stack(
|
| 117 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 118 |
+
).flatten(3)
|
| 119 |
+
pos_y = torch.stack(
|
| 120 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 121 |
+
).flatten(3)
|
| 122 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 123 |
+
self.cache[cache_key] = pos[0]
|
| 124 |
+
return pos
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def forward(self, x: torch.Tensor):
|
| 128 |
+
B = x.shape[0]
|
| 129 |
+
cache_key = (x.shape[-2], x.shape[-1])
|
| 130 |
+
return self._pe(B, x.device, *cache_key)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PositionEmbeddingRandom(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
Positional encoding using random spatial frequencies.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
if scale is None or scale <= 0.0:
|
| 141 |
+
scale = 1.0
|
| 142 |
+
self.register_buffer(
|
| 143 |
+
"positional_encoding_gaussian_matrix",
|
| 144 |
+
scale * torch.randn((2, num_pos_feats)),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
| 149 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
| 150 |
+
coords = 2 * coords - 1
|
| 151 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
| 152 |
+
coords = 2 * np.pi * coords
|
| 153 |
+
# outputs d_1 x ... x d_n x C shape
|
| 154 |
+
# for the dummy coords [0, 0], the Fourier features is [0, ..., 0, 1, ..., 1]
|
| 155 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
| 156 |
+
|
| 157 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
| 158 |
+
"""Generate positional encoding for a grid of the specified size."""
|
| 159 |
+
h, w = size
|
| 160 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
| 161 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
| 162 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
| 163 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
| 164 |
+
y_embed = y_embed / h
|
| 165 |
+
x_embed = x_embed / w
|
| 166 |
+
|
| 167 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
| 168 |
+
return pe.permute(2, 0, 1) # C x H x W
|
| 169 |
+
|
| 170 |
+
def forward_with_coords(
|
| 171 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
| 174 |
+
coords = coords_input.clone()
|
| 175 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
| 176 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
| 177 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Rotary Positional Encoding, adapted from:
|
| 181 |
+
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 182 |
+
# 2. https://github.com/naver-ai/rope-vit
|
| 183 |
+
# 3. https://github.com/lucidrains/rotary-embedding-torch
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def init_t_xy(end_x: int, end_y: int):
|
| 187 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
| 188 |
+
t_x = (t % end_x).float()
|
| 189 |
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
| 190 |
+
return t_x, t_y
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
| 194 |
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 195 |
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 196 |
+
|
| 197 |
+
t_x, t_y = init_t_xy(end_x, end_y)
|
| 198 |
+
freqs_x = torch.outer(t_x, freqs_x)
|
| 199 |
+
freqs_y = torch.outer(t_y, freqs_y)
|
| 200 |
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
| 201 |
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
| 202 |
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 206 |
+
ndim = x.ndim
|
| 207 |
+
assert 0 <= 1 < ndim
|
| 208 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
| 209 |
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
| 210 |
+
return freqs_cis.view(*shape)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def apply_rotary_enc(
|
| 214 |
+
xq: torch.Tensor,
|
| 215 |
+
xk: torch.Tensor,
|
| 216 |
+
freqs_cis: torch.Tensor,
|
| 217 |
+
repeat_freqs_k: bool = False,
|
| 218 |
+
):
|
| 219 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 220 |
+
xk_ = (
|
| 221 |
+
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 222 |
+
if xk.shape[-2] != 0
|
| 223 |
+
else None
|
| 224 |
+
)
|
| 225 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 226 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 227 |
+
if xk_ is None:
|
| 228 |
+
# no keys to rotate, due to dropout
|
| 229 |
+
return xq_out.type_as(xq).to(xq.device), xk
|
| 230 |
+
# repeat freqs along seq_len dim to match k seq_len
|
| 231 |
+
if repeat_freqs_k:
|
| 232 |
+
r = xk_.shape[-2] // xq_.shape[-2]
|
| 233 |
+
if freqs_cis.is_cuda:
|
| 234 |
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
| 235 |
+
else:
|
| 236 |
+
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
| 237 |
+
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
| 238 |
+
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
| 239 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 240 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
sam2/sam2/modeling/sam/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sam2/sam2/modeling/sam/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (150 Bytes). View file
|
|
|
sam2/sam2/modeling/sam/__pycache__/gra_mask_decoder.cpython-310.pyc
ADDED
|
Binary file (7.72 kB). View file
|
|
|