RedmondHosting commited on
Commit
6691005
·
verified ·
1 Parent(s): ed7d6af

Update modules/sam_inference.py

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +31 -116
modules/sam_inference.py CHANGED
@@ -8,6 +8,8 @@ from datetime import datetime
8
  import numpy as np
9
  import gradio as gr
10
 
 
 
11
  from modules.model_downloader import (
12
  AVAILABLE_MODELS, DEFAULT_MODEL_TYPE,
13
  is_sam_exist,
@@ -21,7 +23,7 @@ from modules.mask_utils import (
21
  create_mask_combined_images,
22
  create_mask_gallery,
23
  create_mask_pixelized_image,
24
- create_solid_color_mask_image
25
  )
26
  from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
27
  extract_sound, clean_temp_dir, clean_files_with_extension)
@@ -58,10 +60,6 @@ class SamInference:
58
  load_video_predictor: bool = False):
59
  """
60
  Load the model from the model directory. If the model is not found, download it from the URL.
61
-
62
- Args:
63
- model_type (str): The model type to load.
64
- load_video_predictor (bool): Load the video predictor model.
65
  """
66
  if model_type is None:
67
  model_type = DEFAULT_MODEL_TYPE
@@ -70,7 +68,6 @@ class SamInference:
70
  config_dir, config_name = os.path.split(config_path)
71
 
72
  filename, url = AVAILABLE_MODELS[model_type]
73
-
74
  model_path = os.path.join(self.model_dir, filename)
75
 
76
  if not is_sam_exist(model_dir=self.model_dir, model_type=model_type):
@@ -106,10 +103,6 @@ class SamInference:
106
  model_type: Optional[str] = None):
107
  """
108
  Initialize the video inference state for the video predictor.
109
-
110
- Args:
111
- vid_input (str): The video frames directory.
112
- model_type (str): The model type to load.
113
  """
114
  if model_type is None:
115
  model_type = self.current_model_type
@@ -137,18 +130,8 @@ class SamInference:
137
  invert_mask: bool = False,
138
  **params) -> List[Dict[str, Any]]:
139
  """
140
- Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
141
-
142
- Args:
143
- image (np.ndarray): The input image.
144
- model_type (str): The model type to load.
145
- invert_mask (bool): Invert the mask output - used for background masking.
146
- **params: The hyperparameters for the mask generator.
147
-
148
- Returns:
149
- List[Dict[str, Any]]: The auto-generated mask data.
150
  """
151
-
152
  if self.model is None or self.current_model_type != model_type:
153
  self.current_model_type = model_type
154
  self.load_model(model_type=model_type)
@@ -178,20 +161,6 @@ class SamInference:
178
  **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
179
  """
180
  Predict image with prompt data.
181
-
182
- Args:
183
- image (np.ndarray): The input image.
184
- model_type (str): The model type to load.
185
- box (np.ndarray): The box prompt data.
186
- point_coords (np.ndarray): The point coordinates prompt data.
187
- point_labels (np.ndarray): The point labels prompt data.
188
- invert_mask (bool): Invert the mask output - used for background masking.
189
- **params: The hyperparameters for the mask generator.
190
-
191
- Returns:
192
- np.ndarray: The predicted masks output in CxHxW format.
193
- np.ndarray: Array of scores for each mask.
194
- np.ndarray: Array of logits in CxHxW format.
195
  """
196
  if self.model is None or self.current_model_type != model_type:
197
  self.current_model_type = model_type
@@ -223,22 +192,8 @@ class SamInference:
223
  labels: Optional[np.ndarray] = None,
224
  box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
225
  """
226
- Add prediction to the current video inference state. inference state must be initialized before calling this method.
227
-
228
- Args:
229
- frame_idx (int): The frame index of the video.
230
- obj_id (int): The object id for the frame.
231
- inference_state (Dict): The inference state for the video predictor.
232
- points (np.ndarray): The point coordinates prompt data.
233
- labels (np.ndarray): The point labels prompt data.
234
- box (np.ndarray): The box prompt data.
235
-
236
- Returns:
237
- int: The frame index of the corresponding prediction.
238
- int: The object id of the corresponding prediction.
239
- torch.Tensor: The mask logits output in CxHxW format.
240
  """
241
-
242
  if (self.video_predictor is None or
243
  inference_state is None and self.video_inference_state is None):
244
  logger.exception("Error while predicting frame from video, load video predictor first")
@@ -264,16 +219,7 @@ class SamInference:
264
  def propagate_in_video(self,
265
  inference_state: Optional[Dict] = None,):
266
  """
267
- Propagate in the video with the tracked predictions for each frame. Currently only supports
268
- single frame tracking.
269
-
270
- Args:
271
- inference_state (Dict): The inference state for the video predictor. Use self.video_inference_state if None.
272
-
273
- Returns:
274
- Dict: The video segments with the image and mask data. It has frame index as each key and each key has
275
- "image" and "mask" data. "image" key contains the path of the original image file and "mask" key contains
276
- the np.ndarray mask output.
277
  """
278
  if inference_state is None and self.video_inference_state is None:
279
  logger.exception("Error while propagating in video, load video predictor first")
@@ -312,19 +258,7 @@ class SamInference:
312
  invert_mask: bool = False
313
  ):
314
  """
315
- Add filter to the preview image with the prompt data. Specially made for gradio app.
316
- It adds prediction tracking to the self.video_inference_state and returns the filtered image.
317
-
318
- Args:
319
- image_prompt_input_data (Dict): The image prompt data.
320
- filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
321
- frame_idx (int): The frame index of the video.
322
- pixel_size (int): The pixel size for the pixelize filter.
323
- color_hex (str): The color hex code for the solid color filter.
324
- invert_mask (bool): Invert the mask output - used for background masking.
325
-
326
- Returns:
327
- np.ndarray: The filtered image output.
328
  """
329
  if self.video_predictor is None or self.video_inference_state is None:
330
  logger.exception("Error while adding filter to preview, load video predictor first")
@@ -357,8 +291,18 @@ class SamInference:
357
 
358
  generated_masks = self.format_to_auto_result(masks)
359
 
 
 
 
360
  if filter_mode == COLOR_FILTER:
361
- image = create_solid_color_mask_image(image, generated_masks, color_hex)
 
 
 
 
 
 
 
362
 
363
  elif filter_mode == PIXELIZE_FILTER:
364
  image = create_mask_pixelized_image(image, generated_masks, pixel_size)
@@ -374,22 +318,8 @@ class SamInference:
374
  invert_mask: bool = False
375
  ):
376
  """
377
- Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
378
- This needs FFmpeg to run. Returns two output path because of the gradio app.
379
-
380
- Args:
381
- image_prompt_input_data (Dict): The image prompt data.
382
- filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
383
- frame_idx (int): The frame index of the video.
384
- pixel_size (int): The pixel size for the pixelize filter.
385
- color_hex (str): The color hex code for the solid color filter.
386
- invert_mask (bool): Invert the mask output - used for background masking.
387
-
388
- Returns:
389
- str: The output video path.
390
- str: The output video path.
391
  """
392
-
393
  if self.video_predictor is None or self.video_inference_state is None:
394
  logger.exception("Error while adding filter to preview, load video predictor first")
395
  raise RuntimeError("Error while adding filter to preview")
@@ -399,13 +329,13 @@ class SamInference:
399
  "Please press the eraser button (on the image prompter) and add your prompts again.")
400
  logger.error(error_message)
401
  raise gr.Error(error_message, duration=20)
 
402
  output_dir = os.path.join(self.output_dir, "filter")
403
 
404
  clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
405
  self.video_predictor.reset_state(self.video_inference_state)
406
 
407
  prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
408
-
409
  point_labels, point_coords, box = self.handle_prompt_data(prompt)
410
  obj_id = frame_idx
411
 
@@ -425,8 +355,16 @@ class SamInference:
425
  masks = invert_masks(masks)
426
  masks = self.format_to_auto_result(masks)
427
 
 
 
 
428
  if filter_mode == COLOR_FILTER:
429
- filtered_image = create_solid_color_mask_image(orig_image, masks, color_hex)
 
 
 
 
 
430
 
431
  elif filter_mode == PIXELIZE_FILTER:
432
  filtered_image = create_mask_pixelized_image(orig_image, masks, pixel_size)
@@ -454,24 +392,11 @@ class SamInference:
454
  *params):
455
  """
456
  Divide the layer with the given prompt data and save psd file.
457
-
458
- Args:
459
- image_input (np.ndarray): The input image.
460
- image_prompt_input_data (Dict): The image prompt data.
461
- input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
462
- model_type (str): The model type to load.
463
- invert_mask (bool): Invert the mask output.
464
- *params: The hyperparameters for the mask generator.
465
-
466
- Returns:
467
- List[np.ndarray]: List of images by predicted masks.
468
- str: The output path of the psd file.
469
  """
470
-
471
  timestamp = datetime.now().strftime("%m%d%H%M%S")
472
  output_file_name = f"result-{timestamp}.psd"
473
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
474
- # Pre-processed gradio components
475
  hparams = {
476
  'points_per_side': int(params[0]),
477
  'points_per_batch': int(params[1]),
@@ -488,14 +413,12 @@ class SamInference:
488
 
489
  if input_mode == AUTOMATIC_MODE:
490
  image = image_input
491
-
492
  generated_masks = self.generate_mask(
493
  image=image,
494
  model_type=model_type,
495
  invert_mask=invert_mask,
496
  **hparams
497
  )
498
-
499
  elif input_mode == BOX_PROMPT_MODE:
500
  image = image_prompt_input_data["image"]
501
  image = np.array(image.convert("RGB"))
@@ -540,14 +463,6 @@ class SamInference:
540
  ):
541
  """
542
  Handle data from ImageInputPrompter.
543
-
544
- Args:
545
- prompt_data (Dict): A dictionary containing the 'prompt' key with a list of prompts.
546
-
547
- Returns:
548
- point_labels (List): list of points labels.
549
- point_coords (List): list of points coords.
550
- box (List): list of box datas.
551
  """
552
  point_labels, point_coords, box = [], [], []
553
 
@@ -563,4 +478,4 @@ class SamInference:
563
  point_coords = np.array(point_coords) if point_coords else None
564
  box = np.array(box) if box else None
565
 
566
- return point_labels, point_coords, box
 
8
  import numpy as np
9
  import gradio as gr
10
 
11
+ from PIL import ImageColor # <-- We need this to convert hex color to (R,G,B)
12
+
13
  from modules.model_downloader import (
14
  AVAILABLE_MODELS, DEFAULT_MODEL_TYPE,
15
  is_sam_exist,
 
23
  create_mask_combined_images,
24
  create_mask_gallery,
25
  create_mask_pixelized_image,
26
+ # create_solid_color_mask_image <-- We won't call this anymore
27
  )
28
  from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
29
  extract_sound, clean_temp_dir, clean_files_with_extension)
 
60
  load_video_predictor: bool = False):
61
  """
62
  Load the model from the model directory. If the model is not found, download it from the URL.
 
 
 
 
63
  """
64
  if model_type is None:
65
  model_type = DEFAULT_MODEL_TYPE
 
68
  config_dir, config_name = os.path.split(config_path)
69
 
70
  filename, url = AVAILABLE_MODELS[model_type]
 
71
  model_path = os.path.join(self.model_dir, filename)
72
 
73
  if not is_sam_exist(model_dir=self.model_dir, model_type=model_type):
 
103
  model_type: Optional[str] = None):
104
  """
105
  Initialize the video inference state for the video predictor.
 
 
 
 
106
  """
107
  if model_type is None:
108
  model_type = self.current_model_type
 
130
  invert_mask: bool = False,
131
  **params) -> List[Dict[str, Any]]:
132
  """
133
+ Generate masks with Automatic segmentation.
 
 
 
 
 
 
 
 
 
134
  """
 
135
  if self.model is None or self.current_model_type != model_type:
136
  self.current_model_type = model_type
137
  self.load_model(model_type=model_type)
 
161
  **params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
162
  """
163
  Predict image with prompt data.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  """
165
  if self.model is None or self.current_model_type != model_type:
166
  self.current_model_type = model_type
 
192
  labels: Optional[np.ndarray] = None,
193
  box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
194
  """
195
+ Add prediction to the current video inference state.
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  """
 
197
  if (self.video_predictor is None or
198
  inference_state is None and self.video_inference_state is None):
199
  logger.exception("Error while predicting frame from video, load video predictor first")
 
219
  def propagate_in_video(self,
220
  inference_state: Optional[Dict] = None,):
221
  """
222
+ Propagate in the video with the tracked predictions for each frame.
 
 
 
 
 
 
 
 
 
223
  """
224
  if inference_state is None and self.video_inference_state is None:
225
  logger.exception("Error while propagating in video, load video predictor first")
 
258
  invert_mask: bool = False
259
  ):
260
  """
261
+ Add filter to the preview image with the prompt data.
 
 
 
 
 
 
 
 
 
 
 
 
262
  """
263
  if self.video_predictor is None or self.video_inference_state is None:
264
  logger.exception("Error while adding filter to preview, load video predictor first")
 
291
 
292
  generated_masks = self.format_to_auto_result(masks)
293
 
294
+ # ---------------------------
295
+ # Modified solid color branch
296
+ # ---------------------------
297
  if filter_mode == COLOR_FILTER:
298
+ # Make entire background black, fill the mask area with the chosen color
299
+ color_rgb = ImageColor.getcolor(color_hex, "RGB")
300
+ blacked = np.zeros_like(image, dtype=np.uint8)
301
+ # If there are multiple mask segments, fill them all with the chosen color
302
+ for m in generated_masks:
303
+ seg = m["segmentation"]
304
+ blacked[seg > 0] = color_rgb
305
+ image = blacked
306
 
307
  elif filter_mode == PIXELIZE_FILTER:
308
  image = create_mask_pixelized_image(image, generated_masks, pixel_size)
 
318
  invert_mask: bool = False
319
  ):
320
  """
321
+ Create a whole filtered video. Currently only one-frame tracking is supported.
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  """
 
323
  if self.video_predictor is None or self.video_inference_state is None:
324
  logger.exception("Error while adding filter to preview, load video predictor first")
325
  raise RuntimeError("Error while adding filter to preview")
 
329
  "Please press the eraser button (on the image prompter) and add your prompts again.")
330
  logger.error(error_message)
331
  raise gr.Error(error_message, duration=20)
332
+
333
  output_dir = os.path.join(self.output_dir, "filter")
334
 
335
  clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
336
  self.video_predictor.reset_state(self.video_inference_state)
337
 
338
  prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
 
339
  point_labels, point_coords, box = self.handle_prompt_data(prompt)
340
  obj_id = frame_idx
341
 
 
355
  masks = invert_masks(masks)
356
  masks = self.format_to_auto_result(masks)
357
 
358
+ # ---------------------------
359
+ # Modified solid color branch
360
+ # ---------------------------
361
  if filter_mode == COLOR_FILTER:
362
+ color_rgb = ImageColor.getcolor(color_hex, "RGB")
363
+ blacked = np.zeros_like(orig_image, dtype=np.uint8)
364
+ for m in masks:
365
+ seg = m["segmentation"]
366
+ blacked[seg > 0] = color_rgb
367
+ filtered_image = blacked
368
 
369
  elif filter_mode == PIXELIZE_FILTER:
370
  filtered_image = create_mask_pixelized_image(orig_image, masks, pixel_size)
 
392
  *params):
393
  """
394
  Divide the layer with the given prompt data and save psd file.
 
 
 
 
 
 
 
 
 
 
 
 
395
  """
 
396
  timestamp = datetime.now().strftime("%m%d%H%M%S")
397
  output_file_name = f"result-{timestamp}.psd"
398
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
399
+
400
  hparams = {
401
  'points_per_side': int(params[0]),
402
  'points_per_batch': int(params[1]),
 
413
 
414
  if input_mode == AUTOMATIC_MODE:
415
  image = image_input
 
416
  generated_masks = self.generate_mask(
417
  image=image,
418
  model_type=model_type,
419
  invert_mask=invert_mask,
420
  **hparams
421
  )
 
422
  elif input_mode == BOX_PROMPT_MODE:
423
  image = image_prompt_input_data["image"]
424
  image = np.array(image.convert("RGB"))
 
463
  ):
464
  """
465
  Handle data from ImageInputPrompter.
 
 
 
 
 
 
 
 
466
  """
467
  point_labels, point_coords, box = [], [], []
468
 
 
478
  point_coords = np.array(point_coords) if point_coords else None
479
  box = np.array(box) if box else None
480
 
481
+ return point_labels, point_coords, box