File size: 8,205 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | import os
import re
import time
from pathlib import Path
from typing import Iterable
import numpy as np
import torch
import imageio.v2 as imageio
from PIL import Image, ImageOps
from shared.utils.audio_video import _get_codec_params
from shared.utils.utils import get_resampled_video_transparent, get_video_info_details, has_image_file_extension, rgb_bw_to_rgba_mask, sanitize_file_name
from shared.utils.virtual_media import get_virtual_image, strip_virtual_media_suffix
PROCESS_ID = "magic_mask"
PROCESS_NAME = "Magic Mask"
DOWNLOAD_REPO_ID = "DeepBeepMeep/Wan2.1"
DOWNLOAD_FOLDER = "sam3"
DOWNLOAD_FILES = ["sam3.1_multiplex_bf16.safetensors", "bpe_simple_vocab_16e6.txt.gz"]
DEFAULT_FILL_HOLE_AREA = 2
DEFAULT_POSTPROCESS_BATCH_SIZE = 1
OUTPUT_DIR = "mask_outputs"
def parse_keywords(keyword_text: str | Iterable[str]) -> list[str]:
if isinstance(keyword_text, str):
candidates = re.split(r"[\n,;]+", keyword_text)
else:
candidates = keyword_text
return [str(keyword).strip() for keyword in candidates if str(keyword).strip()]
def query_download_def():
return {"repoId": DOWNLOAD_REPO_ID, "sourceFolderList": [DOWNLOAD_FOLDER], "fileList": [list(DOWNLOAD_FILES)]}
def _fill_hole_area(no_hole):
return DEFAULT_FILL_HOLE_AREA if bool(no_hole) else 0
def _open_image(image):
if isinstance(image, dict):
image = image.get("path") or image.get("name") or image.get("orig_name")
virtual_image = get_virtual_image(image) if isinstance(image, str) else None
if virtual_image is not None:
image = virtual_image
elif isinstance(image, str):
image = Image.open(strip_virtual_media_suffix(image))
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if not isinstance(image, Image.Image):
raise ValueError("Magic Mask needs a control image.")
return ImageOps.exif_transpose(image).convert("RGB")
def _media_path(path):
if isinstance(path, dict):
path = path.get("path") or path.get("name") or path.get("orig_name")
return path
def _video_to_numpy(video_path):
video_path = _media_path(video_path)
if not video_path:
raise ValueError("Magic Mask needs a control video.")
if isinstance(video_path, str) and has_image_file_extension(video_path):
image = _open_image(video_path)
width, height = image.size
return np.asarray(image, dtype=np.uint8)[None], 1, width, height
details = get_video_info_details(video_path)
fps = details.get("fps_float") or details.get("fps") or 1
width = int(details.get("display_width") or details.get("width") or 0)
height = int(details.get("display_height") or details.get("height") or 0)
frame_count = int(details.get("frame_count") or 1)
frames = get_resampled_video_transparent(video_path, 0, frame_count, fps, bridge="torch")
if torch.is_tensor(frames):
frames = frames.detach().cpu().numpy()
elif hasattr(frames, "asnumpy"):
frames = frames.asnumpy()
else:
frames = np.asarray(frames)
if frames.ndim != 4 or frames.shape[0] == 0:
raise ValueError("Magic Mask could not read any control video frames.")
if frames.shape[-1] > 3:
frames = frames[..., :3]
if frames.shape[-1] == 1:
frames = np.repeat(frames, 3, axis=-1)
if width > 0 and height > 0 and frames.shape[1:3] != (height, width):
frames = np.stack([np.asarray(Image.fromarray(frame).resize((width, height), resample=Image.Resampling.LANCZOS)) for frame in frames], axis=0)
return frames.astype(np.uint8, copy=False), fps, width, height
def _run_sam3(video: np.ndarray, keywords: list[str], batch_size, no_hole, progress_callback=None) -> np.ndarray:
from preprocessing.sam3.preprocessor import run_sam3_video
with torch.inference_mode():
return run_sam3_video(
video,
keywords,
batched_grounding_batch_size=batch_size,
postprocess_batch_size=DEFAULT_POSTPROCESS_BATCH_SIZE,
use_batched_grounding=True,
fill_hole_area=_fill_hole_area(no_hole),
progress_callback=progress_callback,
)
def prepare_image_mask_input(image) -> tuple[Image.Image, np.ndarray]:
image = _open_image(image)
return image, np.asarray(image, dtype=np.uint8)[None]
def prepare_video_mask_input(video_path) -> tuple[str, np.ndarray, int]:
video_path = _media_path(video_path)
if not video_path:
raise ValueError("Magic Mask needs a control video.")
video, fps, _, _ = _video_to_numpy(video_path)
return video_path, video, fps
def generate_keyword_masks(video: np.ndarray, keyword_text: str | Iterable[str], *, batch_size=None, no_hole=True, progress_callback=None) -> np.ndarray:
keywords = parse_keywords(keyword_text)
if len(keywords) == 0:
return np.zeros(video.shape[:3], dtype=np.bool_)
return _run_sam3(video, keywords, batch_size, no_hole, progress_callback=progress_callback)
def merge_keyword_masks(current_mask: np.ndarray | None, keyword_mask: np.ndarray) -> np.ndarray:
keyword_mask = keyword_mask.astype(bool, copy=False)
return keyword_mask.copy() if current_mask is None else (current_mask | keyword_mask)
def finalize_masks(mask: np.ndarray, *, negative_mask=False) -> np.ndarray:
if negative_mask:
mask = ~mask
return mask
def mask_to_image(mask: np.ndarray) -> Image.Image:
return Image.fromarray(mask.astype(np.uint8) * 255, mode="L")
def _magic_mask_video_codec_params():
params = dict(_get_codec_params("libx264_10", "mp4"))
params["macro_block_size"] = 1
if params.get("pixelformat") == "yuv420p":
params["pixelformat"] = "yuv444p"
return params
def save_mask_video(video_path: str, masks: np.ndarray, fps: float, keywords: list[str], *, codec_type=None, output_dir=OUTPUT_DIR, abort_callback=None) -> str:
# codec_type is kept for compatibility; Magic Mask outputs are always MP4 libx264_10.
masks = masks.astype(np.uint8) * 255
mask_frames = np.repeat(masks[..., None], 3, axis=-1)
Path(output_dir).mkdir(parents=True, exist_ok=True)
stem = Path(strip_virtual_media_suffix(video_path)).stem
keywords_suffix = truncate_keywords_for_path(keywords)
output_path = Path(output_dir) / f"{sanitize_file_name(stem)}_magic_mask_{keywords_suffix}_{time.strftime('%Y%m%d_%H%M%S')}.mp4"
output_path = os.fspath(output_path)
writer = imageio.get_writer(output_path, fps=fps, ffmpeg_log_level="error", **_magic_mask_video_codec_params())
try:
for frame in mask_frames:
if abort_callback is not None:
abort_callback()
writer.append_data(frame)
finally:
writer.close()
return output_path
def generate_image_mask(image, keyword_text, *, batch_size=None, no_hole=True, negative_mask=False) -> tuple[Image.Image, Image.Image, list[str]]:
keywords = parse_keywords(keyword_text)
if len(keywords) == 0:
raise ValueError("Enter at least one keyword.")
image, video = prepare_image_mask_input(image)
mask = finalize_masks(_run_sam3(video, keywords, batch_size, no_hole)[0], negative_mask=negative_mask)
mask_image = mask_to_image(mask)
return image, mask_image, keywords
def generate_video_mask(video_path, keyword_text, *, batch_size=None, no_hole=True, negative_mask=False, codec_type=None, output_dir=OUTPUT_DIR) -> tuple[str, list[str]]:
keywords = parse_keywords(keyword_text)
if len(keywords) == 0:
raise ValueError("Enter at least one keyword.")
video_path, video, fps = prepare_video_mask_input(video_path)
masks = finalize_masks(_run_sam3(video, keywords, batch_size, no_hole), negative_mask=negative_mask)
return save_mask_video(video_path, masks, fps, keywords, output_dir=output_dir), keywords
def truncate_keywords_for_path(keywords: list[str]) -> str:
suffix = sanitize_file_name("_".join(keywords), "_").strip("_")
return suffix[:40] or "mask"
def build_image_editor_value(background: Image.Image, mask_image: Image.Image):
return {"background": background, "composite": None, "layers": [rgb_bw_to_rgba_mask(mask_image)]}
|