Transformers
jiamingZ commited on
Commit
c370337
·
verified ·
1 Parent(s): e40e500

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +317 -3
README.md CHANGED
@@ -1,3 +1,317 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ datasets:
4
+ - MCG-NJU/Tracking-Any-Granularity
5
+ library_name: transformers
6
+ ---
7
+
8
+ 🏠 [Homepage](https://tracking-any-granularity.github.io/) | 📄 [Paper](https://arxiv.org/abs/2510.18822) | 🔗 [GitHub](https://github.com/MCG-NJU/SAM2-Plus)
9
+
10
+ Model repository for SAM 2++: Tracking Anything at Any Granularity, a unified video tracking framework that extends the SAM 2 model to track any targets in videos at any granularity, including masks, bounding boxes, and points.
11
+ See the [SAM 2++ paper](https://arxiv.org/abs/2510.18822) for more information.
12
+
13
+ ## Usage
14
+
15
+ **[Video Object Segmentation (Mask Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/vos_inference_plus.sh)**
16
+ ```
17
+ import os
18
+ import torch
19
+ import matplotlib.pyplot as plt
20
+ from PIL import Image
21
+
22
+ from sam2_plus.build_sam import build_sam2_video_predictor_plus
23
+
24
+ from tools.visualization import show_mask, show_box, show_points
25
+ from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
26
+
27
+ checkpoint = "./checkpoints/SAM2-Plus/checkpoint_phase123.pt"
28
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml"
29
+ predictor = build_sam2_video_predictor_plus(model_cfg, checkpoint, task="mask")
30
+
31
+ input_video_dir = "./examples/JPEGImages/00001"
32
+ input_mask_path = "./examples/Annotations/00001/00000.png"
33
+ output_mask_dir = "./output/Annotations/"
34
+
35
+ score_thresh = 0
36
+
37
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
38
+ inference_state = predictor.init_state(video_path=input_video_dir)
39
+
40
+ video_name = os.path.basename(input_video_dir)
41
+ frame_names = [
42
+ os.path.splitext(p)[0]
43
+ for p in os.listdir(input_video_dir)
44
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
45
+ ]
46
+ frame_names.sort()
47
+ height = inference_state["video_height"]
48
+ width = inference_state["video_width"]
49
+
50
+ input_frame_idx = 0 # the frame index we interact with
51
+ object_id = 1 # give a unique id to each object we interact with (it can be any integers)
52
+
53
+ input_palette = None
54
+ input_mask, input_palette = load_ann_png(input_mask_path)
55
+ per_obj_input_mask = get_per_obj_mask(input_mask)
56
+ object_mask = per_obj_input_mask[object_id]
57
+
58
+ predictor.add_new_mask(
59
+ inference_state=inference_state,
60
+ frame_idx=input_frame_idx,
61
+ obj_id=object_id,
62
+ mask=object_mask,
63
+ )
64
+
65
+ # run propagation throughout the video and collect the results in a dict
66
+ os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
67
+ output_palette = input_palette or DAVIS_PALETTE
68
+ video_segments = {} # video_segments contains the per-frame segmentation results
69
+ for out_frame_idx, out_obj_ids, out_mask_logits, _, _ in predictor.propagate_in_video(
70
+ inference_state
71
+ ):
72
+ per_obj_output_mask = {
73
+ out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
74
+ for i, out_obj_id in enumerate(out_obj_ids)
75
+ }
76
+ video_segments[out_frame_idx] = per_obj_output_mask
77
+
78
+ # write the output masks as palette PNG files to output_mask_dir
79
+ for out_frame_idx, per_obj_output_mask in video_segments.items():
80
+ save_masks_to_dir(
81
+ output_mask_dir=output_mask_dir,
82
+ video_name=video_name,
83
+ frame_name=frame_names[out_frame_idx],
84
+ per_obj_output_mask=per_obj_output_mask,
85
+ height=height,
86
+ width=width,
87
+ per_obj_png_file=False,
88
+ output_palette=output_palette,
89
+ )
90
+
91
+ # visualize the tracking results
92
+ for out_frame_idx in range(0, len(frame_names)):
93
+ plt.clf()
94
+ plt.figure()
95
+ # plt.title(f"frame {out_frame_idx}")
96
+ plt.imshow(Image.open(os.path.join(input_video_dir, frame_names[out_frame_idx] + ".jpg")))
97
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
98
+ show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
99
+ plt.axis('off')
100
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
101
+ plt.savefig(f"{output_mask_dir}/{video_name}/{out_frame_idx:05d}_withMask.png", dpi=300, bbox_inches='tight', pad_inches=0)
102
+ ```
103
+
104
+ **[Video Object Tracking (Box Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/sot_inference_plus.sh)**
105
+ ```
106
+ import os
107
+ import torch
108
+ import matplotlib.pyplot as plt
109
+ from PIL import Image
110
+ import numpy as np
111
+ import logging
112
+
113
+ from sam2_plus.build_sam import build_sam2_video_predictor_plus
114
+
115
+ from tools.visualization import show_mask, show_box, show_points
116
+ from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
117
+ from tools.sot_inference import save_boxes_to_dir, save_masks_and_boxes_to_dir
118
+ from training.dataset_plus.box.utils import np_box_xywh_to_xyxy, np_box_xyxy_to_xywh, np_masks_to_boxes, np_box_clamp_xywh
119
+ from benchmarks.sot_benchmark.datasets.utils import load_text
120
+
121
+ checkpoint = "./checkpoints/SAM2-Plus/checkpoint_phase123.pt"
122
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml"
123
+ predictor = build_sam2_video_predictor_plus(model_cfg, checkpoint, task="box")
124
+
125
+ input_video_dir = "./examples/JPEGImages/00001"
126
+ input_box_path = "./examples/Boxes/00001.txt"
127
+ output_box_dir = "./output/Boxes/"
128
+
129
+ score_thresh = 0
130
+
131
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
132
+ inference_state = predictor.init_state(video_path=input_video_dir)
133
+
134
+ video_name = os.path.basename(input_video_dir)
135
+ frame_names = [
136
+ os.path.splitext(p)[0]
137
+ for p in os.listdir(input_video_dir)
138
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
139
+ ]
140
+ frame_names.sort()
141
+ height = inference_state["video_height"]
142
+ width = inference_state["video_width"]
143
+
144
+ input_frame_idx = 0 # the frame index we interact with
145
+ object_id = 1 # give a unique id to each object we interact with (it can be any integers)
146
+
147
+ input_palette = None
148
+ if os.path.isfile(input_box_path):
149
+ input_box_xywh = load_text(str(input_box_path), delimiter=',', dtype=np.float64, backend='numpy').reshape(-1, 4)[0]
150
+ else:
151
+ input_box_xywh = [1026,361,222,169]
152
+ per_obj_input_box_xyxy = {1: np_box_xywh_to_xyxy(np.array(input_box_xywh))}
153
+ object_box_xyxy = per_obj_input_box_xyxy[object_id]
154
+
155
+ frame_idx, obj_ids, masks, _ = predictor.add_new_points_or_box(
156
+ inference_state=inference_state,
157
+ frame_idx=input_frame_idx,
158
+ obj_id=object_id,
159
+ box=object_box_xyxy,
160
+ )
161
+
162
+ # run propagation throughout the video and collect the results in a dict
163
+ output_palette = input_palette or DAVIS_PALETTE
164
+ video_segments = {} # video_segments contains the per-frame segmentation results
165
+ video_boxes_xywh = {} # video_boxes_xyxy contains the per-frame bounding box results
166
+ for out_frame_idx, out_obj_ids, out_mask_logits, output_box_xyxy, out_obj_score_logits in predictor.propagate_in_video(
167
+ inference_state=inference_state,
168
+ ):
169
+ if torch.any(output_box_xyxy[:,:,0] >= output_box_xyxy[:,:,2]) or torch.any(output_box_xyxy[:,:,1] >= output_box_xyxy[:,:,3]):
170
+ logging.warning(f"Invalid box prediction: {output_box_xyxy}")
171
+
172
+ per_obj_output_mask = {
173
+ out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
174
+ for i, out_obj_id in enumerate(out_obj_ids)
175
+ }
176
+ video_segments[out_frame_idx] = per_obj_output_mask
177
+ per_obj_output_box_xywh = {
178
+ out_obj_id: np_box_clamp_xywh(np_box_xyxy_to_xywh(output_box_xyxy[i].cpu().numpy()))
179
+ for i, out_obj_id in enumerate(out_obj_ids)
180
+ }
181
+ video_boxes_xywh[out_frame_idx] = per_obj_output_box_xywh
182
+
183
+ # save the tracking results
184
+ save_boxes_to_dir(
185
+ output_bbox_dir=output_box_dir,
186
+ video_name=video_name,
187
+ video_boxes_xywh=video_boxes_xywh,
188
+ )
189
+
190
+ # visualize the tracking results
191
+ os.makedirs(os.path.join(output_box_dir, video_name), exist_ok=True)
192
+ for out_frame_idx in range(0, len(frame_names)):
193
+ plt.clf()
194
+ plt.figure()
195
+ # plt.title(f"frame {out_frame_idx}")
196
+ plt.imshow(Image.open(os.path.join(input_video_dir, frame_names[out_frame_idx] + ".jpg")))
197
+ for out_obj_id, out_box in video_boxes_xywh[out_frame_idx].items():
198
+ box_xywh = out_box[0]
199
+ box_xyxy = np_box_xywh_to_xyxy(np.array(box_xywh))
200
+ show_box(box_xyxy, plt.gca())
201
+ plt.axis('off')
202
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
203
+ plt.savefig(os.path.join(output_box_dir, video_name, f"{out_frame_idx:05d}_withbox.png"), dpi=300, bbox_inches='tight', pad_inches=0)
204
+ ```
205
+
206
+ **[Point Tracking (Point Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/pt_inference_plus.sh)**
207
+ ```
208
+ import os
209
+ import torch
210
+ import matplotlib.pyplot as plt
211
+ from PIL import Image
212
+ import numpy as np
213
+ import logging
214
+
215
+ from sam2_plus.build_sam import build_sam2_video_predictor_plus
216
+
217
+ from tools.visualization import show_mask, show_box, show_points
218
+ from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
219
+ from tools.sot_inference import save_boxes_to_dir, save_masks_and_boxes_to_dir
220
+ from tools.pt_inference_plus import load_visible_points_from_npz
221
+ from training.dataset_plus.box.utils import np_box_xywh_to_xyxy, np_box_xyxy_to_xywh, np_masks_to_boxes, np_box_clamp_xywh
222
+ from benchmarks.sot_benchmark.datasets.utils import load_text
223
+
224
+ checkpoint = "./checkpoints/SAM2-Plus/checkpoint_phase123.pt"
225
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml"
226
+ predictor = build_sam2_video_predictor_plus(model_cfg, checkpoint, task="point")
227
+
228
+ input_video_dir = "./examples/JPEGImages/00001"
229
+ input_point_path = "./examples/Points/00001.npz"
230
+ output_point_dir = "./output/Points/"
231
+
232
+ radius, sigma = 5, 2
233
+
234
+ score_thresh = 0
235
+
236
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
237
+ inference_state = predictor.init_state(video_path=input_video_dir)
238
+
239
+ video_name = os.path.basename(input_video_dir)
240
+ frame_names = [
241
+ os.path.splitext(p)[0]
242
+ for p in os.listdir(input_video_dir)
243
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
244
+ ]
245
+ frame_names.sort()
246
+ height = inference_state["video_height"]
247
+ width = inference_state["video_width"]
248
+
249
+ input_frame_idx = 0 # the frame index we interact with
250
+ object_id = 1 # give a unique id to each object we interact with (it can be any integers)
251
+
252
+ input_palette = None
253
+ gt_data = np.load(input_point_path, allow_pickle=True)
254
+ trajs = gt_data['trajs_2d'].astype(np.float32) # ndarray [N_frames, N_points, 2], xyxy
255
+ visible = gt_data['visibs'].astype(bool) # ndarray [N_frames, N_points], bool
256
+ input_point, input_visible = torch.tensor(trajs), torch.tensor(visible)
257
+ per_obj_input_point = load_visible_points_from_npz(
258
+ input_points=input_point,
259
+ input_visibles=input_visible,
260
+ frame_idx=input_frame_idx,
261
+ )
262
+ object_point = per_obj_input_point[object_id]
263
+
264
+ predictor.add_new_points_and_generate_gaussian_mask(
265
+ inference_state=inference_state,
266
+ frame_idx=input_frame_idx,
267
+ obj_id=object_id,
268
+ points=object_point.unsqueeze(0).numpy(),
269
+ labels=np.array([1]),
270
+ radius=radius,
271
+ sigma=sigma,
272
+ )
273
+
274
+ # run propagation throughout the video and collect the results in a dict
275
+ num_frames, num_points = len(frame_names), trajs.shape[1]
276
+ point_array = -np.ones((num_frames, num_points, 2), dtype=np.float32)
277
+ visible_array = np.zeros((num_frames, num_points), dtype=bool)
278
+ for out_frame_idx, out_obj_ids, out_mask_logits, out_box_xyxys, out_obj_score_logits in predictor.propagate_in_video(
279
+ inference_state
280
+ ):
281
+ for out_obj_id, out_mask_logit, out_obj_score_logit in zip(out_obj_ids, out_mask_logits, out_obj_score_logits):
282
+ out_mask_logit, out_obj_score_logit = out_mask_logit.squeeze(0), out_obj_score_logit.squeeze(0)
283
+ max_index = torch.argmax(out_mask_logit)
284
+ max_score_y, max_score_x = torch.unravel_index(max_index, out_mask_logit.shape)
285
+ point_array[out_frame_idx, out_obj_id] = np.array([max_score_x.cpu(), max_score_y.cpu()])
286
+ visible_array[out_frame_idx, out_obj_id] = (out_obj_score_logit > score_thresh).cpu().numpy()
287
+
288
+ # write the output masks as palette PNG files to output_mask_dir
289
+ os.makedirs(output_point_dir, exist_ok=True)
290
+ np.savez(os.path.join(output_point_dir, f"{video_name}.npz"), trajs_2d=point_array, visibs=visible_array, size=(width, height))
291
+
292
+ # visualize the tracking results
293
+ os.makedirs(os.path.join(output_point_dir, video_name), exist_ok=True)
294
+ for out_frame_idx in range(0, len(frame_names)):
295
+ plt.clf()
296
+ plt.figure()
297
+ # plt.title(f"frame {out_frame_idx}")
298
+ plt.imshow(Image.open(os.path.join(input_video_dir, frame_names[out_frame_idx] + ".jpg")))
299
+ points = point_array[out_frame_idx, object_id].reshape(1, 2)
300
+ labels = np.array([-1], np.int32)
301
+ show_points(points, labels, plt.gca(), marker_size=20, edgecolor=None)
302
+ plt.axis('off')
303
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
304
+ plt.savefig(os.path.join(output_point_dir, video_name, f"{out_frame_idx:05d}_withPoint.png"), dpi=300, bbox_inches='tight', pad_inches=0)
305
+ ```
306
+
307
+ ### Load from 🤗 Hugging Face
308
+
309
+ Models can alternatively be loaded from [Hugging Face](https://huggingface.co/MCG-NJU/SAM2-Plus)
310
+
311
+ ```
312
+ import torch
313
+ from sam2_plus.sam2_video_predictor import SAM2VideoPredictor_Plus
314
+
315
+ predictor = SAM2VideoPredictor_Plus.from_pretrained("MCG-NJU/SAM2-Plus")
316
+ ```
317
+