rollback
Browse files- .gitignore +2 -0
- app.py +10 -30
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.egg-info/
|
| 2 |
+
__pycache__/
|
app.py
CHANGED
|
@@ -71,25 +71,17 @@ examples = [
|
|
| 71 |
OBJ_ID = 0
|
| 72 |
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
model_cfg = "edgetam.yaml"
|
| 79 |
-
predictor = build_sam2_video_predictor(
|
| 80 |
-
model_cfg, sam2_checkpoint, device="cuda"
|
| 81 |
-
)
|
| 82 |
-
print("predictor loaded")
|
| 83 |
-
|
| 84 |
-
# use bfloat16 for the entire demo
|
| 85 |
-
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 86 |
-
if torch.cuda.get_device_properties(0).major >= 8:
|
| 87 |
-
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
| 88 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 89 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
def get_video_fps(video_path):
|
|
@@ -106,10 +98,7 @@ def get_video_fps(video_path):
|
|
| 106 |
return fps
|
| 107 |
|
| 108 |
|
| 109 |
-
@spaces.GPU
|
| 110 |
def reset(session_state):
|
| 111 |
-
predictor = get_predictor(session_state)
|
| 112 |
-
predictor.to("cuda")
|
| 113 |
session_state["input_points"] = []
|
| 114 |
session_state["input_labels"] = []
|
| 115 |
if session_state["inference_state"] is not None:
|
|
@@ -127,10 +116,7 @@ def reset(session_state):
|
|
| 127 |
)
|
| 128 |
|
| 129 |
|
| 130 |
-
@spaces.GPU
|
| 131 |
def clear_points(session_state):
|
| 132 |
-
predictor = get_predictor(session_state)
|
| 133 |
-
predictor.to("cuda")
|
| 134 |
session_state["input_points"] = []
|
| 135 |
session_state["input_labels"] = []
|
| 136 |
if session_state["inference_state"]["tracking_has_started"]:
|
|
@@ -145,8 +131,6 @@ def clear_points(session_state):
|
|
| 145 |
|
| 146 |
@spaces.GPU
|
| 147 |
def preprocess_video_in(video_path, session_state):
|
| 148 |
-
predictor = get_predictor(session_state)
|
| 149 |
-
predictor.to("cuda")
|
| 150 |
if video_path is None:
|
| 151 |
return (
|
| 152 |
gr.update(open=True), # video_in_drawer
|
|
@@ -210,8 +194,6 @@ def segment_with_points(
|
|
| 210 |
session_state,
|
| 211 |
evt: gr.SelectData,
|
| 212 |
):
|
| 213 |
-
predictor = get_predictor(session_state)
|
| 214 |
-
predictor.to("cuda")
|
| 215 |
session_state["input_points"].append(evt.index)
|
| 216 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
| 217 |
|
|
@@ -285,8 +267,6 @@ def propagate_to_all(
|
|
| 285 |
video_in,
|
| 286 |
session_state,
|
| 287 |
):
|
| 288 |
-
predictor = get_predictor(session_state)
|
| 289 |
-
predictor.to("cuda")
|
| 290 |
if (
|
| 291 |
len(session_state["input_points"]) == 0
|
| 292 |
or video_in is None
|
|
|
|
| 71 |
OBJ_ID = 0
|
| 72 |
|
| 73 |
|
| 74 |
+
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 75 |
+
model_cfg = "edgetam.yaml"
|
| 76 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
| 77 |
+
print("predictor loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
# use bfloat16 for the entire demo
|
| 80 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 81 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 82 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
| 83 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 84 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 85 |
|
| 86 |
|
| 87 |
def get_video_fps(video_path):
|
|
|
|
| 98 |
return fps
|
| 99 |
|
| 100 |
|
|
|
|
| 101 |
def reset(session_state):
|
|
|
|
|
|
|
| 102 |
session_state["input_points"] = []
|
| 103 |
session_state["input_labels"] = []
|
| 104 |
if session_state["inference_state"] is not None:
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
|
|
|
|
| 119 |
def clear_points(session_state):
|
|
|
|
|
|
|
| 120 |
session_state["input_points"] = []
|
| 121 |
session_state["input_labels"] = []
|
| 122 |
if session_state["inference_state"]["tracking_has_started"]:
|
|
|
|
| 131 |
|
| 132 |
@spaces.GPU
|
| 133 |
def preprocess_video_in(video_path, session_state):
|
|
|
|
|
|
|
| 134 |
if video_path is None:
|
| 135 |
return (
|
| 136 |
gr.update(open=True), # video_in_drawer
|
|
|
|
| 194 |
session_state,
|
| 195 |
evt: gr.SelectData,
|
| 196 |
):
|
|
|
|
|
|
|
| 197 |
session_state["input_points"].append(evt.index)
|
| 198 |
print(f"TRACKING INPUT POINT: {session_state['input_points']}")
|
| 199 |
|
|
|
|
| 267 |
video_in,
|
| 268 |
session_state,
|
| 269 |
):
|
|
|
|
|
|
|
| 270 |
if (
|
| 271 |
len(session_state["input_points"]) == 0
|
| 272 |
or video_in is None
|