Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
·
17abb6a
1
Parent(s):
a81c70a
Add docstring
Browse files- modules/sam_inference.py +121 -4
modules/sam_inference.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 2 |
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
| 3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 4 |
-
from typing import Dict, List, Optional
|
| 5 |
import torch
|
| 6 |
import os
|
| 7 |
from datetime import datetime
|
|
@@ -52,6 +52,13 @@ class SamInference:
|
|
| 52 |
def load_model(self,
|
| 53 |
model_type: Optional[str] = None,
|
| 54 |
load_video_predictor: bool = False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
if model_type is None:
|
| 56 |
model_type = DEFAULT_MODEL_TYPE
|
| 57 |
|
|
@@ -90,6 +97,13 @@ class SamInference:
|
|
| 90 |
def init_video_inference_state(self,
|
| 91 |
vid_input: str,
|
| 92 |
model_type: Optional[str] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
if model_type is None:
|
| 94 |
model_type = self.current_model_type
|
| 95 |
|
|
@@ -113,7 +127,19 @@ class SamInference:
|
|
| 113 |
def generate_mask(self,
|
| 114 |
image: np.ndarray,
|
| 115 |
model_type: str,
|
| 116 |
-
**params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
if self.model is None or self.current_model_type != model_type:
|
| 118 |
self.current_model_type = model_type
|
| 119 |
self.load_model(model_type=model_type)
|
|
@@ -134,7 +160,23 @@ class SamInference:
|
|
| 134 |
box: Optional[np.ndarray] = None,
|
| 135 |
point_coords: Optional[np.ndarray] = None,
|
| 136 |
point_labels: Optional[np.ndarray] = None,
|
| 137 |
-
**params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
if self.model is None or self.current_model_type != model_type:
|
| 139 |
self.current_model_type = model_type
|
| 140 |
self.load_model(model_type=model_type)
|
|
@@ -159,7 +201,24 @@ class SamInference:
|
|
| 159 |
inference_state: Optional[Dict] = None,
|
| 160 |
points: Optional[np.ndarray] = None,
|
| 161 |
labels: Optional[np.ndarray] = None,
|
| 162 |
-
box: Optional[np.ndarray] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if (self.video_predictor is None or
|
| 164 |
inference_state is None and self.video_inference_state is None):
|
| 165 |
logger.exception("Error while predicting frame from video, load video predictor first")
|
|
@@ -184,6 +243,18 @@ class SamInference:
|
|
| 184 |
|
| 185 |
def propagate_in_video(self,
|
| 186 |
inference_state: Optional[Dict] = None,):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
if inference_state is None and self.video_inference_state is None:
|
| 188 |
logger.exception("Error while propagating in video, load video predictor first")
|
| 189 |
|
|
@@ -219,6 +290,20 @@ class SamInference:
|
|
| 219 |
pixel_size: Optional[int] = None,
|
| 220 |
color_hex: Optional[str] = None,
|
| 221 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 223 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
| 224 |
raise f"Error while adding filter to preview"
|
|
@@ -262,6 +347,22 @@ class SamInference:
|
|
| 262 |
pixel_size: Optional[int] = None,
|
| 263 |
color_hex: Optional[str] = None
|
| 264 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 266 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
| 267 |
raise RuntimeError("Error while adding filter to preview")
|
|
@@ -321,6 +422,21 @@ class SamInference:
|
|
| 321 |
input_mode: str,
|
| 322 |
model_type: str,
|
| 323 |
*params):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 325 |
output_file_name = f"result-{timestamp}.psd"
|
| 326 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
|
@@ -378,6 +494,7 @@ class SamInference:
|
|
| 378 |
def format_to_auto_result(
|
| 379 |
masks: np.ndarray
|
| 380 |
):
|
|
|
|
| 381 |
place_holder = 0
|
| 382 |
if len(masks.shape) <= 3:
|
| 383 |
masks = np.expand_dims(masks, axis=0)
|
|
|
|
| 1 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 2 |
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
| 3 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 4 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 5 |
import torch
|
| 6 |
import os
|
| 7 |
from datetime import datetime
|
|
|
|
| 52 |
def load_model(self,
|
| 53 |
model_type: Optional[str] = None,
|
| 54 |
load_video_predictor: bool = False):
|
| 55 |
+
"""
|
| 56 |
+
Load the model from the model directory. If the model is not found, download it from the URL.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model_type (str): The model type to load.
|
| 60 |
+
load_video_predictor (bool): Load the video predictor model.
|
| 61 |
+
"""
|
| 62 |
if model_type is None:
|
| 63 |
model_type = DEFAULT_MODEL_TYPE
|
| 64 |
|
|
|
|
| 97 |
def init_video_inference_state(self,
|
| 98 |
vid_input: str,
|
| 99 |
model_type: Optional[str] = None):
|
| 100 |
+
"""
|
| 101 |
+
Initialize the video inference state for the video predictor.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
vid_input (str): The video frames directory.
|
| 105 |
+
model_type (str): The model type to load.
|
| 106 |
+
"""
|
| 107 |
if model_type is None:
|
| 108 |
model_type = self.current_model_type
|
| 109 |
|
|
|
|
| 127 |
def generate_mask(self,
|
| 128 |
image: np.ndarray,
|
| 129 |
model_type: str,
|
| 130 |
+
**params) -> List[Dict[str, Any]]:
|
| 131 |
+
"""
|
| 132 |
+
Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
image (np.ndarray): The input image.
|
| 136 |
+
model_type (str): The model type to load.
|
| 137 |
+
**params: The hyperparameters for the mask generator.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
List[Dict[str, Any]]: The auto-generated mask data.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
if self.model is None or self.current_model_type != model_type:
|
| 144 |
self.current_model_type = model_type
|
| 145 |
self.load_model(model_type=model_type)
|
|
|
|
| 160 |
box: Optional[np.ndarray] = None,
|
| 161 |
point_coords: Optional[np.ndarray] = None,
|
| 162 |
point_labels: Optional[np.ndarray] = None,
|
| 163 |
+
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 164 |
+
"""
|
| 165 |
+
Predict image with prompt data.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
image (np.ndarray): The input image.
|
| 169 |
+
model_type (str): The model type to load.
|
| 170 |
+
box (np.ndarray): The box prompt data.
|
| 171 |
+
point_coords (np.ndarray): The point coordinates prompt data.
|
| 172 |
+
point_labels (np.ndarray): The point labels prompt data.
|
| 173 |
+
**params: The hyperparameters for the mask generator.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
np.ndarray: The predicted masks output in CxHxW format.
|
| 177 |
+
np.ndarray: Array of scores for each mask.
|
| 178 |
+
np.ndarray: Array of logits in CxHxW format.
|
| 179 |
+
"""
|
| 180 |
if self.model is None or self.current_model_type != model_type:
|
| 181 |
self.current_model_type = model_type
|
| 182 |
self.load_model(model_type=model_type)
|
|
|
|
| 201 |
inference_state: Optional[Dict] = None,
|
| 202 |
points: Optional[np.ndarray] = None,
|
| 203 |
labels: Optional[np.ndarray] = None,
|
| 204 |
+
box: Optional[np.ndarray] = None) -> Tuple[int, int, torch.Tensor]:
|
| 205 |
+
"""
|
| 206 |
+
Add prediction to the current video inference state. inference state must be initialized before calling this method.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
frame_idx (int): The frame index of the video.
|
| 210 |
+
obj_id (int): The object id for the frame.
|
| 211 |
+
inference_state (Dict): The inference state for the video predictor.
|
| 212 |
+
points (np.ndarray): The point coordinates prompt data.
|
| 213 |
+
labels (np.ndarray): The point labels prompt data.
|
| 214 |
+
box (np.ndarray): The box prompt data.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
int: The frame index of the corresponding prediction.
|
| 218 |
+
int: The object id of the corresponding prediction.
|
| 219 |
+
torch.Tensor: The mask logits output in CxHxW format.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
if (self.video_predictor is None or
|
| 223 |
inference_state is None and self.video_inference_state is None):
|
| 224 |
logger.exception("Error while predicting frame from video, load video predictor first")
|
|
|
|
| 243 |
|
| 244 |
def propagate_in_video(self,
|
| 245 |
inference_state: Optional[Dict] = None,):
|
| 246 |
+
"""
|
| 247 |
+
Propagate in the video with the tracked predictions for each frame. Currently only supports
|
| 248 |
+
single frame tracking.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
inference_state (Dict): The inference state for the video predictor. Use self.video_inference_state if None.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Dict: The video segments with the image and mask data. It has frame index as each key and each key has
|
| 255 |
+
"image" and "mask" data. "image" key contains the path of the original image file and "mask" key contains
|
| 256 |
+
the np.ndarray mask output.
|
| 257 |
+
"""
|
| 258 |
if inference_state is None and self.video_inference_state is None:
|
| 259 |
logger.exception("Error while propagating in video, load video predictor first")
|
| 260 |
|
|
|
|
| 290 |
pixel_size: Optional[int] = None,
|
| 291 |
color_hex: Optional[str] = None,
|
| 292 |
):
|
| 293 |
+
"""
|
| 294 |
+
Add filter to the preview image with the prompt data. Specially made for gradio app.
|
| 295 |
+
It adds prediction tracking to the self.video_inference_state and returns the filtered image.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
image_prompt_input_data (Dict): The image prompt data.
|
| 299 |
+
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
|
| 300 |
+
frame_idx (int): The frame index of the video.
|
| 301 |
+
pixel_size (int): The pixel size for the pixelize filter.
|
| 302 |
+
color_hex (str): The color hex code for the solid color filter.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
np.ndarray: The filtered image output.
|
| 306 |
+
"""
|
| 307 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 308 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
| 309 |
raise f"Error while adding filter to preview"
|
|
|
|
| 347 |
pixel_size: Optional[int] = None,
|
| 348 |
color_hex: Optional[str] = None
|
| 349 |
):
|
| 350 |
+
"""
|
| 351 |
+
Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
|
| 352 |
+
This needs FFmpeg to run. Returns two output path because of the gradio app.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
image_prompt_input_data (Dict): The image prompt data.
|
| 356 |
+
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
|
| 357 |
+
frame_idx (int): The frame index of the video.
|
| 358 |
+
pixel_size (int): The pixel size for the pixelize filter.
|
| 359 |
+
color_hex (str): The color hex code for the solid color filter.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
str: The output video path.
|
| 363 |
+
str: The output video path.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
if self.video_predictor is None or self.video_inference_state is None:
|
| 367 |
logger.exception("Error while adding filter to preview, load video predictor first")
|
| 368 |
raise RuntimeError("Error while adding filter to preview")
|
|
|
|
| 422 |
input_mode: str,
|
| 423 |
model_type: str,
|
| 424 |
*params):
|
| 425 |
+
"""
|
| 426 |
+
Divide the layer with the given prompt data and save psd file.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
image_input (np.ndarray): The input image.
|
| 430 |
+
image_prompt_input_data (Dict): The image prompt data.
|
| 431 |
+
input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
|
| 432 |
+
model_type (str): The model type to load.
|
| 433 |
+
*params: The hyperparameters for the mask generator.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
List[np.ndarray]: List of images by predicted masks.
|
| 437 |
+
str: The output path of the psd file.
|
| 438 |
+
"""
|
| 439 |
+
|
| 440 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 441 |
output_file_name = f"result-{timestamp}.psd"
|
| 442 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
|
|
|
| 494 |
def format_to_auto_result(
|
| 495 |
masks: np.ndarray
|
| 496 |
):
|
| 497 |
+
"""Format the masks to auto result format for convenience."""
|
| 498 |
place_holder = 0
|
| 499 |
if len(masks.shape) <= 3:
|
| 500 |
masks = np.expand_dims(masks, axis=0)
|