John Ho commited on
Commit ·
6a1ec3d
1
Parent(s): 41fa610
removed try-except for better debugging
Browse files
app.py
CHANGED
|
@@ -133,95 +133,92 @@ def video_inference(input_video, prompt: str):
|
|
| 133 |
"detections": [],
|
| 134 |
"status": "Missing video or prompt.",
|
| 135 |
}
|
| 136 |
-
try:
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
)
|
| 143 |
-
if not video_path:
|
| 144 |
-
return {
|
| 145 |
-
"output_video": None,
|
| 146 |
-
"detections": [],
|
| 147 |
-
"status": "Invalid video input.",
|
| 148 |
-
}
|
| 149 |
-
video_cap = cv2.VideoCapture(video_path)
|
| 150 |
-
vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
|
| 151 |
-
vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 152 |
-
vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 153 |
-
video_frames = []
|
| 154 |
-
while video_cap.isOpened():
|
| 155 |
-
ret, frame = video_cap.read()
|
| 156 |
-
if not ret:
|
| 157 |
-
break
|
| 158 |
-
video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 159 |
-
video_cap.release()
|
| 160 |
-
if len(video_frames) == 0:
|
| 161 |
-
return {
|
| 162 |
-
"output_video": None,
|
| 163 |
-
"detections": [],
|
| 164 |
-
"status": "No frames found in video.",
|
| 165 |
-
}
|
| 166 |
-
session = VID_PROCESSOR.init_video_session(
|
| 167 |
-
video=video_frames, inference_device=DEVICE, dtype=DTYPE
|
| 168 |
-
)
|
| 169 |
-
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
|
| 170 |
-
temp_out_path = tempfile.mktemp(suffix=".mp4")
|
| 171 |
-
video_writer = cv2.VideoWriter(
|
| 172 |
-
temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
detections = []
|
| 176 |
-
for model_out in VID_MODEL.propagate_in_video_iterator(
|
| 177 |
-
inference_session=session, max_frame_num_to_track=len(video_frames)
|
| 178 |
-
):
|
| 179 |
-
post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
|
| 180 |
-
f_idx = model_out.frame_idx
|
| 181 |
-
original_pil = Image.fromarray(video_frames[f_idx])
|
| 182 |
-
frame_detections = []
|
| 183 |
-
if "masks" in post_processed:
|
| 184 |
-
detected_masks = post_processed["masks"]
|
| 185 |
-
object_ids = post_processed["object_ids"]
|
| 186 |
-
if detected_masks.ndim == 4:
|
| 187 |
-
detected_masks = detected_masks.squeeze(1)
|
| 188 |
-
# detected_masks: (num_objects, H, W)
|
| 189 |
-
for i, mask in enumerate(detected_masks):
|
| 190 |
-
mask = mask.cpu().numpy()
|
| 191 |
-
mask_bin = (mask > 0.0).astype(np.uint8)
|
| 192 |
-
xyxy = mask_to_xyxy(mask_bin)
|
| 193 |
-
if not xyxy:
|
| 194 |
-
continue
|
| 195 |
-
x0, y0, x1, y1 = xyxy
|
| 196 |
-
det = {
|
| 197 |
-
"frame": f_idx,
|
| 198 |
-
"track_id": int(object_ids[i]) if object_ids is not None else i,
|
| 199 |
-
"x": x0 / vid_w,
|
| 200 |
-
"y": y0 / vid_h,
|
| 201 |
-
"w": (x1 - x0) / vid_w,
|
| 202 |
-
"h": (y1 - y0) / vid_h,
|
| 203 |
-
"conf": 1,
|
| 204 |
-
"mask_b64": b64_mask_encode(mask_bin).decode("ascii"),
|
| 205 |
-
}
|
| 206 |
-
detections.append(det)
|
| 207 |
-
final_frame = apply_mask_overlay(
|
| 208 |
-
original_pil, detected_masks, object_ids=object_ids
|
| 209 |
-
)
|
| 210 |
-
else:
|
| 211 |
-
final_frame = original_pil
|
| 212 |
-
video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
|
| 213 |
-
video_writer.release()
|
| 214 |
return {
|
| 215 |
-
"output_video":
|
| 216 |
-
"detections":
|
| 217 |
-
"status": "
|
| 218 |
}
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
return {
|
| 221 |
"output_video": None,
|
| 222 |
"detections": [],
|
| 223 |
-
"status":
|
| 224 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
# the Gradio App
|
|
|
|
| 133 |
"detections": [],
|
| 134 |
"status": "Missing video or prompt.",
|
| 135 |
}
|
| 136 |
+
# try:
|
| 137 |
+
# Gradio passes a dict with 'name' key for uploaded files
|
| 138 |
+
video_path = (
|
| 139 |
+
input_video if isinstance(input_video, str) else input_video.get("name", None)
|
| 140 |
+
)
|
| 141 |
+
if not video_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
return {
|
| 143 |
+
"output_video": None,
|
| 144 |
+
"detections": [],
|
| 145 |
+
"status": "Invalid video input.",
|
| 146 |
}
|
| 147 |
+
video_cap = cv2.VideoCapture(video_path)
|
| 148 |
+
vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
|
| 149 |
+
vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 150 |
+
vid_h = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 151 |
+
video_frames = []
|
| 152 |
+
while video_cap.isOpened():
|
| 153 |
+
ret, frame = video_cap.read()
|
| 154 |
+
if not ret:
|
| 155 |
+
break
|
| 156 |
+
video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 157 |
+
video_cap.release()
|
| 158 |
+
if len(video_frames) == 0:
|
| 159 |
return {
|
| 160 |
"output_video": None,
|
| 161 |
"detections": [],
|
| 162 |
+
"status": "No frames found in video.",
|
| 163 |
}
|
| 164 |
+
session = VID_PROCESSOR.init_video_session(
|
| 165 |
+
video=video_frames, inference_device=DEVICE, dtype=DTYPE
|
| 166 |
+
)
|
| 167 |
+
session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=prompt)
|
| 168 |
+
temp_out_path = tempfile.mktemp(suffix=".mp4")
|
| 169 |
+
video_writer = cv2.VideoWriter(
|
| 170 |
+
temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
detections = []
|
| 174 |
+
for model_out in VID_MODEL.propagate_in_video_iterator(
|
| 175 |
+
inference_session=session, max_frame_num_to_track=len(video_frames)
|
| 176 |
+
):
|
| 177 |
+
post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
|
| 178 |
+
f_idx = model_out.frame_idx
|
| 179 |
+
original_pil = Image.fromarray(video_frames[f_idx])
|
| 180 |
+
if "masks" in post_processed:
|
| 181 |
+
detected_masks = post_processed["masks"]
|
| 182 |
+
object_ids = post_processed["object_ids"]
|
| 183 |
+
if detected_masks.ndim == 4:
|
| 184 |
+
detected_masks = detected_masks.squeeze(1)
|
| 185 |
+
# detected_masks: (num_objects, H, W)
|
| 186 |
+
for i, mask in enumerate(detected_masks):
|
| 187 |
+
mask = mask.cpu().numpy()
|
| 188 |
+
mask_bin = (mask > 0.0).astype(np.uint8)
|
| 189 |
+
xyxy = mask_to_xyxy(mask_bin)
|
| 190 |
+
if not xyxy:
|
| 191 |
+
continue
|
| 192 |
+
x0, y0, x1, y1 = xyxy
|
| 193 |
+
det = {
|
| 194 |
+
"frame": f_idx,
|
| 195 |
+
"track_id": int(object_ids[i]) if object_ids is not None else i,
|
| 196 |
+
"x": x0 / vid_w,
|
| 197 |
+
"y": y0 / vid_h,
|
| 198 |
+
"w": (x1 - x0) / vid_w,
|
| 199 |
+
"h": (y1 - y0) / vid_h,
|
| 200 |
+
"conf": 1,
|
| 201 |
+
"mask_b64": b64_mask_encode(mask_bin).decode("ascii"),
|
| 202 |
+
}
|
| 203 |
+
detections.append(det)
|
| 204 |
+
final_frame = apply_mask_overlay(
|
| 205 |
+
original_pil, detected_masks, object_ids=object_ids
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
final_frame = original_pil
|
| 209 |
+
video_writer.write(cv2.cvtColor(np.array(final_frame), cv2.COLOR_RGB2BGR))
|
| 210 |
+
video_writer.release()
|
| 211 |
+
return {
|
| 212 |
+
"output_video": temp_out_path,
|
| 213 |
+
"detections": detections,
|
| 214 |
+
"status": "Video processing completed successfully.✅",
|
| 215 |
+
}
|
| 216 |
+
# except Exception as e:
|
| 217 |
+
# return {
|
| 218 |
+
# "output_video": None,
|
| 219 |
+
# "detections": [],
|
| 220 |
+
# "status": f"Error during video processing: {str(e)}",
|
| 221 |
+
# }
|
| 222 |
|
| 223 |
|
| 224 |
# the Gradio App
|