John Ho commited on
Commit
6a1ec3d
·
1 Parent(s): 41fa610

removed try-except for better debugging

Browse files
Files changed (1) hide show
  1. app.py +80 -83
app.py CHANGED
@@ -133,95 +133,92 @@ def video_inference(input_video, prompt: str):
133
  "detections": [],
134
  "status": "Missing video or prompt.",
135
  }
136
- try:
137
- # Gradio passes a dict with 'name' key for uploaded files
138
- video_path = (
139
- input_video
140
- if isinstance(input_video, str)
141
- else input_video.get("name", None)
142
- )
143
- if not video_path:
144
- return {
145
- "output_video": None,
146
- "detections": [],
147
- "status": "Invalid video input.",
148
- }
149
- video_cap = cv2.VideoCapture(video_path)
150
- vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
151
- vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
152
- vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
153
- video_frames = []
154
- while video_cap.isOpened():
155
- ret, frame = video_cap.read()
156
- if not ret:
157
- break
158
- video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
159
- video_cap.release()
160
- if len(video_frames) == 0:
161
- return {
162
- "output_video": None,
163
- "detections": [],
164
- "status": "No frames found in video.",
165
- }
166
- session = VID_PROCESSOR.init_video_session(
167
- video=video_frames, inference_device=DEVICE, dtype=DTYPE
168
- )
169
- session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
170
- temp_out_path = tempfile.mktemp(suffix=".mp4")
171
- video_writer = cv2.VideoWriter(
172
- temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
173
- )
174
-
175
- detections = []
176
- for model_out in VID_MODEL.propagate_in_video_iterator(
177
- inference_session=session, max_frame_num_to_track=len(video_frames)
178
- ):
179
- post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
180
- f_idx = model_out.frame_idx
181
- original_pil = Image.fromarray(video_frames[f_idx])
182
- frame_detections = []
183
- if "masks" in post_processed:
184
- detected_masks = post_processed["masks"]
185
- object_ids = post_processed["object_ids"]
186
- if detected_masks.ndim == 4:
187
- detected_masks = detected_masks.squeeze(1)
188
- # detected_masks: (num_objects, H, W)
189
- for i, mask in enumerate(detected_masks):
190
- mask = mask.cpu().numpy()
191
- mask_bin = (mask > 0.0).astype(np.uint8)
192
- xyxy = mask_to_xyxy(mask_bin)
193
- if not xyxy:
194
- continue
195
- x0, y0, x1, y1 = xyxy
196
- det = {
197
- "frame": f_idx,
198
- "track_id": int(object_ids[i]) if object_ids is not None else i,
199
- "x": x0 / vid_w,
200
- "y": y0 / vid_h,
201
- "w": (x1 - x0) / vid_w,
202
- "h": (y1 - y0) / vid_h,
203
- "conf": 1,
204
- "mask_b64": b64_mask_encode(mask_bin).decode("ascii"),
205
- }
206
- detections.append(det)
207
- final_frame = apply_mask_overlay(
208
- original_pil, detected_masks, object_ids=object_ids
209
- )
210
- else:
211
- final_frame = original_pil
212
- video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
213
- video_writer.release()
214
  return {
215
- "output_video": temp_out_path,
216
- "detections": detections,
217
- "status": "Video processing completed successfully.",
218
  }
219
- except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
220
  return {
221
  "output_video": None,
222
  "detections": [],
223
- "status": f"Error during video processing: {str(e)}",
224
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  # the Gradio App
 
133
  "detections": [],
134
  "status": "Missing video or prompt.",
135
  }
136
+ # try:
137
+ # Gradio passes a dict with 'name' key for uploaded files
138
+ video_path = (
139
+ input_video if isinstance(input_video, str) else input_video.get("name", None)
140
+ )
141
+ if not video_path:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  return {
143
+ "output_video": None,
144
+ "detections": [],
145
+ "status": "Invalid video input.",
146
  }
147
+ video_cap = cv2.VideoCapture(video_path)
148
+ vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
149
+ vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
150
+ vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
151
+ video_frames = []
152
+ while video_cap.isOpened():
153
+ ret, frame = video_cap.read()
154
+ if not ret:
155
+ break
156
+ video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
157
+ video_cap.release()
158
+ if len(video_frames) == 0:
159
  return {
160
  "output_video": None,
161
  "detections": [],
162
+ "status": "No frames found in video.",
163
  }
164
+ session = VID_PROCESSOR.init_video_session(
165
+ video=video_frames, inference_device=DEVICE, dtype=DTYPE
166
+ )
167
+ session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
168
+ temp_out_path = tempfile.mktemp(suffix=".mp4")
169
+ video_writer = cv2.VideoWriter(
170
+ temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
171
+ )
172
+
173
+ detections = []
174
+ for model_out in VID_MODEL.propagate_in_video_iterator(
175
+ inference_session=session, max_frame_num_to_track=len(video_frames)
176
+ ):
177
+ post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
178
+ f_idx = model_out.frame_idx
179
+ original_pil = Image.fromarray(video_frames[f_idx])
180
+ if "masks" in post_processed:
181
+ detected_masks = post_processed["masks"]
182
+ object_ids = post_processed["object_ids"]
183
+ if detected_masks.ndim == 4:
184
+ detected_masks = detected_masks.squeeze(1)
185
+ # detected_masks: (num_objects, H, W)
186
+ for i, mask in enumerate(detected_masks):
187
+ mask = mask.cpu().numpy()
188
+ mask_bin = (mask > 0.0).astype(np.uint8)
189
+ xyxy = mask_to_xyxy(mask_bin)
190
+ if not xyxy:
191
+ continue
192
+ x0, y0, x1, y1 = xyxy
193
+ det = {
194
+ "frame": f_idx,
195
+ "track_id": int(object_ids[i]) if object_ids is not None else i,
196
+ "x": x0 / vid_w,
197
+ "y": y0 / vid_h,
198
+ "w": (x1 - x0) / vid_w,
199
+ "h": (y1 - y0) / vid_h,
200
+ "conf": 1,
201
+ "mask_b64": b64_mask_encode(mask_bin).decode("ascii"),
202
+ }
203
+ detections.append(det)
204
+ final_frame = apply_mask_overlay(
205
+ original_pil, detected_masks, object_ids=object_ids
206
+ )
207
+ else:
208
+ final_frame = original_pil
209
+ video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
210
+ video_writer.release()
211
+ return {
212
+ "output_video": temp_out_path,
213
+ "detections": detections,
214
+ "status": "Video processing completed successfully.✅",
215
+ }
216
+ # except Exception as e:
217
+ # return {
218
+ # "output_video": None,
219
+ # "detections": [],
220
+ # "status": f"Error during video processing: {str(e)}",
221
+ # }
222
 
223
 
224
  # the Gradio App