John Ho commited on
Commit
59822ae
·
1 Parent(s): 95ca774

trying async frame load

Browse files
Files changed (2) hide show
  1. app.py +6 -3
  2. samv2_handler.py +4 -1
app.py CHANGED
@@ -128,7 +128,7 @@ def process_video(video_path: str, variant: str, masks: Union[list, str]):
128
  variant: SAMv2's model variant
129
  masks: a list of b64 encoded masks for the first frame of the video, indicating the objects to be tracked
130
  Returns:
131
- list: a list of masks
132
  """
133
  model = load_vid_model(variant=variant)
134
  masks = json.loads(masks) if isinstance(masks, str) else masks
@@ -145,6 +145,7 @@ def process_video(video_path: str, variant: str, masks: Union[list, str]):
145
  device="cuda",
146
  do_tidy_up=True,
147
  drop_mask=False,
 
148
  )
149
 
150
 
@@ -185,10 +186,12 @@ with gr.Blocks() as demo:
185
  choices=["tiny", "small", "base_plus", "large"],
186
  ),
187
  gr.Textbox(
188
- label='Masks for Objects of Interest in the First Frame (JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...])',
189
  value=None,
190
  lines=5,
191
- placeholder='JSON list of dicts: [{"x0":..., "y0":..., "x1":..., "y1":...}, ...]',
 
 
192
  ),
193
  ],
194
  outputs=gr.JSON(label="Output JSON"),
 
128
  variant: SAMv2's model variant
129
  masks: a list of b64 encoded masks for the first frame of the video, indicating the objects to be tracked
130
  Returns:
131
+ list: a list of tracked objects expressed as a list of dictionary [{"frame":..., "track_id":..., "x":..., "y":...,"w":...,"h":...,"conf":..., "mask_b64":...},...]
132
  """
133
  model = load_vid_model(variant=variant)
134
  masks = json.loads(masks) if isinstance(masks, str) else masks
 
145
  device="cuda",
146
  do_tidy_up=True,
147
  drop_mask=False,
148
+ async_frame_load=True,
149
  )
150
 
151
 
 
186
  choices=["tiny", "small", "base_plus", "large"],
187
  ),
188
  gr.Textbox(
189
+ label="Masks for Objects of Interest in the First Frame",
190
  value=None,
191
  lines=5,
192
+ placeholder="""
193
+ JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...]
194
+ """,
195
  ),
196
  ],
197
  outputs=gr.JSON(label="Output JSON"),
samv2_handler.py CHANGED
@@ -160,6 +160,7 @@ def run_sam_video_inference(
160
  every_x: int = None,
161
  do_tidy_up: bool = False,
162
  drop_mask: bool = True,
 
163
  ):
164
  # put video frames into directory
165
  # TODO:
@@ -177,7 +178,9 @@ def run_sam_video_inference(
177
  w = vinfo["frame_width"]
178
  h = vinfo["frame_height"]
179
 
180
- inference_state = model.init_state(video_path=vframes_dir, device=device)
 
 
181
  for i, mask in enumerate(masks):
182
  model.add_new_mask(
183
  inference_state=inference_state, frame_idx=0, obj_id=i, mask=mask
 
160
  every_x: int = None,
161
  do_tidy_up: bool = False,
162
  drop_mask: bool = True,
163
+ async_frame_load: bool = False,
164
  ):
165
  # put video frames into directory
166
  # TODO:
 
178
  w = vinfo["frame_width"]
179
  h = vinfo["frame_height"]
180
 
181
+ inference_state = model.init_state(
182
+ video_path=vframes_dir, device=device, async_loading_frames=async_frame_load
183
+ )
184
  for i, mask in enumerate(masks):
185
  model.add_new_mask(
186
  inference_state=inference_state, frame_idx=0, obj_id=i, mask=mask