John Ho commited on
Commit
2e155e5
·
1 Parent(s): 1345eda

updated demo json output

Browse files
Files changed (3) hide show
  1. app.py +49 -8
  2. toolbox/mask_encoding.py +43 -0
  3. visualizer.py +102 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sys
2
  import tempfile
3
 
@@ -14,6 +15,10 @@ from transformers import (
14
  Sam3VideoProcessor,
15
  )
16
 
 
 
 
 
17
  logger.remove()
18
  logger.add(
19
  sys.stderr,
@@ -100,7 +105,7 @@ def apply_mask_overlay(base_image, mask_data, object_ids=None, opacity=0.5):
100
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
101
 
102
 
103
- print("Loading Models and Processors...")
104
  try:
105
  VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(DEVICE, dtype=DTYPE)
106
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
@@ -113,18 +118,23 @@ except Exception as e:
113
 
114
  # Our Inference Function
115
  @spaces.GPU(duration=120)
116
- def video_inference(input_video, prompt):
117
  """
118
  Segments objects in a video using a text prompt.
119
- Returns a JSON with output video path and status.
120
  """
121
  if VID_MODEL is None or VID_PROCESSOR is None:
122
  return {
123
  "output_video": None,
 
124
  "status": "Video Models failed to load on startup.",
125
  }
126
  if input_video is None or not prompt:
127
- return {"output_video": None, "status": "Missing video or prompt."}
 
 
 
 
128
  try:
129
  # Gradio passes a dict with 'name' key for uploaded files
130
  video_path = (
@@ -133,7 +143,11 @@ def video_inference(input_video, prompt):
133
  else input_video.get("name", None)
134
  )
135
  if not video_path:
136
- return {"output_video": None, "status": "Invalid video input."}
 
 
 
 
137
  video_cap = cv2.VideoCapture(video_path)
138
  vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
139
  vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -146,7 +160,11 @@ def video_inference(input_video, prompt):
146
  video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
147
  video_cap.release()
148
  if len(video_frames) == 0:
149
- return {"output_video": None, "status": "No frames found in video."}
 
 
 
 
150
  session = VID_PROCESSOR.init_video_session(
151
  video=video_frames, inference_device=DEVICE, dtype=DTYPE
152
  )
@@ -155,17 +173,38 @@ def video_inference(input_video, prompt):
155
  video_writer = cv2.VideoWriter(
156
  temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
157
  )
 
 
158
  for model_out in VID_MODEL.propagate_in_video_iterator(
159
  inference_session=session, max_frame_num_to_track=len(video_frames)
160
  ):
161
  post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
162
  f_idx = model_out.frame_idx
163
  original_pil = Image.fromarray(video_frames[f_idx])
 
164
  if "masks" in post_processed:
165
  detected_masks = post_processed["masks"]
166
  object_ids = post_processed["object_ids"]
167
  if detected_masks.ndim == 4:
168
  detected_masks = detected_masks.squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  final_frame = apply_mask_overlay(
170
  original_pil, detected_masks, object_ids=object_ids
171
  )
@@ -175,11 +214,13 @@ def video_inference(input_video, prompt):
175
  video_writer.release()
176
  return {
177
  "output_video": temp_out_path,
 
178
  "status": "Video processing completed successfully.✅",
179
  }
180
  except Exception as e:
181
  return {
182
  "output_video": None,
 
183
  "status": f"Error during video processing: {str(e)}",
184
  }
185
 
@@ -192,8 +233,8 @@ app = gr.Interface(
192
  gr.Textbox(
193
  label="Prompt",
194
  lines=3,
195
- info="Some models like [cam motion](https://huggingface.co/chancharikm/qwen2.5-vl-7b-cam-motion-preview) are trained specific prompts",
196
- value="Describe the camera motion in this video.",
197
  ),
198
  ],
199
  outputs=gr.JSON(label="Output JSON"),
 
1
+ # Import helpers for mask encoding and bbox extraction
2
  import sys
3
  import tempfile
4
 
 
15
  Sam3VideoProcessor,
16
  )
17
 
18
+ # import local helpers
19
+ from toolbox.mask_encoding import b64_mask_encode
20
+ from visualizer import mask_to_xyxy
21
+
22
  logger.remove()
23
  logger.add(
24
  sys.stderr,
 
105
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
106
 
107
 
108
+ logger.info("Loading Models and Processors...")
109
  try:
110
  VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(DEVICE, dtype=DTYPE)
111
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
 
118
 
119
  # Our Inference Function
120
  @spaces.GPU(duration=120)
121
+ def video_inference(input_video, prompt: str):
122
  """
123
  Segments objects in a video using a text prompt.
124
+ Returns a list of detection dicts (one per object per frame) and output video path/status.
125
  """
126
  if VID_MODEL is None or VID_PROCESSOR is None:
127
  return {
128
  "output_video": None,
129
+ "detections": [],
130
  "status": "Video Models failed to load on startup.",
131
  }
132
  if input_video is None or not prompt:
133
+ return {
134
+ "output_video": None,
135
+ "detections": [],
136
+ "status": "Missing video or prompt.",
137
+ }
138
  try:
139
  # Gradio passes a dict with 'name' key for uploaded files
140
  video_path = (
 
143
  else input_video.get("name", None)
144
  )
145
  if not video_path:
146
+ return {
147
+ "output_video": None,
148
+ "detections": [],
149
+ "status": "Invalid video input.",
150
+ }
151
  video_cap = cv2.VideoCapture(video_path)
152
  vid_fps = video_cap.get(cv2.CAP_PROP_FPS)
153
  vid_w = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
160
  video_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
161
  video_cap.release()
162
  if len(video_frames) == 0:
163
+ return {
164
+ "output_video": None,
165
+ "detections": [],
166
+ "status": "No frames found in video.",
167
+ }
168
  session = VID_PROCESSOR.init_video_session(
169
  video=video_frames, inference_device=DEVICE, dtype=DTYPE
170
  )
 
173
  video_writer = cv2.VideoWriter(
174
  temp_out_path, cv2.VideoWriter_fourcc(*"mp4v"), vid_fps, (vid_w, vid_h)
175
  )
176
+
177
+ detections = []
178
  for model_out in VID_MODEL.propagate_in_video_iterator(
179
  inference_session=session, max_frame_num_to_track=len(video_frames)
180
  ):
181
  post_processed = VID_PROCESSOR.postprocess_outputs(session, model_out)
182
  f_idx = model_out.frame_idx
183
  original_pil = Image.fromarray(video_frames[f_idx])
184
+ frame_detections = []
185
  if "masks" in post_processed:
186
  detected_masks = post_processed["masks"]
187
  object_ids = post_processed["object_ids"]
188
  if detected_masks.ndim == 4:
189
  detected_masks = detected_masks.squeeze(1)
190
+ # detected_masks: (num_objects, H, W)
191
+ for i, mask in enumerate(detected_masks):
192
+ mask_bin = (mask > 0.0).astype(np.uint8)
193
+ xyxy = mask_to_xyxy(mask_bin)
194
+ if not xyxy:
195
+ continue
196
+ x0, y0, x1, y1 = xyxy
197
+ det = {
198
+ "frame": f_idx,
199
+ "track_id": int(object_ids[i]) if object_ids is not None else i,
200
+ "x": x0 / vid_w,
201
+ "y": y0 / vid_h,
202
+ "w": (x1 - x0) / vid_w,
203
+ "h": (y1 - y0) / vid_h,
204
+ "conf": 1,
205
+ "mask_b64": b64_mask_encode(mask_bin).decode("ascii"),
206
+ }
207
+ detections.append(det)
208
  final_frame = apply_mask_overlay(
209
  original_pil, detected_masks, object_ids=object_ids
210
  )
 
214
  video_writer.release()
215
  return {
216
  "output_video": temp_out_path,
217
+ "detections": detections,
218
  "status": "Video processing completed successfully.✅",
219
  }
220
  except Exception as e:
221
  return {
222
  "output_video": None,
223
+ "detections": [],
224
  "status": f"Error during video processing: {str(e)}",
225
  }
226
 
 
233
  gr.Textbox(
234
  label="Prompt",
235
  lines=3,
236
+ info="Describe the Object(s) you would like to track/ segmentate",
237
+ value="",
238
  ),
239
  ],
240
  outputs=gr.JSON(label="Output JSON"),
toolbox/mask_encoding.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64, os, io, random, time
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def b64_mask_encode(mask_np_arr, tmp_dir = '/tmp/miro/mask_encoding/'):
6
+ '''
7
+ turn a binary mask in numpy into a base64 string
8
+ '''
9
+ mask_im = Image.fromarray(np.array(mask_np_arr).astype(np.uint8)*255)
10
+ mask_im = mask_im.convert(mode = '1') # convert to 1bit image
11
+
12
+ if not os.path.isdir(tmp_dir):
13
+ print(f'b64_mask_encode: making tmp dir for mask encoding...')
14
+ os.makedirs(tmp_dir)
15
+
16
+ timestr = time.strftime("%Y%m%d-%H%M%S")
17
+ hash_str = random.getrandbits(128)
18
+ tmp_fname = tmp_dir + f'{timestr}_{hash_str}_mask.png'
19
+ mask_im.save(tmp_fname)
20
+ return base64.b64encode(open(tmp_fname, 'rb').read())
21
+
22
+ def b64_mask_decode(b64_string):
23
+ '''
24
+ decode a base64 string back to a binary mask numpy array
25
+ '''
26
+ im_bytes = base64.b64decode(b64_string)
27
+ im_decode = Image.open(io.BytesIO(im_bytes))
28
+ return np.array(im_decode)
29
+
30
+ def get_true_mask(mask_arr, im_w_h:tuple, x0, y0, x1, y1):
31
+ '''
32
+ decode the mask of CM output to get a mask that's the same size as source im
33
+ '''
34
+ if x0 > im_w_h[0] or x1 > im_w_h[0] or y0 > im_w_h[1] or y1 > im_w_h[1]:
35
+ raise ValueError(f'get_true_mask: Xs and Ys exceeded im_w_h bound: {im_w_h}')
36
+
37
+ if mask_arr.shape != (y1 - y0, x1 - x0):
38
+ raise ValueError(f'get_true_mask: Bounding Box h: {y1-y0} w: {x1-x0} does not match mask shape: {mask_arr.shape}')
39
+
40
+ w, h = im_w_h
41
+ mask = np.zeros((h,w), dtype = np.uint8)
42
+ mask[y0:y1, x0:x1] = mask_arr
43
+ return mask
visualizer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageColor
2
+
3
+ # import matplotlib.colors as mcolors
4
+ import numpy as np
5
+
6
+ # from toolbox.mask_encoding import b64_mask_decode
7
+ # from toolbox.img_utils import im_draw_bbox, im_draw_point, im_color_mask
8
+
9
+
10
+ def mask_to_xyxy(mask: np.ndarray, verbose: bool = False) -> tuple:
11
+ """Convert a binary mask of shape (h, w) to
12
+ xyxy bounding box format (top-left and bottom-right coordinates).
13
+ """
14
+ ys, xs = np.where(mask)
15
+ if len(xs) == 0 or len(ys) == 0:
16
+ if verbose:
17
+ logger.warning("mask_to_xyxy: No object found in the mask")
18
+ return None
19
+ x_min = np.min(xs)
20
+ y_min = np.min(ys)
21
+ x_max = np.max(xs)
22
+ y_max = np.max(ys)
23
+ xyxy = (x_min, y_min, x_max, y_max)
24
+ xyxy = tuple([int(i) for i in xyxy])
25
+ return xyxy
26
+
27
+
28
+ def annotate_detections(
29
+ im: Image.Image,
30
+ l_obj: list,
31
+ color_key: str = "class",
32
+ bbox_width: int = 1,
33
+ label_key: str = "object_id",
34
+ color_dict: dict = {},
35
+ ):
36
+ # color_list is a list of tuple(name, color_hex)
37
+ color_list = list(
38
+ mcolors.XKCD_COLORS.items()
39
+ ) # list(mcolors.TABLEAU_COLORS.items())
40
+ unique_color_keys = list(
41
+ set([o[color_key] for o in l_obj if color_key in o.keys()])
42
+ )
43
+
44
+ for obj in l_obj:
45
+ color_index = unique_color_keys.index(obj[color_key])
46
+ bbox_color = (
47
+ color_dict[obj[color_key]] if color_dict else color_list[color_index][1]
48
+ )
49
+ im = (
50
+ im_draw_bbox(
51
+ im,
52
+ color=bbox_color,
53
+ width=bbox_width,
54
+ caption=(str(obj[label_key]) if label_key else None),
55
+ **obj["boundingBox"],
56
+ use_bbv=True,
57
+ )
58
+ if "boundingBox" in obj.keys()
59
+ else im_draw_point(
60
+ im,
61
+ **obj["point"],
62
+ width=bbox_width,
63
+ caption=(str(obj[label_key]) if label_key else None),
64
+ color=bbox_color,
65
+ )
66
+ )
67
+ return im
68
+
69
+
70
+ def annotate_masks(
71
+ im: Image.Image, masks: list, mask_alpha: float = 0.9, bbox_width: int = 3
72
+ ) -> Image.Image:
73
+ """returns an annotated pillow image"""
74
+ masks = [
75
+ b64_mask_decode(m).astype(np.uint8) if isinstance(m, str) else m for m in masks
76
+ ]
77
+ segs = []
78
+ for i, m in enumerate(masks):
79
+ x0, y0, x1, y1 = mask_to_xyxy(m)
80
+ segs.append(
81
+ {
82
+ "object_id": i,
83
+ "boundingBox": {"x0": x0, "y0": y0, "x1": x1, "y1": y1},
84
+ }
85
+ )
86
+ ann_im = np.array(im)
87
+ for i, m in enumerate(masks):
88
+ m_color = list(mcolors.XKCD_COLORS.items())[i]
89
+ ann_im = im_color_mask(
90
+ ann_im,
91
+ mask_array=m,
92
+ alpha=mask_alpha,
93
+ rbg_tup=ImageColor.getrgb(m_color[1]),
94
+ )
95
+ ann_im = annotate_detections(
96
+ ann_im,
97
+ l_obj=segs,
98
+ color_key="object_id",
99
+ label_key="object_id",
100
+ bbox_width=bbox_width,
101
+ )
102
+ return ann_im