import sys sys.path.append("../") sys.path.append("../../") # ZeroGPU compatible version import os import json import time import psutil import ffmpeg import imageio import argparse from PIL import Image import cv2 import torch import numpy as np import gradio as gr import spaces from tools.painter import mask_painter from tools.interact_tools import SamControler # get_device is NOT imported at module level to avoid CUDA init via torch.cuda.is_available() from tools.download_util import load_file_from_url from matanyone2_wrapper import matanyone2 from matanyone2.utils.get_default_model import get_matanyone2_model from matanyone2.inference.inference_core import InferenceCore from hydra.core.global_hydra import GlobalHydra import warnings warnings.filterwarnings("ignore") def parse_augment(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default=None) parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") parser.add_argument('--mask_save', default=False) args = parser.parse_args() # ZeroGPU: do NOT call get_device() (which calls torch.cuda.is_available()) at module level. # It can trigger CUDA init in the main process. Default to 'cpu'; GPU functions # determine the actual device at runtime inside @spaces.GPU-decorated functions. if not args.device: args.device = "cpu" return args # SAM generator class MaskGenerator(): def __init__(self, sam_checkpoint, args): self.args = args self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) return mask, logit, painted_image # convert points input to prompt state def get_prompt(click_state, click_input): inputs = json.loads(click_input) points = click_state[0] labels = click_state[1] for input in inputs: points.append(input[:2]) labels.append(input[2]) click_state[0] = points click_state[1] = labels prompt = { "prompt_type":["click"], "input_point":click_state[0], "input_label":click_state[1], "multimask_output":"True", } return prompt def get_frames_from_image(image_input, image_state): """ Args: video_path:str timestamp:float64 Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ user_name = time.time() frames = [image_input] * 2 # hardcode: mimic a video with 2 frames image_size = (frames[0].shape[0],frames[0].shape[1]) # initialize video_state image_state = { "user_name": user_name, "image_name": "output.png", "origin_images": frames, "painted_images": frames.copy(), "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, "fps": None } image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) # SAM loading and set_image are deferred to sam_refine() which runs under @spaces.GPU return image_state, image_info, image_state["origin_images"][0], \ gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) # extract frames from upload video def get_frames_from_video(video_input, video_state): """ Args: video_path:str timestamp:float64 Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ video_path = video_input frames = [] user_name = time.time() # extract Audio try: audio_path = video_input.replace(".mp4", "_audio.wav") ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True) except Exception as e: print(f"Audio extraction error: {str(e)}") audio_path = "" # Set to "" if extraction fails # extract frames try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) while cap.isOpened(): ret, frame = cap.read() if ret == True: current_memory_usage = psutil.virtual_memory().percent frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) if current_memory_usage > 90: break else: break except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: print("read_frame_source:{} error. {}\n".format(video_path, str(e))) image_size = (frames[0].shape[0],frames[0].shape[1]) # [remove for local demo] resize if resolution too big if image_size[0]>=1080 and image_size[0]>=1080: scale = 1080 / min(image_size) new_w = int(image_size[1] * scale) new_h = int(image_size[0] * scale) # update frames frames = [cv2.resize(f, (new_w, new_h), interpolation=cv2.INTER_AREA) for f in frames] # update image_size image_size = (frames[0].shape[0],frames[0].shape[1]) # initialize video_state video_state = { "user_name": user_name, "video_name": os.path.split(video_path)[-1], "origin_images": frames, "painted_images": frames.copy(), "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, "fps": fps, "audio": audio_path } video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) # SAM loading and set_image are deferred to sam_refine() which runs under @spaces.GPU return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) # get the select frame from gradio slider def select_video_template(image_selection_slider, video_state, interactive_state): image_selection_slider -= 1 video_state["select_frame_number"] = image_selection_slider # SAM set_image is deferred to sam_refine() which runs under @spaces.GPU return video_state["painted_images"][image_selection_slider], video_state, interactive_state def select_image_template(image_selection_slider, video_state, interactive_state): image_selection_slider = 0 # fixed for image video_state["select_frame_number"] = image_selection_slider # SAM set_image is deferred to sam_refine() which runs under @spaces.GPU return video_state["painted_images"][image_selection_slider], video_state, interactive_state # set the tracking end frame def get_end_number(track_pause_number_slider, video_state, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider return video_state["painted_images"][track_pause_number_slider],interactive_state # use sam to get the mask # ZeroGPU: gr.SelectData cannot be pickled (contains lambdas from Gradio's State.__init__). # We split into an outer wrapper that extracts plain data from the event, # and an inner @spaces.GPU function that receives only picklable arguments. @spaces.GPU(duration=60) def _sam_refine_gpu(video_state, point_prompt, click_state, interactive_state, click_x, click_y): """ Inner GPU function for SAM refinement. Args: video_state: dict with video/image data point_prompt: "Positive" or "Negative" click_state: [[points], [labels]] interactive_state: dict with interaction state click_x, click_y: integer pixel coordinates extracted from gr.SelectData """ if point_prompt == "Positive": coordinate = "[[{},{},1]]".format(click_x, click_y) interactive_state["positive_click_times"] += 1 else: coordinate = "[[{},{},0]]".format(click_x, click_y) interactive_state["negative_click_times"] += 1 # prompt for sam model ensure_sam_on_cuda() model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( image=video_state["origin_images"][video_state["select_frame_number"]], points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), multimask=prompt["multimask_output"], ) video_state["masks"][video_state["select_frame_number"]] = mask video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image return painted_image, video_state, interactive_state def sam_refine(video_state, point_prompt, click_state, interactive_state, evt: gr.SelectData): """ Outer wrapper: extracts plain picklable coordinates from gr.SelectData, then delegates to the @spaces.GPU inner function. """ click_x, click_y = int(evt.index[0]), int(evt.index[1]) return _sam_refine_gpu(video_state, point_prompt, click_state, interactive_state, click_x, click_y) def add_multi_mask(video_state, interactive_state, mask_dropdown): mask = video_state["masks"][video_state["select_frame_number"]] interactive_state["multi_mask"]["masks"].append(mask) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) select_frame = show_mask(video_state, interactive_state, mask_dropdown) return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]] def clear_click(video_state, click_state): click_state = [[],[]] template_frame = video_state["origin_images"][video_state["select_frame_number"]] return template_frame, click_state def remove_multi_mask(interactive_state, mask_dropdown): interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["masks"] = [] return interactive_state, gr.update(choices=[],value=[]) def show_mask(video_state, interactive_state, mask_dropdown): mask_dropdown.sort() if video_state["origin_images"]: select_frame = video_state["origin_images"][video_state["select_frame_number"]] for i in range(len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 mask = interactive_state["multi_mask"]["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) return select_frame # image matting @spaces.GPU(duration=120) def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter, model_selection): # Load model if not already loaded try: selected_model = load_model(model_selection) except (FileNotFoundError, ValueError) as e: # Fallback to first available model if available_models: print(f"Warning: {str(e)}. Using {available_models[0]} instead.") selected_model = load_model(available_models[0]) else: raise ValueError("No models are available! Please check if the model files exist.") matanyone_processor = InferenceCore(selected_model, cfg=selected_model.cfg) if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] else: following_frames = video_state["origin_images"][video_state["select_frame_number"]:] if interactive_state["multi_mask"]["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) video_state["masks"][video_state["select_frame_number"]]= template_mask else: template_mask = video_state["masks"][video_state["select_frame_number"]] # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 foreground, alpha = matanyone2(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter) foreground_output = Image.fromarray(foreground[-1]) alpha_output = Image.fromarray(alpha[-1][:,:,0]) return foreground_output, alpha_output # video matting @spaces.GPU(duration=300) def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, model_selection): # Load model if not already loaded try: selected_model = load_model(model_selection) except (FileNotFoundError, ValueError) as e: # Fallback to first available model if available_models: print(f"Warning: {str(e)}. Using {available_models[0]} instead.") selected_model = load_model(available_models[0]) else: raise ValueError("No models are available! Please check if the model files exist.") matanyone_processor = InferenceCore(selected_model, cfg=selected_model.cfg) if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] else: following_frames = video_state["origin_images"][video_state["select_frame_number"]:] if interactive_state["multi_mask"]["masks"]: if len(mask_dropdown) == 0: mask_dropdown = ["mask_001"] mask_dropdown.sort() template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) for i in range(1,len(mask_dropdown)): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) video_state["masks"][video_state["select_frame_number"]]= template_mask else: template_mask = video_state["masks"][video_state["select_frame_number"]] fps = video_state["fps"] audio_path = video_state["audio"] # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 foreground, alpha = matanyone2(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size) foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video return foreground_output, alpha_output def add_audio_to_video(video_path, audio_path, output_path): try: video_input = ffmpeg.input(video_path) audio_input = ffmpeg.input(audio_path) _ = ( ffmpeg .output(video_input, audio_input, output_path, vcodec="copy", acodec="aac") .run(overwrite_output=True, capture_stdout=True, capture_stderr=True) ) return output_path except ffmpeg.Error as e: print(f"FFmpeg error:\n{e.stderr.decode()}") return None def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""): frames = np.asarray(frames) if gray2rgb: frames = np.repeat(frames, 3, axis=3) _, h, w, _ = frames.shape h = h // 2 * 2 w = w // 2 * 2 if frames.shape[1] != h or frames.shape[2] != w: frames = np.asarray([ cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR) for frame in frames ]) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) video_temp_path = output_path.replace(".mp4", "_temp.mp4") imageio.mimwrite( video_temp_path, frames, fps=fps, quality=7, codec="libx264", macro_block_size=1 ) if audio_path != "" and os.path.exists(audio_path): output_path = add_audio_to_video(video_temp_path, audio_path, output_path) os.remove(video_temp_path) return output_path return video_temp_path # reset all states for a new input def restart(): return { "user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 30 }, { "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, "mask_save": args.mask_save, "multi_mask": { "mask_names": [], "masks": [] }, "track_end_number": None, }, [[],[]], None, None, \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False) # args, defined in track_anything.py args = parse_augment() sam_checkpoint_url_dict = { 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models') # ZeroGPU: do NOT download or load models at module level. # All model loading is deferred to the first GPU function call. model = None # SAM MaskGenerator — lazily initialized # Model display names to file names mapping model_display_to_file = { "MatAnyone": "matanyone.pth", "MatAnyone 2": "matanyone2.pth" } # Model URLs model_urls = { "matanyone.pth": "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth", "matanyone2.pth": "https://github.com/pq-yang/MatAnyone2/releases/download/v1.0.0/matanyone2.pth" } # MatAnyone model file paths — filled lazily on first download model_paths = {} # Cache for loaded MatAnyone models (lazy loading) loaded_models = {} # All supported models (for the UI) — always show both options available_models = ["MatAnyone 2", "MatAnyone"] default_model = "MatAnyone 2" def ensure_sam_loaded(): """Download SAM checkpoint and init MaskGenerator on CPU (safe to call outside GPU context).""" global model if model is None: sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder) # Always load on CPU here — CUDA placement happens in ensure_sam_on_cuda(), # which is only ever called from within a @spaces.GPU-decorated function. import copy cpu_args = copy.copy(args) cpu_args.device = "cpu" model = MaskGenerator(sam_checkpoint, cpu_args) def ensure_sam_on_cuda(): """Move SAM to CUDA. Must only be called inside a @spaces.GPU-decorated function.""" ensure_sam_loaded() cuda_device = "cuda" if torch.cuda.is_available() else "cpu" model.samcontroler.sam_controler.predictor.model.to(cuda_device) model.samcontroler.sam_controler.device = cuda_device model.samcontroler.sam_controler.torch_dtype = torch.float16 if cuda_device == "cuda" else torch.float32 def _ensure_matanyone_downloaded(model_file): """Download the MatAnyone checkpoint if not already present.""" if model_file not in model_paths: model_paths[model_file] = load_file_from_url(model_urls[model_file], checkpoint_folder) return model_paths[model_file] def load_model(display_name): """Download (if needed) and load a MatAnyone model. Cached after first load.""" # Map display name to file name if display_name in model_display_to_file: model_file = model_display_to_file[display_name] elif display_name in model_urls: model_file = display_name else: raise ValueError(f"Unknown model: {display_name}") if model_file in loaded_models: return loaded_models[model_file] ckpt_path = _ensure_matanyone_downloaded(model_file) # Clear Hydra instance if already initialized (to allow loading different models) try: GlobalHydra.instance().clear() except Exception: pass device = "cuda" if torch.cuda.is_available() else args.device print(f"Loading model: {display_name} ({model_file}) on {device}...") loaded_mat_model = get_matanyone2_model(ckpt_path, device) loaded_mat_model = loaded_mat_model.to(device).eval() loaded_models[model_file] = loaded_mat_model print(f"Model {display_name} loaded successfully.") return loaded_mat_model # download test samples test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/") load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-0-720p.mp4', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-1-720p.mp4', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-2-720p.mp4', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-3-720p.mp4', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-4-720p.mp4', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-5-720p.mp4', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-0.jpg', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-1.jpg', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-2.jpg', test_sample_path) load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-3.jpg', test_sample_path) # download assets assets_path = os.path.join('/home/user/app/hugging_face/', "assets/") load_file_from_url('https://github.com/pq-yang/MatAnyone/releases/download/media/tutorial_single_target.mp4', assets_path) load_file_from_url('https://github.com/pq-yang/MatAnyone/releases/download/media/tutorial_multi_targets.mp4', assets_path) # documents title = r"""