Spaces:
Paused
Paused
Update modules/sam_inference.py
Browse files- modules/sam_inference.py +31 -116
modules/sam_inference.py
CHANGED
|
@@ -8,6 +8,8 @@ from datetime import datetime
|
|
| 8 |
import numpy as np
|
| 9 |
import gradio as gr
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from modules.model_downloader import (
|
| 12 |
AVAILABLE_MODELS, DEFAULT_MODEL_TYPE,
|
| 13 |
is_sam_exist,
|
|
@@ -21,7 +23,7 @@ from modules.mask_utils import (
|
|
| 21 |
create_mask_combined_images,
|
| 22 |
create_mask_gallery,
|
| 23 |
create_mask_pixelized_image,
|
| 24 |
-
create_solid_color_mask_image
|
| 25 |
)
|
| 26 |
from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
|
| 27 |
extract_sound, clean_temp_dir, clean_files_with_extension)
|
|
@@ -58,10 +60,6 @@ class SamInference:
|
|
| 58 |
load_video_predictor: bool = False):
|
| 59 |
"""
|
| 60 |
Load the model from the model directory. If the model is not found, download it from the URL.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
model_type (str): The model type to load.
|
| 64 |
-
load_video_predictor (bool): Load the video predictor model.
|
| 65 |
"""
|
| 66 |
if model_type is None:
|
| 67 |
model_type = DEFAULT_MODEL_TYPE
|
|
@@ -70,7 +68,6 @@ class SamInference:
|
|
| 70 |
config_dir, config_name = os.path.split(config_path)
|
| 71 |
|
| 72 |
filename, url = AVAILABLE_MODELS[model_type]
|
| 73 |
-
|
| 74 |
model_path = os.path.join(self.model_dir, filename)
|
| 75 |
|
| 76 |
if not is_sam_exist(model_dir=self.model_dir, model_type=model_type):
|
|
@@ -106,10 +103,6 @@ class SamInference:
|
|
| 106 |
model_type: Optional[str] = None):
|
| 107 |
"""
|
| 108 |
Initialize the video inference state for the video predictor.
|
| 109 |
-
|
| 110 |
-
Args:
|
| 111 |
-
vid_input (str): The video frames directory.
|
| 112 |
-
model_type (str): The model type to load.
|
| 113 |
"""
|
| 114 |
if model_type is None:
|
| 115 |
model_type = self.current_model_type
|
|
@@ -137,18 +130,8 @@ class SamInference:
|
|
| 137 |
invert_mask: bool = False,
|
| 138 |
**params) -> List[Dict[str, Any]]:
|
| 139 |
"""
|
| 140 |
-
Generate masks with Automatic segmentation.
|
| 141 |
-
|
| 142 |
-
Args:
|
| 143 |
-
image (np.ndarray): The input image.
|
| 144 |
-
model_type (str): The model type to load.
|
| 145 |
-
invert_mask (bool): Invert the mask output - used for background masking.
|
| 146 |
-
**params: The hyperparameters for the mask generator.
|
| 147 |
-
|
| 148 |
-
Returns:
|
| 149 |
-
List[Dict[str, Any]]: The auto-generated mask data.
|
| 150 |
"""
|
| 151 |
-
|
| 152 |
if self.model is None or self.current_model_type != model_type:
|
| 153 |
self.current_model_type = model_type
|
| 154 |
self.load_model(model_type=model_type)
|
|
@@ -178,20 +161,6 @@ class SamInference:
|
|
| 178 |
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 179 |
"""
|
| 180 |
Predict image with prompt data.
|
| 181 |
-
|
| 182 |
-
Args:
|
| 183 |
-
image (np.ndarray): The input image.
|
| 184 |
-
model_type (str): The model type to load.
|
| 185 |
-
box (np.ndarray): The box prompt data.
|
| 186 |
-
point_coords (np.ndarray): The point coordinates prompt data.
|
| 187 |
-
point_labels (np.ndarray): The point labels prompt data.
|
| 188 |
-
invert_mask (bool): Invert the mask output - used for background masking.
|
| 189 |
-
**params: The hyperparameters for the mask generator.
|
| 190 |
-
|
| 191 |
-
Returns:
|
| 192 |
-
np.ndarray: The predicted masks output in CxHxW format.
|
| 193 |
-
np.ndarray: Array of scores for each mask.
|
| 194 |
-
np.ndarray: Array of logits in CxHxW format.
|
| 195 |
"""
|
| 196 |
if self.model is None or self.current_model_type != model_type:
|
| 197 |
self.current_model_type = model_type
|
|
@@ -223,22 +192,8 @@ class SamInference:
|
|
| 223 |
labels: Optional[np.ndarray] = None,
|
| 224 |
box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
|
| 225 |
"""
|
| 226 |
-
Add prediction to the current video inference state.
|
| 227 |
-
|
| 228 |
-
Args:
|
| 229 |
-
frame_idx (int): The frame index of the video.
|
| 230 |
-
obj_id (int): The object id for the frame.
|
| 231 |
-
inference_state (Dict): The inference state for the video predictor.
|
| 232 |
-
points (np.ndarray): The point coordinates prompt data.
|
| 233 |
-
labels (np.ndarray): The point labels prompt data.
|
| 234 |
-
box (np.ndarray): The box prompt data.
|
| 235 |
-
|
| 236 |
-
Returns:
|
| 237 |
-
int: The frame index of the corresponding prediction.
|
| 238 |
-
int: The object id of the corresponding prediction.
|
| 239 |
-
torch.Tensor: The mask logits output in CxHxW format.
|
| 240 |
"""
|
| 241 |
-
|
| 242 |
if (self.video_predictor is None or
|
| 243 |
inference_state is None and self.video_inference_state is None):
|
| 244 |
logger.exception("Error while predicting frame from video, load video predictor first")
|
|
@@ -264,16 +219,7 @@ class SamInference:
|
|
| 264 |
def propagate_in_video(self,
|
| 265 |
inference_state: Optional[Dict] = None,):
|
| 266 |
"""
|
| 267 |
-
Propagate in the video with the tracked predictions for each frame.
|
| 268 |
-
single frame tracking.
|
| 269 |
-
|
| 270 |
-
Args:
|
| 271 |
-
inference_state (Dict): The inference state for the video predictor. Use self.video_inference_state if None.
|
| 272 |
-
|
| 273 |
-
Returns:
|
| 274 |
-
Dict: The video segments with the image and mask data. It has frame index as each key and each key has
|
| 275 |
-
"image" and "mask" data. "image" key contains the path of the original image file and "mask" key contains
|
| 276 |
-
the np.ndarray mask output.
|
| 277 |
"""
|
| 278 |
if inference_state is None and self.video_inference_state is None:
|
| 279 |
logger.exception("Error while propagating in video, load video predictor first")
|
|
@@ -312,19 +258,7 @@ class SamInference:
|
|
| 312 |
invert_mask: bool = False
|
| 313 |
):
|
| 314 |
"""
|
| 315 |
-
Add filter to the preview image with the prompt data.
|
| 316 |
-
It adds prediction tracking to the self.video_inference_state and returns the filtered image.
|
| 317 |
-
|
| 318 |
-
Args:
|
| 319 |
-
image_prompt_input_data (Dict): The image prompt data.
|
| 320 |
-
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
|
| 321 |
-
frame_idx (int): The frame index of the video.
|
| 322 |
-
pixel_size (int): The pixel size for the pixelize filter.
|
| 323 |
-
color_hex (str): The color hex code for the solid color filter.
|
| 324 |
-
invert_mask (bool): Invert the mask output - used for background masking.
|
| 325 |
-
|
| 326 |
-
Returns:
|
| 327 |
-
np.ndarray: The filtered image output.
|
| 328 |
"""
|
| 329 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 330 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
|
@@ -357,8 +291,18 @@ class SamInference:
|
|
| 357 |
|
| 358 |
generated_masks = self.format_to_auto_result(masks)
|
| 359 |
|
|
|
|
|
|
|
|
|
|
| 360 |
if filter_mode == COLOR_FILTER:
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
elif filter_mode == PIXELIZE_FILTER:
|
| 364 |
image = create_mask_pixelized_image(image, generated_masks, pixel_size)
|
|
@@ -374,22 +318,8 @@ class SamInference:
|
|
| 374 |
invert_mask: bool = False
|
| 375 |
):
|
| 376 |
"""
|
| 377 |
-
Create a whole filtered video
|
| 378 |
-
This needs FFmpeg to run. Returns two output path because of the gradio app.
|
| 379 |
-
|
| 380 |
-
Args:
|
| 381 |
-
image_prompt_input_data (Dict): The image prompt data.
|
| 382 |
-
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
|
| 383 |
-
frame_idx (int): The frame index of the video.
|
| 384 |
-
pixel_size (int): The pixel size for the pixelize filter.
|
| 385 |
-
color_hex (str): The color hex code for the solid color filter.
|
| 386 |
-
invert_mask (bool): Invert the mask output - used for background masking.
|
| 387 |
-
|
| 388 |
-
Returns:
|
| 389 |
-
str: The output video path.
|
| 390 |
-
str: The output video path.
|
| 391 |
"""
|
| 392 |
-
|
| 393 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 394 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
| 395 |
raise RuntimeError("Error while adding filter to preview")
|
|
@@ -399,13 +329,13 @@ class SamInference:
|
|
| 399 |
"Please press the eraser button (on the image prompter) and add your prompts again.")
|
| 400 |
logger.error(error_message)
|
| 401 |
raise gr.Error(error_message, duration=20)
|
|
|
|
| 402 |
output_dir = os.path.join(self.output_dir, "filter")
|
| 403 |
|
| 404 |
clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
|
| 405 |
self.video_predictor.reset_state(self.video_inference_state)
|
| 406 |
|
| 407 |
prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
|
| 408 |
-
|
| 409 |
point_labels, point_coords, box = self.handle_prompt_data(prompt)
|
| 410 |
obj_id = frame_idx
|
| 411 |
|
|
@@ -425,8 +355,16 @@ class SamInference:
|
|
| 425 |
masks = invert_masks(masks)
|
| 426 |
masks = self.format_to_auto_result(masks)
|
| 427 |
|
|
|
|
|
|
|
|
|
|
| 428 |
if filter_mode == COLOR_FILTER:
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
elif filter_mode == PIXELIZE_FILTER:
|
| 432 |
filtered_image = create_mask_pixelized_image(orig_image, masks, pixel_size)
|
|
@@ -454,24 +392,11 @@ class SamInference:
|
|
| 454 |
*params):
|
| 455 |
"""
|
| 456 |
Divide the layer with the given prompt data and save psd file.
|
| 457 |
-
|
| 458 |
-
Args:
|
| 459 |
-
image_input (np.ndarray): The input image.
|
| 460 |
-
image_prompt_input_data (Dict): The image prompt data.
|
| 461 |
-
input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
|
| 462 |
-
model_type (str): The model type to load.
|
| 463 |
-
invert_mask (bool): Invert the mask output.
|
| 464 |
-
*params: The hyperparameters for the mask generator.
|
| 465 |
-
|
| 466 |
-
Returns:
|
| 467 |
-
List[np.ndarray]: List of images by predicted masks.
|
| 468 |
-
str: The output path of the psd file.
|
| 469 |
"""
|
| 470 |
-
|
| 471 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 472 |
output_file_name = f"result-{timestamp}.psd"
|
| 473 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 474 |
-
|
| 475 |
hparams = {
|
| 476 |
'points_per_side': int(params[0]),
|
| 477 |
'points_per_batch': int(params[1]),
|
|
@@ -488,14 +413,12 @@ class SamInference:
|
|
| 488 |
|
| 489 |
if input_mode == AUTOMATIC_MODE:
|
| 490 |
image = image_input
|
| 491 |
-
|
| 492 |
generated_masks = self.generate_mask(
|
| 493 |
image=image,
|
| 494 |
model_type=model_type,
|
| 495 |
invert_mask=invert_mask,
|
| 496 |
**hparams
|
| 497 |
)
|
| 498 |
-
|
| 499 |
elif input_mode == BOX_PROMPT_MODE:
|
| 500 |
image = image_prompt_input_data["image"]
|
| 501 |
image = np.array(image.convert("RGB"))
|
|
@@ -540,14 +463,6 @@ class SamInference:
|
|
| 540 |
):
|
| 541 |
"""
|
| 542 |
Handle data from ImageInputPrompter.
|
| 543 |
-
|
| 544 |
-
Args:
|
| 545 |
-
prompt_data (Dict): A dictionary containing the 'prompt' key with a list of prompts.
|
| 546 |
-
|
| 547 |
-
Returns:
|
| 548 |
-
point_labels (List): list of points labels.
|
| 549 |
-
point_coords (List): list of points coords.
|
| 550 |
-
box (List): list of box datas.
|
| 551 |
"""
|
| 552 |
point_labels, point_coords, box = [], [], []
|
| 553 |
|
|
@@ -563,4 +478,4 @@ class SamInference:
|
|
| 563 |
point_coords = np.array(point_coords) if point_coords else None
|
| 564 |
box = np.array(box) if box else None
|
| 565 |
|
| 566 |
-
return point_labels, point_coords, box
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
+
from PIL import ImageColor # <-- We need this to convert hex color to (R,G,B)
|
| 12 |
+
|
| 13 |
from modules.model_downloader import (
|
| 14 |
AVAILABLE_MODELS, DEFAULT_MODEL_TYPE,
|
| 15 |
is_sam_exist,
|
|
|
|
| 23 |
create_mask_combined_images,
|
| 24 |
create_mask_gallery,
|
| 25 |
create_mask_pixelized_image,
|
| 26 |
+
# create_solid_color_mask_image <-- We won't call this anymore
|
| 27 |
)
|
| 28 |
from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
|
| 29 |
extract_sound, clean_temp_dir, clean_files_with_extension)
|
|
|
|
| 60 |
load_video_predictor: bool = False):
|
| 61 |
"""
|
| 62 |
Load the model from the model directory. If the model is not found, download it from the URL.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
if model_type is None:
|
| 65 |
model_type = DEFAULT_MODEL_TYPE
|
|
|
|
| 68 |
config_dir, config_name = os.path.split(config_path)
|
| 69 |
|
| 70 |
filename, url = AVAILABLE_MODELS[model_type]
|
|
|
|
| 71 |
model_path = os.path.join(self.model_dir, filename)
|
| 72 |
|
| 73 |
if not is_sam_exist(model_dir=self.model_dir, model_type=model_type):
|
|
|
|
| 103 |
model_type: Optional[str] = None):
|
| 104 |
"""
|
| 105 |
Initialize the video inference state for the video predictor.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
"""
|
| 107 |
if model_type is None:
|
| 108 |
model_type = self.current_model_type
|
|
|
|
| 130 |
invert_mask: bool = False,
|
| 131 |
**params) -> List[Dict[str, Any]]:
|
| 132 |
"""
|
| 133 |
+
Generate masks with Automatic segmentation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
|
|
|
| 135 |
if self.model is None or self.current_model_type != model_type:
|
| 136 |
self.current_model_type = model_type
|
| 137 |
self.load_model(model_type=model_type)
|
|
|
|
| 161 |
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 162 |
"""
|
| 163 |
Predict image with prompt data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
"""
|
| 165 |
if self.model is None or self.current_model_type != model_type:
|
| 166 |
self.current_model_type = model_type
|
|
|
|
| 192 |
labels: Optional[np.ndarray] = None,
|
| 193 |
box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
|
| 194 |
"""
|
| 195 |
+
Add prediction to the current video inference state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
"""
|
|
|
|
| 197 |
if (self.video_predictor is None or
|
| 198 |
inference_state is None and self.video_inference_state is None):
|
| 199 |
logger.exception("Error while predicting frame from video, load video predictor first")
|
|
|
|
| 219 |
def propagate_in_video(self,
|
| 220 |
inference_state: Optional[Dict] = None,):
|
| 221 |
"""
|
| 222 |
+
Propagate in the video with the tracked predictions for each frame.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
"""
|
| 224 |
if inference_state is None and self.video_inference_state is None:
|
| 225 |
logger.exception("Error while propagating in video, load video predictor first")
|
|
|
|
| 258 |
invert_mask: bool = False
|
| 259 |
):
|
| 260 |
"""
|
| 261 |
+
Add filter to the preview image with the prompt data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
"""
|
| 263 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 264 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
|
|
|
| 291 |
|
| 292 |
generated_masks = self.format_to_auto_result(masks)
|
| 293 |
|
| 294 |
+
# ---------------------------
|
| 295 |
+
# Modified solid color branch
|
| 296 |
+
# ---------------------------
|
| 297 |
if filter_mode == COLOR_FILTER:
|
| 298 |
+
# Make entire background black, fill the mask area with the chosen color
|
| 299 |
+
color_rgb = ImageColor.getcolor(color_hex, "RGB")
|
| 300 |
+
blacked = np.zeros_like(image, dtype=np.uint8)
|
| 301 |
+
# If there are multiple mask segments, fill them all with the chosen color
|
| 302 |
+
for m in generated_masks:
|
| 303 |
+
seg = m["segmentation"]
|
| 304 |
+
blacked[seg > 0] = color_rgb
|
| 305 |
+
image = blacked
|
| 306 |
|
| 307 |
elif filter_mode == PIXELIZE_FILTER:
|
| 308 |
image = create_mask_pixelized_image(image, generated_masks, pixel_size)
|
|
|
|
| 318 |
invert_mask: bool = False
|
| 319 |
):
|
| 320 |
"""
|
| 321 |
+
Create a whole filtered video. Currently only one-frame tracking is supported.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
"""
|
|
|
|
| 323 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 324 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
| 325 |
raise RuntimeError("Error while adding filter to preview")
|
|
|
|
| 329 |
"Please press the eraser button (on the image prompter) and add your prompts again.")
|
| 330 |
logger.error(error_message)
|
| 331 |
raise gr.Error(error_message, duration=20)
|
| 332 |
+
|
| 333 |
output_dir = os.path.join(self.output_dir, "filter")
|
| 334 |
|
| 335 |
clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
|
| 336 |
self.video_predictor.reset_state(self.video_inference_state)
|
| 337 |
|
| 338 |
prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
|
|
|
|
| 339 |
point_labels, point_coords, box = self.handle_prompt_data(prompt)
|
| 340 |
obj_id = frame_idx
|
| 341 |
|
|
|
|
| 355 |
masks = invert_masks(masks)
|
| 356 |
masks = self.format_to_auto_result(masks)
|
| 357 |
|
| 358 |
+
# ---------------------------
|
| 359 |
+
# Modified solid color branch
|
| 360 |
+
# ---------------------------
|
| 361 |
if filter_mode == COLOR_FILTER:
|
| 362 |
+
color_rgb = ImageColor.getcolor(color_hex, "RGB")
|
| 363 |
+
blacked = np.zeros_like(orig_image, dtype=np.uint8)
|
| 364 |
+
for m in masks:
|
| 365 |
+
seg = m["segmentation"]
|
| 366 |
+
blacked[seg > 0] = color_rgb
|
| 367 |
+
filtered_image = blacked
|
| 368 |
|
| 369 |
elif filter_mode == PIXELIZE_FILTER:
|
| 370 |
filtered_image = create_mask_pixelized_image(orig_image, masks, pixel_size)
|
|
|
|
| 392 |
*params):
|
| 393 |
"""
|
| 394 |
Divide the layer with the given prompt data and save psd file.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
"""
|
|
|
|
| 396 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 397 |
output_file_name = f"result-{timestamp}.psd"
|
| 398 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
| 399 |
+
|
| 400 |
hparams = {
|
| 401 |
'points_per_side': int(params[0]),
|
| 402 |
'points_per_batch': int(params[1]),
|
|
|
|
| 413 |
|
| 414 |
if input_mode == AUTOMATIC_MODE:
|
| 415 |
image = image_input
|
|
|
|
| 416 |
generated_masks = self.generate_mask(
|
| 417 |
image=image,
|
| 418 |
model_type=model_type,
|
| 419 |
invert_mask=invert_mask,
|
| 420 |
**hparams
|
| 421 |
)
|
|
|
|
| 422 |
elif input_mode == BOX_PROMPT_MODE:
|
| 423 |
image = image_prompt_input_data["image"]
|
| 424 |
image = np.array(image.convert("RGB"))
|
|
|
|
| 463 |
):
|
| 464 |
"""
|
| 465 |
Handle data from ImageInputPrompter.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
"""
|
| 467 |
point_labels, point_coords, box = [], [], []
|
| 468 |
|
|
|
|
| 478 |
point_coords = np.array(point_coords) if point_coords else None
|
| 479 |
box = np.array(box) if box else None
|
| 480 |
|
| 481 |
+
return point_labels, point_coords, box
|