John Ho
commited on
Commit
·
95ca774
1
Parent(s):
05f7921
testing video inference
Browse files
app.py
CHANGED
|
@@ -2,6 +2,9 @@ import gradio as gr
|
|
| 2 |
import spaces, torch, os, requests, json
|
| 3 |
from pathlib import Path
|
| 4 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 5 |
from samv2_handler import (
|
| 6 |
load_sam_image_model,
|
| 7 |
run_sam_im_inference,
|
|
@@ -9,9 +12,7 @@ from samv2_handler import (
|
|
| 9 |
run_sam_video_inference,
|
| 10 |
logger,
|
| 11 |
)
|
| 12 |
-
from
|
| 13 |
-
from typing import Union
|
| 14 |
-
import numpy as np
|
| 15 |
|
| 16 |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 17 |
if torch.cuda.get_device_properties(0).major >= 8:
|
|
@@ -125,15 +126,22 @@ def process_video(video_path: str, variant: str, masks: Union[list, str]):
|
|
| 125 |
Args:
|
| 126 |
video_path: path to video object
|
| 127 |
variant: SAMv2's model variant
|
| 128 |
-
masks: a list of masks for the first frame of the video, indicating the objects to be tracked
|
| 129 |
Returns:
|
| 130 |
list: a list of masks
|
| 131 |
"""
|
| 132 |
model = load_vid_model(variant=variant)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return run_sam_video_inference(
|
| 134 |
model,
|
| 135 |
video_path=video_path,
|
| 136 |
-
masks=
|
| 137 |
device="cuda",
|
| 138 |
do_tidy_up=True,
|
| 139 |
drop_mask=False,
|
|
|
|
| 2 |
import spaces, torch, os, requests, json
|
| 3 |
from pathlib import Path
|
| 4 |
from tqdm import tqdm
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from typing import Union
|
| 7 |
+
import numpy as np
|
| 8 |
from samv2_handler import (
|
| 9 |
load_sam_image_model,
|
| 10 |
run_sam_im_inference,
|
|
|
|
| 12 |
run_sam_video_inference,
|
| 13 |
logger,
|
| 14 |
)
|
| 15 |
+
from toolbox.mask_encoding import b64_mask_decode
|
|
|
|
|
|
|
| 16 |
|
| 17 |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 18 |
if torch.cuda.get_device_properties(0).major >= 8:
|
|
|
|
| 126 |
Args:
|
| 127 |
video_path: path to video object
|
| 128 |
variant: SAMv2's model variant
|
| 129 |
+
masks: a list of b64 encoded masks for the first frame of the video, indicating the objects to be tracked
|
| 130 |
Returns:
|
| 131 |
list: a list of masks
|
| 132 |
"""
|
| 133 |
model = load_vid_model(variant=variant)
|
| 134 |
+
masks = json.loads(masks) if isinstance(masks, str) else masks
|
| 135 |
+
logger.debug(f"masks---\n{masks}")
|
| 136 |
+
masks = [
|
| 137 |
+
m[2:-1].encode() if m.startswith("b'") and m.endswith("'") else m for m in masks
|
| 138 |
+
] # expect the b'' literal to be included
|
| 139 |
+
masks = np.array([b64_mask_decode(m).astype(np.uint8) for m in masks])
|
| 140 |
+
logger.debug(f"masks---\n{masks}")
|
| 141 |
return run_sam_video_inference(
|
| 142 |
model,
|
| 143 |
video_path=video_path,
|
| 144 |
+
masks=masks,
|
| 145 |
device="cuda",
|
| 146 |
do_tidy_up=True,
|
| 147 |
drop_mask=False,
|