Transformers
jiamingZ commited on
Commit
b20fdb1
·
verified ·
1 Parent(s): c370337

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -44
README.md CHANGED
@@ -18,18 +18,27 @@ import os
18
  import torch
19
  import matplotlib.pyplot as plt
20
  from PIL import Image
 
 
21
 
22
  from sam2_plus.build_sam import build_sam2_video_predictor_plus
23
 
24
  from tools.visualization import show_mask, show_box, show_points
25
  from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
26
 
27
- checkpoint = "./checkpoints/SAM2-Plus/checkpoint_phase123.pt"
28
- model_cfg = "configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml"
29
- predictor = build_sam2_video_predictor_plus(model_cfg, checkpoint, task="mask")
30
-
31
- input_video_dir = "./examples/JPEGImages/00001"
32
- input_mask_path = "./examples/Annotations/00001/00000.png"
 
 
 
 
 
 
 
33
  output_mask_dir = "./output/Annotations/"
34
 
35
  score_thresh = 0
@@ -41,9 +50,9 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
41
  frame_names = [
42
  os.path.splitext(p)[0]
43
  for p in os.listdir(input_video_dir)
44
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
45
  ]
46
- frame_names.sort()
47
  height = inference_state["video_height"]
48
  width = inference_state["video_width"]
49
 
@@ -89,7 +98,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
89
  )
90
 
91
  # visualize the tracking results
92
- for out_frame_idx in range(0, len(frame_names)):
93
  plt.clf()
94
  plt.figure()
95
  # plt.title(f"frame {out_frame_idx}")
@@ -99,6 +108,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
99
  plt.axis('off')
100
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
101
  plt.savefig(f"{output_mask_dir}/{video_name}/{out_frame_idx:05d}_withMask.png", dpi=300, bbox_inches='tight', pad_inches=0)
 
102
  ```
103
 
104
  **[Video Object Tracking (Box Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/sot_inference_plus.sh)**
@@ -107,6 +117,8 @@ import os
107
  import torch
108
  import matplotlib.pyplot as plt
109
  from PIL import Image
 
 
110
  import numpy as np
111
  import logging
112
 
@@ -118,12 +130,19 @@ from tools.sot_inference import save_boxes_to_dir, save_masks_and_boxes_to_dir
118
  from training.dataset_plus.box.utils import np_box_xywh_to_xyxy, np_box_xyxy_to_xywh, np_masks_to_boxes, np_box_clamp_xywh
119
  from benchmarks.sot_benchmark.datasets.utils import load_text
120
 
121
- checkpoint = "./checkpoints/SAM2-Plus/checkpoint_phase123.pt"
122
- model_cfg = "configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml"
123
- predictor = build_sam2_video_predictor_plus(model_cfg, checkpoint, task="box")
124
-
125
- input_video_dir = "./examples/JPEGImages/00001"
126
- input_box_path = "./examples/Boxes/00001.txt"
 
 
 
 
 
 
 
127
  output_box_dir = "./output/Boxes/"
128
 
129
  score_thresh = 0
@@ -135,9 +154,9 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
135
  frame_names = [
136
  os.path.splitext(p)[0]
137
  for p in os.listdir(input_video_dir)
138
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
139
  ]
140
- frame_names.sort()
141
  height = inference_state["video_height"]
142
  width = inference_state["video_width"]
143
 
@@ -148,7 +167,8 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
148
  if os.path.isfile(input_box_path):
149
  input_box_xywh = load_text(str(input_box_path), delimiter=',', dtype=np.float64, backend='numpy').reshape(-1, 4)[0]
150
  else:
151
- input_box_xywh = [1026,361,222,169]
 
152
  per_obj_input_box_xyxy = {1: np_box_xywh_to_xyxy(np.array(input_box_xywh))}
153
  object_box_xyxy = per_obj_input_box_xyxy[object_id]
154
 
@@ -189,7 +209,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
189
 
190
  # visualize the tracking results
191
  os.makedirs(os.path.join(output_box_dir, video_name), exist_ok=True)
192
- for out_frame_idx in range(0, len(frame_names)):
193
  plt.clf()
194
  plt.figure()
195
  # plt.title(f"frame {out_frame_idx}")
@@ -201,6 +221,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
201
  plt.axis('off')
202
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
203
  plt.savefig(os.path.join(output_box_dir, video_name, f"{out_frame_idx:05d}_withbox.png"), dpi=300, bbox_inches='tight', pad_inches=0)
 
204
  ```
205
 
206
  **[Point Tracking (Point Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/pt_inference_plus.sh)**
@@ -209,51 +230,53 @@ import os
209
  import torch
210
  import matplotlib.pyplot as plt
211
  from PIL import Image
 
212
  import numpy as np
213
- import logging
214
 
215
  from sam2_plus.build_sam import build_sam2_video_predictor_plus
216
 
217
  from tools.visualization import show_mask, show_box, show_points
218
  from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
219
- from tools.sot_inference import save_boxes_to_dir, save_masks_and_boxes_to_dir
220
  from tools.pt_inference_plus import load_visible_points_from_npz
221
- from training.dataset_plus.box.utils import np_box_xywh_to_xyxy, np_box_xyxy_to_xywh, np_masks_to_boxes, np_box_clamp_xywh
222
- from benchmarks.sot_benchmark.datasets.utils import load_text
223
-
224
- checkpoint = "./checkpoints/SAM2-Plus/checkpoint_phase123.pt"
225
- model_cfg = "configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml"
226
- predictor = build_sam2_video_predictor_plus(model_cfg, checkpoint, task="point")
227
 
228
- input_video_dir = "./examples/JPEGImages/00001"
229
- input_point_path = "./examples/Points/00001.npz"
 
 
 
 
 
 
 
 
 
 
 
230
  output_point_dir = "./output/Points/"
231
 
232
  radius, sigma = 5, 2
233
-
234
  score_thresh = 0
235
 
236
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
237
- inference_state = predictor.init_state(video_path=input_video_dir)
238
-
239
  video_name = os.path.basename(input_video_dir)
240
  frame_names = [
241
  os.path.splitext(p)[0]
242
  for p in os.listdir(input_video_dir)
243
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
244
  ]
245
- frame_names.sort()
 
 
246
  height = inference_state["video_height"]
247
  width = inference_state["video_width"]
248
 
249
  input_frame_idx = 0 # the frame index we interact with
250
- object_id = 1 # give a unique id to each object we interact with (it can be any integers)
 
251
 
252
- input_palette = None
253
- gt_data = np.load(input_point_path, allow_pickle=True)
254
- trajs = gt_data['trajs_2d'].astype(np.float32) # ndarray [N_frames, N_points, 2], xyxy
255
- visible = gt_data['visibs'].astype(bool) # ndarray [N_frames, N_points], bool
256
- input_point, input_visible = torch.tensor(trajs), torch.tensor(visible)
257
  per_obj_input_point = load_visible_points_from_npz(
258
  input_points=input_point,
259
  input_visibles=input_visible,
@@ -272,7 +295,6 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
272
  )
273
 
274
  # run propagation throughout the video and collect the results in a dict
275
- num_frames, num_points = len(frame_names), trajs.shape[1]
276
  point_array = -np.ones((num_frames, num_points, 2), dtype=np.float32)
277
  visible_array = np.zeros((num_frames, num_points), dtype=bool)
278
  for out_frame_idx, out_obj_ids, out_mask_logits, out_box_xyxys, out_obj_score_logits in predictor.propagate_in_video(
@@ -284,14 +306,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
284
  max_score_y, max_score_x = torch.unravel_index(max_index, out_mask_logit.shape)
285
  point_array[out_frame_idx, out_obj_id] = np.array([max_score_x.cpu(), max_score_y.cpu()])
286
  visible_array[out_frame_idx, out_obj_id] = (out_obj_score_logit > score_thresh).cpu().numpy()
287
-
288
  # write the output masks as palette PNG files to output_mask_dir
289
  os.makedirs(output_point_dir, exist_ok=True)
290
  np.savez(os.path.join(output_point_dir, f"{video_name}.npz"), trajs_2d=point_array, visibs=visible_array, size=(width, height))
291
-
292
  # visualize the tracking results
293
  os.makedirs(os.path.join(output_point_dir, video_name), exist_ok=True)
294
- for out_frame_idx in range(0, len(frame_names)):
295
  plt.clf()
296
  plt.figure()
297
  # plt.title(f"frame {out_frame_idx}")
@@ -302,6 +324,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
302
  plt.axis('off')
303
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
304
  plt.savefig(os.path.join(output_point_dir, video_name, f"{out_frame_idx:05d}_withPoint.png"), dpi=300, bbox_inches='tight', pad_inches=0)
 
305
  ```
306
 
307
  ### Load from 🤗 Hugging Face
 
18
  import torch
19
  import matplotlib.pyplot as plt
20
  from PIL import Image
21
+ from tqdm import tqdm
22
+ from natsort import natsorted
23
 
24
  from sam2_plus.build_sam import build_sam2_video_predictor_plus
25
 
26
  from tools.visualization import show_mask, show_box, show_points
27
  from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
28
 
29
+ predictor = build_sam2_video_predictor_plus(
30
+ config_file="configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml",
31
+ ckpt_path="./checkpoints/SAM2-Plus/checkpoint_phase123.pt",
32
+ apply_postprocessing=False,
33
+ hydra_overrides_extra=[
34
+ "++model.non_overlap_masks=" + ("false")
35
+ ],
36
+ vos_optimized=False,
37
+ task='mask'
38
+ )
39
+
40
+ input_video_dir = "./examples/JPEGImages/horsejump-low"
41
+ input_mask_path = "./examples/Annotations/horsejump-low/00000.png"
42
  output_mask_dir = "./output/Annotations/"
43
 
44
  score_thresh = 0
 
50
  frame_names = [
51
  os.path.splitext(p)[0]
52
  for p in os.listdir(input_video_dir)
53
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
54
  ]
55
+ frame_names = natsorted(frame_names)
56
  height = inference_state["video_height"]
57
  width = inference_state["video_width"]
58
 
 
98
  )
99
 
100
  # visualize the tracking results
101
+ for out_frame_idx in tqdm(range(0, len(frame_names)), desc="Visualization Results"):
102
  plt.clf()
103
  plt.figure()
104
  # plt.title(f"frame {out_frame_idx}")
 
108
  plt.axis('off')
109
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
110
  plt.savefig(f"{output_mask_dir}/{video_name}/{out_frame_idx:05d}_withMask.png", dpi=300, bbox_inches='tight', pad_inches=0)
111
+ plt.close()
112
  ```
113
 
114
  **[Video Object Tracking (Box Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/sot_inference_plus.sh)**
 
117
  import torch
118
  import matplotlib.pyplot as plt
119
  from PIL import Image
120
+ from tqdm import tqdm
121
+ from natsort import natsorted
122
  import numpy as np
123
  import logging
124
 
 
130
  from training.dataset_plus.box.utils import np_box_xywh_to_xyxy, np_box_xyxy_to_xywh, np_masks_to_boxes, np_box_clamp_xywh
131
  from benchmarks.sot_benchmark.datasets.utils import load_text
132
 
133
+ predictor = build_sam2_video_predictor_plus(
134
+ config_file="configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml",
135
+ ckpt_path="./checkpoints/SAM2-Plus/checkpoint_phase123.pt",
136
+ apply_postprocessing=False,
137
+ hydra_overrides_extra=[
138
+ "++model.non_overlap_masks=" + ("false")
139
+ ],
140
+ vos_optimized=False,
141
+ task='box'
142
+ )
143
+
144
+ input_video_dir = "./examples/JPEGImages/horsejump-low"
145
+ input_box_path = "./examples/Boxes/horsejump-low.txt"
146
  output_box_dir = "./output/Boxes/"
147
 
148
  score_thresh = 0
 
154
  frame_names = [
155
  os.path.splitext(p)[0]
156
  for p in os.listdir(input_video_dir)
157
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
158
  ]
159
+ frame_names = natsorted(frame_names)
160
  height = inference_state["video_height"]
161
  width = inference_state["video_width"]
162
 
 
167
  if os.path.isfile(input_box_path):
168
  input_box_xywh = load_text(str(input_box_path), delimiter=',', dtype=np.float64, backend='numpy').reshape(-1, 4)[0]
169
  else:
170
+ print(f"Box file {input_box_path} not found. Using default box.")
171
+ input_box_xywh = [316,385,742,488]
172
  per_obj_input_box_xyxy = {1: np_box_xywh_to_xyxy(np.array(input_box_xywh))}
173
  object_box_xyxy = per_obj_input_box_xyxy[object_id]
174
 
 
209
 
210
  # visualize the tracking results
211
  os.makedirs(os.path.join(output_box_dir, video_name), exist_ok=True)
212
+ for out_frame_idx in tqdm(range(0, len(frame_names)), desc="Visualization Results"):
213
  plt.clf()
214
  plt.figure()
215
  # plt.title(f"frame {out_frame_idx}")
 
221
  plt.axis('off')
222
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
223
  plt.savefig(os.path.join(output_box_dir, video_name, f"{out_frame_idx:05d}_withbox.png"), dpi=300, bbox_inches='tight', pad_inches=0)
224
+ plt.close()
225
  ```
226
 
227
  **[Point Tracking (Point Granularity)](https://github.com/MCG-NJU/SAM2-Plus/blob/main/tools/pt_inference_plus.sh)**
 
230
  import torch
231
  import matplotlib.pyplot as plt
232
  from PIL import Image
233
+ from tqdm import tqdm
234
  import numpy as np
235
+ from natsort import natsorted
236
 
237
  from sam2_plus.build_sam import build_sam2_video_predictor_plus
238
 
239
  from tools.visualization import show_mask, show_box, show_points
240
  from tools.vos_inference import load_ann_png, get_per_obj_mask, DAVIS_PALETTE, save_masks_to_dir
 
241
  from tools.pt_inference_plus import load_visible_points_from_npz
 
 
 
 
 
 
242
 
243
+ predictor = build_sam2_video_predictor_plus(
244
+ config_file="configs/sam2.1/sam2.1_hiera_b+_predmasks_decoupled_MAME.yaml",
245
+ ckpt_path="./checkpoints/SAM2-Plus/checkpoint_phase123.pt",
246
+ apply_postprocessing=False,
247
+ hydra_overrides_extra=[
248
+ "++model.non_overlap_masks=" + ("false")
249
+ ],
250
+ vos_optimized=False,
251
+ task='point'
252
+ )
253
+
254
+ input_video_dir = "./examples/JPEGImages/horsejump-low"
255
+ input_point_path = "./examples/Points/horsejump-low.npz"
256
  output_point_dir = "./output/Points/"
257
 
258
  radius, sigma = 5, 2
 
259
  score_thresh = 0
260
 
261
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
 
 
262
  video_name = os.path.basename(input_video_dir)
263
  frame_names = [
264
  os.path.splitext(p)[0]
265
  for p in os.listdir(input_video_dir)
266
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
267
  ]
268
+ frame_names = natsorted(frame_names)
269
+
270
+ inference_state = predictor.init_state(video_path=input_video_dir)
271
  height = inference_state["video_height"]
272
  width = inference_state["video_width"]
273
 
274
  input_frame_idx = 0 # the frame index we interact with
275
+ object_id = 0 # give a unique id to each object we interact with (it can be any integers)
276
+ num_frames, num_points = len(frame_names), 1
277
 
278
+ input_data = np.load(input_point_path, allow_pickle=True)
279
+ input_point, input_visible = torch.tensor(input_data['trajs_2d'].astype(np.float32)), torch.tensor(input_data['visibs'].astype(bool))
 
 
 
280
  per_obj_input_point = load_visible_points_from_npz(
281
  input_points=input_point,
282
  input_visibles=input_visible,
 
295
  )
296
 
297
  # run propagation throughout the video and collect the results in a dict
 
298
  point_array = -np.ones((num_frames, num_points, 2), dtype=np.float32)
299
  visible_array = np.zeros((num_frames, num_points), dtype=bool)
300
  for out_frame_idx, out_obj_ids, out_mask_logits, out_box_xyxys, out_obj_score_logits in predictor.propagate_in_video(
 
306
  max_score_y, max_score_x = torch.unravel_index(max_index, out_mask_logit.shape)
307
  point_array[out_frame_idx, out_obj_id] = np.array([max_score_x.cpu(), max_score_y.cpu()])
308
  visible_array[out_frame_idx, out_obj_id] = (out_obj_score_logit > score_thresh).cpu().numpy()
309
+
310
  # write the output masks as palette PNG files to output_mask_dir
311
  os.makedirs(output_point_dir, exist_ok=True)
312
  np.savez(os.path.join(output_point_dir, f"{video_name}.npz"), trajs_2d=point_array, visibs=visible_array, size=(width, height))
313
+
314
  # visualize the tracking results
315
  os.makedirs(os.path.join(output_point_dir, video_name), exist_ok=True)
316
+ for out_frame_idx in tqdm(range(0, len(frame_names)), desc="Visualization Results"):
317
  plt.clf()
318
  plt.figure()
319
  # plt.title(f"frame {out_frame_idx}")
 
324
  plt.axis('off')
325
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
326
  plt.savefig(os.path.join(output_point_dir, video_name, f"{out_frame_idx:05d}_withPoint.png"), dpi=300, bbox_inches='tight', pad_inches=0)
327
+ plt.close()
328
  ```
329
 
330
  ### Load from 🤗 Hugging Face