import os import re import time from pathlib import Path from typing import Iterable import numpy as np import torch import imageio.v2 as imageio from PIL import Image, ImageOps from shared.utils.audio_video import _get_codec_params from shared.utils.utils import get_resampled_video_transparent, get_video_info_details, has_image_file_extension, rgb_bw_to_rgba_mask, sanitize_file_name from shared.utils.virtual_media import get_virtual_image, strip_virtual_media_suffix PROCESS_ID = "magic_mask" PROCESS_NAME = "Magic Mask" DOWNLOAD_REPO_ID = "DeepBeepMeep/Wan2.1" DOWNLOAD_FOLDER = "sam3" DOWNLOAD_FILES = ["sam3.1_multiplex_bf16.safetensors", "bpe_simple_vocab_16e6.txt.gz"] DEFAULT_FILL_HOLE_AREA = 2 DEFAULT_POSTPROCESS_BATCH_SIZE = 1 OUTPUT_DIR = "mask_outputs" def parse_keywords(keyword_text: str | Iterable[str]) -> list[str]: if isinstance(keyword_text, str): candidates = re.split(r"[\n,;]+", keyword_text) else: candidates = keyword_text return [str(keyword).strip() for keyword in candidates if str(keyword).strip()] def query_download_def(): return {"repoId": DOWNLOAD_REPO_ID, "sourceFolderList": [DOWNLOAD_FOLDER], "fileList": [list(DOWNLOAD_FILES)]} def _fill_hole_area(no_hole): return DEFAULT_FILL_HOLE_AREA if bool(no_hole) else 0 def _open_image(image): if isinstance(image, dict): image = image.get("path") or image.get("name") or image.get("orig_name") virtual_image = get_virtual_image(image) if isinstance(image, str) else None if virtual_image is not None: image = virtual_image elif isinstance(image, str): image = Image.open(strip_virtual_media_suffix(image)) if isinstance(image, np.ndarray): image = Image.fromarray(image) if not isinstance(image, Image.Image): raise ValueError("Magic Mask needs a control image.") return ImageOps.exif_transpose(image).convert("RGB") def _media_path(path): if isinstance(path, dict): path = path.get("path") or path.get("name") or path.get("orig_name") return path def _video_to_numpy(video_path): video_path = _media_path(video_path) if not video_path: raise ValueError("Magic Mask needs a control video.") if isinstance(video_path, str) and has_image_file_extension(video_path): image = _open_image(video_path) width, height = image.size return np.asarray(image, dtype=np.uint8)[None], 1, width, height details = get_video_info_details(video_path) fps = details.get("fps_float") or details.get("fps") or 1 width = int(details.get("display_width") or details.get("width") or 0) height = int(details.get("display_height") or details.get("height") or 0) frame_count = int(details.get("frame_count") or 1) frames = get_resampled_video_transparent(video_path, 0, frame_count, fps, bridge="torch") if torch.is_tensor(frames): frames = frames.detach().cpu().numpy() elif hasattr(frames, "asnumpy"): frames = frames.asnumpy() else: frames = np.asarray(frames) if frames.ndim != 4 or frames.shape[0] == 0: raise ValueError("Magic Mask could not read any control video frames.") if frames.shape[-1] > 3: frames = frames[..., :3] if frames.shape[-1] == 1: frames = np.repeat(frames, 3, axis=-1) if width > 0 and height > 0 and frames.shape[1:3] != (height, width): frames = np.stack([np.asarray(Image.fromarray(frame).resize((width, height), resample=Image.Resampling.LANCZOS)) for frame in frames], axis=0) return frames.astype(np.uint8, copy=False), fps, width, height def _run_sam3(video: np.ndarray, keywords: list[str], batch_size, no_hole, progress_callback=None) -> np.ndarray: from preprocessing.sam3.preprocessor import run_sam3_video with torch.inference_mode(): return run_sam3_video( video, keywords, batched_grounding_batch_size=batch_size, postprocess_batch_size=DEFAULT_POSTPROCESS_BATCH_SIZE, use_batched_grounding=True, fill_hole_area=_fill_hole_area(no_hole), progress_callback=progress_callback, ) def prepare_image_mask_input(image) -> tuple[Image.Image, np.ndarray]: image = _open_image(image) return image, np.asarray(image, dtype=np.uint8)[None] def prepare_video_mask_input(video_path) -> tuple[str, np.ndarray, int]: video_path = _media_path(video_path) if not video_path: raise ValueError("Magic Mask needs a control video.") video, fps, _, _ = _video_to_numpy(video_path) return video_path, video, fps def generate_keyword_masks(video: np.ndarray, keyword_text: str | Iterable[str], *, batch_size=None, no_hole=True, progress_callback=None) -> np.ndarray: keywords = parse_keywords(keyword_text) if len(keywords) == 0: return np.zeros(video.shape[:3], dtype=np.bool_) return _run_sam3(video, keywords, batch_size, no_hole, progress_callback=progress_callback) def merge_keyword_masks(current_mask: np.ndarray | None, keyword_mask: np.ndarray) -> np.ndarray: keyword_mask = keyword_mask.astype(bool, copy=False) return keyword_mask.copy() if current_mask is None else (current_mask | keyword_mask) def finalize_masks(mask: np.ndarray, *, negative_mask=False) -> np.ndarray: if negative_mask: mask = ~mask return mask def mask_to_image(mask: np.ndarray) -> Image.Image: return Image.fromarray(mask.astype(np.uint8) * 255, mode="L") def _magic_mask_video_codec_params(): params = dict(_get_codec_params("libx264_10", "mp4")) params["macro_block_size"] = 1 if params.get("pixelformat") == "yuv420p": params["pixelformat"] = "yuv444p" return params def save_mask_video(video_path: str, masks: np.ndarray, fps: float, keywords: list[str], *, codec_type=None, output_dir=OUTPUT_DIR, abort_callback=None) -> str: # codec_type is kept for compatibility; Magic Mask outputs are always MP4 libx264_10. masks = masks.astype(np.uint8) * 255 mask_frames = np.repeat(masks[..., None], 3, axis=-1) Path(output_dir).mkdir(parents=True, exist_ok=True) stem = Path(strip_virtual_media_suffix(video_path)).stem keywords_suffix = truncate_keywords_for_path(keywords) output_path = Path(output_dir) / f"{sanitize_file_name(stem)}_magic_mask_{keywords_suffix}_{time.strftime('%Y%m%d_%H%M%S')}.mp4" output_path = os.fspath(output_path) writer = imageio.get_writer(output_path, fps=fps, ffmpeg_log_level="error", **_magic_mask_video_codec_params()) try: for frame in mask_frames: if abort_callback is not None: abort_callback() writer.append_data(frame) finally: writer.close() return output_path def generate_image_mask(image, keyword_text, *, batch_size=None, no_hole=True, negative_mask=False) -> tuple[Image.Image, Image.Image, list[str]]: keywords = parse_keywords(keyword_text) if len(keywords) == 0: raise ValueError("Enter at least one keyword.") image, video = prepare_image_mask_input(image) mask = finalize_masks(_run_sam3(video, keywords, batch_size, no_hole)[0], negative_mask=negative_mask) mask_image = mask_to_image(mask) return image, mask_image, keywords def generate_video_mask(video_path, keyword_text, *, batch_size=None, no_hole=True, negative_mask=False, codec_type=None, output_dir=OUTPUT_DIR) -> tuple[str, list[str]]: keywords = parse_keywords(keyword_text) if len(keywords) == 0: raise ValueError("Enter at least one keyword.") video_path, video, fps = prepare_video_mask_input(video_path) masks = finalize_masks(_run_sam3(video, keywords, batch_size, no_hole), negative_mask=negative_mask) return save_mask_video(video_path, masks, fps, keywords, output_dir=output_dir), keywords def truncate_keywords_for_path(keywords: list[str]) -> str: suffix = sanitize_file_name("_".join(keywords), "_").strip("_") return suffix[:40] or "mask" def build_image_editor_value(background: Image.Image, mask_image: Image.Image): return {"background": background, "composite": None, "layers": [rgb_bw_to_rgba_mask(mask_image)]}