John Ho commited on
Commit
5c07489
Β·
1 Parent(s): ab54209

added assert statements

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +11 -43
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Name for you Space App
3
  emoji: πŸ“š
4
  colorFrom: blue
5
  colorTo: yellow
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 5.32.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: short description for your Space App
11
  ---
12
 
13
  # SAM3 HuggingFace Space Demo
 
1
  ---
2
+ title: SAM3
3
  emoji: πŸ“š
4
  colorFrom: blue
5
  colorTo: yellow
 
7
  sdk_version: 5.32.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: SAM3 Video Inference on ZeroGPU
11
  ---
12
 
13
  # SAM3 HuggingFace Space Demo
app.py CHANGED
@@ -137,37 +137,20 @@ def video_inference(input_video, prompt: str):
137
  Segments objects in a video using a text prompt.
138
  Returns a list of detection dicts (one per object per frame) and output video path/status.
139
  """
140
- if VID_MODEL is None or VID_PROCESSOR is None:
141
- return {
142
- "output_video": None,
143
- "detections": [],
144
- "status": "Video Models failed to load on startup.",
145
- }
146
- if input_video is None or not prompt:
147
- return {
148
- "output_video": None,
149
- "detections": [],
150
- "status": "Missing video or prompt.",
151
- }
152
- # try:
153
  # Gradio passes a dict with 'name' key for uploaded files
154
  video_path = (
155
  input_video if isinstance(input_video, str) else input_video.get("name", None)
156
  )
157
- if not video_path:
158
- return {
159
- "output_video": None,
160
- "detections": [],
161
- "status": "Invalid video input.",
162
- }
163
  # Use FFmpeg-based helpers for metadata and frame extraction
164
  vmeta = get_video_metadata(video_path, bverbose=False)
165
- if not vmeta:
166
- return {
167
- "output_video": None,
168
- "detections": [],
169
- "status": "Failed to extract video metadata.",
170
- }
171
  vid_fps = vmeta["fps"]
172
  vid_w = vmeta["width"]
173
  vid_h = vmeta["height"]
@@ -181,12 +164,8 @@ def video_inference(input_video, prompt: str):
181
  write_frame_num=False,
182
  output_dir=None,
183
  )
184
- if len(pil_frames) == 0:
185
- return {
186
- "output_video": None,
187
- "detections": [],
188
- "status": "No frames found in video.",
189
- }
190
  # Convert PIL Images to numpy arrays (RGB)
191
  video_frames = [np.array(frame.convert("RGB")) for frame in pil_frames]
192
 
@@ -195,9 +174,6 @@ def video_inference(input_video, prompt: str):
195
  )
196
  session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
197
  temp_out_path = tempfile.mktemp(suffix=".mp4")
198
- # video_writer = cv2.VideoWriter(
199
- # temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
200
- # )
201
 
202
  detections = []
203
  annotated_frames = []
@@ -213,7 +189,7 @@ def video_inference(input_video, prompt: str):
213
  object_ids = [int(oid) for oid in object_ids]
214
  if detected_masks.ndim == 4:
215
  detected_masks = detected_masks.squeeze(1)
216
- # detected_masks: (num_objects, H, W)
217
  for i, mask in enumerate(detected_masks):
218
  mask = mask.cpu().numpy()
219
  mask_bin = (mask > 0.0).astype(np.uint8)
@@ -237,10 +213,8 @@ def video_inference(input_video, prompt: str):
237
  )
238
  else:
239
  final_frame = original_pil
240
- # video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
241
  annotated_frames.append(final_frame)
242
 
243
- # video_writer.release()
244
  return {
245
  "output_video": frames_to_vid(
246
  annotated_frames,
@@ -252,12 +226,6 @@ def video_inference(input_video, prompt: str):
252
  "detections": detections,
253
  "status": "Video processing completed successfully.βœ…",
254
  }
255
- # except Exception as e:
256
- # return {
257
- # "output_video": None,
258
- # "detections": [],
259
- # "status": f"Error during video processing: {str(e)}",
260
- # }
261
 
262
 
263
  # the Gradio App
 
137
  Segments objects in a video using a text prompt.
138
  Returns a list of detection dicts (one per object per frame) and output video path/status.
139
  """
140
+ assert type(VID_MODEL) != type(None) and type(VID_PROCESSOR) != type(
141
+ None
142
+ ), "Video Models failed to load on startup."
143
+ assert input_video and prompt, "Missing video or prompt."
144
+
 
 
 
 
 
 
 
 
145
  # Gradio passes a dict with 'name' key for uploaded files
146
  video_path = (
147
  input_video if isinstance(input_video, str) else input_video.get("name", None)
148
  )
149
+ assert video_path, "Invalid video input."
150
+
 
 
 
 
151
  # Use FFmpeg-based helpers for metadata and frame extraction
152
  vmeta = get_video_metadata(video_path, bverbose=False)
153
+ assert vmeta, "Failed to extract video metadata."
 
 
 
 
 
154
  vid_fps = vmeta["fps"]
155
  vid_w = vmeta["width"]
156
  vid_h = vmeta["height"]
 
164
  write_frame_num=False,
165
  output_dir=None,
166
  )
167
+ assert len(pil_frames) > 0, "No frames found in video."
168
+
 
 
 
 
169
  # Convert PIL Images to numpy arrays (RGB)
170
  video_frames = [np.array(frame.convert("RGB")) for frame in pil_frames]
171
 
 
174
  )
175
  session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
176
  temp_out_path = tempfile.mktemp(suffix=".mp4")
 
 
 
177
 
178
  detections = []
179
  annotated_frames = []
 
189
  object_ids = [int(oid) for oid in object_ids]
190
  if detected_masks.ndim == 4:
191
  detected_masks = detected_masks.squeeze(1)
192
+
193
  for i, mask in enumerate(detected_masks):
194
  mask = mask.cpu().numpy()
195
  mask_bin = (mask > 0.0).astype(np.uint8)
 
213
  )
214
  else:
215
  final_frame = original_pil
 
216
  annotated_frames.append(final_frame)
217
 
 
218
  return {
219
  "output_video": frames_to_vid(
220
  annotated_frames,
 
226
  "detections": detections,
227
  "status": "Video processing completed successfully.βœ…",
228
  }
 
 
 
 
 
 
229
 
230
 
231
  # the Gradio App