John Ho commited on
Commit Β·
5c07489
1
Parent(s): ab54209
added assert statements
Browse files
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: π
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: yellow
|
|
@@ -7,7 +7,7 @@ sdk: gradio
|
|
| 7 |
sdk_version: 5.32.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
short_description:
|
| 11 |
---
|
| 12 |
|
| 13 |
# SAM3 HuggingFace Space Demo
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SAM3
|
| 3 |
emoji: π
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: yellow
|
|
|
|
| 7 |
sdk_version: 5.32.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: SAM3 Video Inference on ZeroGPU
|
| 11 |
---
|
| 12 |
|
| 13 |
# SAM3 HuggingFace Space Demo
|
app.py
CHANGED
|
@@ -137,37 +137,20 @@ def video_inference(input_video, prompt: str):
|
|
| 137 |
Segments objects in a video using a text prompt.
|
| 138 |
Returns a list of detection dicts (one per object per frame) and output video path/status.
|
| 139 |
"""
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
}
|
| 146 |
-
if input_video is None or not prompt:
|
| 147 |
-
return {
|
| 148 |
-
"output_video": None,
|
| 149 |
-
"detections": [],
|
| 150 |
-
"status": "Missing video or prompt.",
|
| 151 |
-
}
|
| 152 |
-
# try:
|
| 153 |
# Gradio passes a dict with 'name' key for uploaded files
|
| 154 |
video_path = (
|
| 155 |
input_video if isinstance(input_video, str) else input_video.get("name", None)
|
| 156 |
)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
"output_video": None,
|
| 160 |
-
"detections": [],
|
| 161 |
-
"status": "Invalid video input.",
|
| 162 |
-
}
|
| 163 |
# Use FFmpeg-based helpers for metadata and frame extraction
|
| 164 |
vmeta = get_video_metadata(video_path, bverbose=False)
|
| 165 |
-
|
| 166 |
-
return {
|
| 167 |
-
"output_video": None,
|
| 168 |
-
"detections": [],
|
| 169 |
-
"status": "Failed to extract video metadata.",
|
| 170 |
-
}
|
| 171 |
vid_fps = vmeta["fps"]
|
| 172 |
vid_w = vmeta["width"]
|
| 173 |
vid_h = vmeta["height"]
|
|
@@ -181,12 +164,8 @@ def video_inference(input_video, prompt: str):
|
|
| 181 |
write_frame_num=False,
|
| 182 |
output_dir=None,
|
| 183 |
)
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
"output_video": None,
|
| 187 |
-
"detections": [],
|
| 188 |
-
"status": "No frames found in video.",
|
| 189 |
-
}
|
| 190 |
# Convert PIL Images to numpy arrays (RGB)
|
| 191 |
video_frames = [np.array(frame.convert("RGB")) for frame in pil_frames]
|
| 192 |
|
|
@@ -195,9 +174,6 @@ def video_inference(input_video, prompt: str):
|
|
| 195 |
)
|
| 196 |
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
|
| 197 |
temp_out_path = tempfile.mktemp(suffix=".mp4")
|
| 198 |
-
# video_writer = cv2.VideoWriter(
|
| 199 |
-
# temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
|
| 200 |
-
# )
|
| 201 |
|
| 202 |
detections = []
|
| 203 |
annotated_frames = []
|
|
@@ -213,7 +189,7 @@ def video_inference(input_video, prompt: str):
|
|
| 213 |
object_ids = [int(oid) for oid in object_ids]
|
| 214 |
if detected_masks.ndim == 4:
|
| 215 |
detected_masks = detected_masks.squeeze(1)
|
| 216 |
-
|
| 217 |
for i, mask in enumerate(detected_masks):
|
| 218 |
mask = mask.cpu().numpy()
|
| 219 |
mask_bin = (mask > 0.0).astype(np.uint8)
|
|
@@ -237,10 +213,8 @@ def video_inference(input_video, prompt: str):
|
|
| 237 |
)
|
| 238 |
else:
|
| 239 |
final_frame = original_pil
|
| 240 |
-
# video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
|
| 241 |
annotated_frames.append(final_frame)
|
| 242 |
|
| 243 |
-
# video_writer.release()
|
| 244 |
return {
|
| 245 |
"output_video": frames_to_vid(
|
| 246 |
annotated_frames,
|
|
@@ -252,12 +226,6 @@ def video_inference(input_video, prompt: str):
|
|
| 252 |
"detections": detections,
|
| 253 |
"status": "Video processing completed successfully.β
",
|
| 254 |
}
|
| 255 |
-
# except Exception as e:
|
| 256 |
-
# return {
|
| 257 |
-
# "output_video": None,
|
| 258 |
-
# "detections": [],
|
| 259 |
-
# "status": f"Error during video processing: {str(e)}",
|
| 260 |
-
# }
|
| 261 |
|
| 262 |
|
| 263 |
# the Gradio App
|
|
|
|
| 137 |
Segments objects in a video using a text prompt.
|
| 138 |
Returns a list of detection dicts (one per object per frame) and output video path/status.
|
| 139 |
"""
|
| 140 |
+
assert type(VID_MODEL) != type(None) and type(VID_PROCESSOR) != type(
|
| 141 |
+
None
|
| 142 |
+
), "Video Models failed to load on startup."
|
| 143 |
+
assert input_video and prompt, "Missing video or prompt."
|
| 144 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
# Gradio passes a dict with 'name' key for uploaded files
|
| 146 |
video_path = (
|
| 147 |
input_video if isinstance(input_video, str) else input_video.get("name", None)
|
| 148 |
)
|
| 149 |
+
assert video_path, "Invalid video input."
|
| 150 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
# Use FFmpeg-based helpers for metadata and frame extraction
|
| 152 |
vmeta = get_video_metadata(video_path, bverbose=False)
|
| 153 |
+
assert vmeta, "Failed to extract video metadata."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
vid_fps = vmeta["fps"]
|
| 155 |
vid_w = vmeta["width"]
|
| 156 |
vid_h = vmeta["height"]
|
|
|
|
| 164 |
write_frame_num=False,
|
| 165 |
output_dir=None,
|
| 166 |
)
|
| 167 |
+
assert len(pil_frames) > 0, "No frames found in video."
|
| 168 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
# Convert PIL Images to numpy arrays (RGB)
|
| 170 |
video_frames = [np.array(frame.convert("RGB")) for frame in pil_frames]
|
| 171 |
|
|
|
|
| 174 |
)
|
| 175 |
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
|
| 176 |
temp_out_path = tempfile.mktemp(suffix=".mp4")
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
detections = []
|
| 179 |
annotated_frames = []
|
|
|
|
| 189 |
object_ids = [int(oid) for oid in object_ids]
|
| 190 |
if detected_masks.ndim == 4:
|
| 191 |
detected_masks = detected_masks.squeeze(1)
|
| 192 |
+
|
| 193 |
for i, mask in enumerate(detected_masks):
|
| 194 |
mask = mask.cpu().numpy()
|
| 195 |
mask_bin = (mask > 0.0).astype(np.uint8)
|
|
|
|
| 213 |
)
|
| 214 |
else:
|
| 215 |
final_frame = original_pil
|
|
|
|
| 216 |
annotated_frames.append(final_frame)
|
| 217 |
|
|
|
|
| 218 |
return {
|
| 219 |
"output_video": frames_to_vid(
|
| 220 |
annotated_frames,
|
|
|
|
| 226 |
"detections": detections,
|
| 227 |
"status": "Video processing completed successfully.β
",
|
| 228 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
# the Gradio App
|