John Ho commited on
Commit
95ca774
·
1 Parent(s): 05f7921

testing video inference

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -2,6 +2,9 @@ import gradio as gr
2
  import spaces, torch, os, requests, json
3
  from pathlib import Path
4
  from tqdm import tqdm
 
 
 
5
  from samv2_handler import (
6
  load_sam_image_model,
7
  run_sam_im_inference,
@@ -9,9 +12,7 @@ from samv2_handler import (
9
  run_sam_video_inference,
10
  logger,
11
  )
12
- from PIL import Image
13
- from typing import Union
14
- import numpy as np
15
 
16
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
17
  if torch.cuda.get_device_properties(0).major >= 8:
@@ -125,15 +126,22 @@ def process_video(video_path: str, variant: str, masks: Union[list, str]):
125
  Args:
126
  video_path: path to video object
127
  variant: SAMv2's model variant
128
- masks: a list of masks for the first frame of the video, indicating the objects to be tracked
129
  Returns:
130
  list: a list of masks
131
  """
132
  model = load_vid_model(variant=variant)
 
 
 
 
 
 
 
133
  return run_sam_video_inference(
134
  model,
135
  video_path=video_path,
136
- masks=np.array(masks),
137
  device="cuda",
138
  do_tidy_up=True,
139
  drop_mask=False,
 
2
  import spaces, torch, os, requests, json
3
  from pathlib import Path
4
  from tqdm import tqdm
5
+ from PIL import Image
6
+ from typing import Union
7
+ import numpy as np
8
  from samv2_handler import (
9
  load_sam_image_model,
10
  run_sam_im_inference,
 
12
  run_sam_video_inference,
13
  logger,
14
  )
15
+ from toolbox.mask_encoding import b64_mask_decode
 
 
16
 
17
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
18
  if torch.cuda.get_device_properties(0).major >= 8:
 
126
  Args:
127
  video_path: path to video object
128
  variant: SAMv2's model variant
129
+ masks: a list of b64 encoded masks for the first frame of the video, indicating the objects to be tracked
130
  Returns:
131
  list: a list of masks
132
  """
133
  model = load_vid_model(variant=variant)
134
+ masks = json.loads(masks) if isinstance(masks, str) else masks
135
+ logger.debug(f"masks---\n{masks}")
136
+ masks = [
137
+ m[2:-1].encode() if m.startswith("b'") and m.endswith("'") else m for m in masks
138
+ ] # expect the b'' literal to be included
139
+ masks = np.array([b64_mask_decode(m).astype(np.uint8) for m in masks])
140
+ logger.debug(f"masks---\n{masks}")
141
  return run_sam_video_inference(
142
  model,
143
  video_path=video_path,
144
+ masks=masks,
145
  device="cuda",
146
  do_tidy_up=True,
147
  drop_mask=False,