Spaces:
Runtime error
Runtime error
jhj0517 commited on
Commit ·
5929ef8
1
Parent(s): 5c7c82a
Refactor to `clean_files_with_extension()`
Browse files- modules/sam_inference.py +4 -4
- modules/video_utils.py +5 -16
modules/sam_inference.py
CHANGED
|
@@ -15,7 +15,7 @@ from modules.model_downloader import (
|
|
| 15 |
download_sam_model_url
|
| 16 |
)
|
| 17 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
|
| 18 |
-
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER
|
| 19 |
from modules.mask_utils import (
|
| 20 |
save_psd_with_masks,
|
| 21 |
create_mask_combined_images,
|
|
@@ -24,7 +24,7 @@ from modules.mask_utils import (
|
|
| 24 |
create_solid_color_mask_image
|
| 25 |
)
|
| 26 |
from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
|
| 27 |
-
extract_sound, clean_temp_dir,
|
| 28 |
from modules.utils import save_image
|
| 29 |
from modules.logger_util import get_logger
|
| 30 |
|
|
@@ -277,14 +277,14 @@ class SamInference:
|
|
| 277 |
logger.error(error_message)
|
| 278 |
raise gr.Error(error_message, duration=20)
|
| 279 |
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
|
| 283 |
|
| 284 |
point_labels, point_coords, box = self.handle_prompt_data(prompt)
|
| 285 |
obj_id = frame_idx
|
| 286 |
|
| 287 |
-
self.video_predictor.reset_state(self.video_inference_state)
|
| 288 |
idx, scores, logits = self.add_prediction_to_frame(
|
| 289 |
frame_idx=frame_idx,
|
| 290 |
obj_id=obj_id,
|
|
|
|
| 15 |
download_sam_model_url
|
| 16 |
)
|
| 17 |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR
|
| 18 |
+
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT
|
| 19 |
from modules.mask_utils import (
|
| 20 |
save_psd_with_masks,
|
| 21 |
create_mask_combined_images,
|
|
|
|
| 24 |
create_solid_color_mask_image
|
| 25 |
)
|
| 26 |
from modules.video_utils import (get_frames_from_dir, create_video_from_frames, get_video_info, extract_frames,
|
| 27 |
+
extract_sound, clean_temp_dir, clean_files_with_extension)
|
| 28 |
from modules.utils import save_image
|
| 29 |
from modules.logger_util import get_logger
|
| 30 |
|
|
|
|
| 277 |
logger.error(error_message)
|
| 278 |
raise gr.Error(error_message, duration=20)
|
| 279 |
|
| 280 |
+
clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
|
| 281 |
+
self.video_predictor.reset_state(self.video_inference_state)
|
| 282 |
|
| 283 |
prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
|
| 284 |
|
| 285 |
point_labels, point_coords, box = self.handle_prompt_data(prompt)
|
| 286 |
obj_id = frame_idx
|
| 287 |
|
|
|
|
| 288 |
idx, scores, logits = self.add_prediction_to_frame(
|
| 289 |
frame_idx=frame_idx,
|
| 290 |
obj_id=obj_id,
|
modules/video_utils.py
CHANGED
|
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|
| 7 |
import re
|
| 8 |
|
| 9 |
from modules.logger_util import get_logger
|
|
|
|
| 10 |
from modules.paths import TEMP_DIR, TEMP_OUT_DIR
|
| 11 |
|
| 12 |
logger = get_logger()
|
|
@@ -222,24 +223,12 @@ def clean_temp_dir(temp_dir: Optional[str] = None):
|
|
| 222 |
else:
|
| 223 |
temp_out_dir = os.path.join(temp_dir, "out")
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
| 229 |
|
| 230 |
-
def
|
| 231 |
-
"""Removes all sound files from the directory."""
|
| 232 |
-
sound_extensions = ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma']
|
| 233 |
-
_clean_files_with_extension(sound_dir, sound_extensions)
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def clean_image_files(image_dir: str):
|
| 237 |
-
"""Removes all image files from the dir"""
|
| 238 |
-
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
|
| 239 |
-
_clean_files_with_extension(image_dir, image_extensions)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def _clean_files_with_extension(dir_path: str, extensions: List):
|
| 243 |
for filename in os.listdir(dir_path):
|
| 244 |
if filename.lower().endswith(tuple(extensions)):
|
| 245 |
file_path = os.path.join(dir_path, filename)
|
|
|
|
| 7 |
import re
|
| 8 |
|
| 9 |
from modules.logger_util import get_logger
|
| 10 |
+
from modules.constants import SOUND_FILE_EXT, VIDEO_FILE_EXT, IMAGE_FILE_EXT
|
| 11 |
from modules.paths import TEMP_DIR, TEMP_OUT_DIR
|
| 12 |
|
| 13 |
logger = get_logger()
|
|
|
|
| 223 |
else:
|
| 224 |
temp_out_dir = os.path.join(temp_dir, "out")
|
| 225 |
|
| 226 |
+
clean_files_with_extension(temp_dir, SOUND_FILE_EXT)
|
| 227 |
+
clean_files_with_extension(temp_dir, IMAGE_FILE_EXT)
|
| 228 |
+
clean_files_with_extension(temp_out_dir, IMAGE_FILE_EXT)
|
| 229 |
|
| 230 |
|
| 231 |
+
def clean_files_with_extension(dir_path: str, extensions: List):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
for filename in os.listdir(dir_path):
|
| 233 |
if filename.lower().endswith(tuple(extensions)):
|
| 234 |
file_path = os.path.join(dir_path, filename)
|