Spaces:
Sleeping
Sleeping
| import os | |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| import spaces | |
| import torch, time, datetime, numpy as np, cv2 | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # ── 兼容补丁:新版 transformers 删除了 FLAX_WEIGHTS_NAME,diffusers==0.29.2 仍需要它 ── | |
| import transformers.utils | |
| if not hasattr(transformers.utils, "FLAX_WEIGHTS_NAME"): | |
| transformers.utils.FLAX_WEIGHTS_NAME = "flax_model.msgpack" | |
| def ensure_weights(repo_id, local_dir, sentinel_files, ignore_patterns=None, max_retries=5): | |
| import time as _time | |
| if all(os.path.exists(os.path.join(local_dir, f)) for f in sentinel_files): | |
| print(f"[Skip] {repo_id} already present at {local_dir}") | |
| return | |
| os.makedirs(local_dir, exist_ok=True) | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir, | |
| ignore_patterns=ignore_patterns) | |
| print(f"[OK] {repo_id}") | |
| return | |
| except Exception as e: | |
| if attempt == max_retries: | |
| raise | |
| print(f"[Download] {repo_id} attempt {attempt} failed: {e}\n Retrying in 5s...") | |
| _time.sleep(5) | |
| for subfolder in ["diffuEraser","majicmix-realistic-v7","PCM_Weights","propainter","sd-vae-ft-mse"]: | |
| os.makedirs(os.path.join("weights", subfolder), exist_ok=True) | |
| ensure_weights("lixiaowen/diffuEraser", "./weights/diffuEraser", | |
| sentinel_files=["brushnet/diffusion_pytorch_model.safetensors", | |
| "unet_main/diffusion_pytorch_model.safetensors"]) | |
| ensure_weights("digiplay/majicMIX_realistic_v7", "./weights/majicmix-realistic-v7", | |
| sentinel_files=["unet/diffusion_pytorch_model.safetensors", "model_index.json"], | |
| ignore_patterns=["*.ckpt", "*.msgpack", "*.pb", "*.h5", "flax_*"]) | |
| ensure_weights("wangfuyun/PCM_Weights", "./weights/PCM_Weights", | |
| sentinel_files=["sd15/pcm_sd15_smallcfg_2step_converted.safetensors"]) | |
| ensure_weights("camenduru/ProPainter", "./weights/propainter", | |
| sentinel_files=["ProPainter.pth"]) | |
| ensure_weights("stabilityai/sd-vae-ft-mse", "./weights/sd-vae-ft-mse", | |
| sentinel_files=["diffusion_pytorch_model.safetensors"]) | |
| from diffueraser.diffueraser import DiffuEraser | |
| from propainter.inference import Propainter, get_device | |
| from transformers import Sam3VideoModel, Sam3VideoProcessor | |
| device = get_device() | |
| video_inpainting_sd = DiffuEraser(device,"weights/majicmix-realistic-v7","weights/sd-vae-ft-mse","weights/diffuEraser",ckpt="2-Step") | |
| propainter = Propainter("weights/propainter", device=device) | |
| sam3_model = Sam3VideoModel.from_pretrained("bodhicitta/sam3").to(device, dtype=torch.bfloat16) | |
| sam3_processor = Sam3VideoProcessor.from_pretrained("bodhicitta/sam3") | |
| def read_video_frames(path): | |
| cap = cv2.VideoCapture(path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| frames = [] | |
| while True: | |
| ret, f = cap.read() | |
| if not ret: break | |
| frames.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| return frames, fps | |
| def save_frames_as_video(frames, path, fps): | |
| h, w = frames[0].shape[:2] | |
| h -= h%2; w -= w%2 | |
| out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h)) | |
| for f in frames: out.write(cv2.cvtColor(f[:h,:w], cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| def save_mask_video(masks, path, fps): | |
| h, w = masks[0].shape[:2] | |
| h -= h%2; w -= w%2 | |
| out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h)) | |
| for m in masks: | |
| rgb = np.stack([m[:h,:w]]*3, axis=-1).astype(np.uint8) | |
| out.write(cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| DILATION_PX = 30 | |
| def dilate_mask(mask, px=DILATION_PX): | |
| k = 2*px+1 | |
| return cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k,k))) | |
| def get_union_bbox(masks, H, W): | |
| union = np.zeros((H,W), dtype=np.uint8) | |
| for m in masks: union = np.maximum(union, m) | |
| ys, xs = np.where(union > 0) | |
| if len(ys) == 0: return None | |
| y1,y2 = max(0,int(ys.min())), min(H,int(ys.max())+1) | |
| x1,x2 = max(0,int(xs.min())), min(W,int(xs.max())+1) | |
| y2 = y1+((y2-y1+1)//2)*2; x2 = x1+((x2-x1+1)//2)*2 | |
| return (y1, x1, min(H,y2), min(W,x2)) | |
| def composite_back(orig_frames, repaired_frames, bbox, dilated_masks): | |
| y1,x1,y2,x2 = bbox | |
| roi_h,roi_w = y2-y1, x2-x1 | |
| result = [f.copy() for f in orig_frames] | |
| for i in range(min(len(result), len(repaired_frames))): | |
| rep = repaired_frames[i] | |
| if rep.shape[0]!=roi_h or rep.shape[1]!=roi_w: | |
| rep = cv2.resize(rep,(roi_w,roi_h)) | |
| alpha = dilated_masks[i].astype(np.float32)/255.0 | |
| alpha = cv2.GaussianBlur(alpha,(31,31),0)[:,:,np.newaxis] | |
| src = result[i].copy() | |
| src[y1:y2,x1:x2] = rep | |
| result[i] = (src.astype(np.float32)*alpha + result[i].astype(np.float32)*(1-alpha)).astype(np.uint8) | |
| return result | |
| def apply_bbox_filter(masks, filter_bbox): | |
| if filter_bbox is None: | |
| return masks | |
| fy1, fx1, fy2, fx2 = filter_bbox | |
| filtered = [] | |
| for m in masks: | |
| mf = np.zeros_like(m) | |
| mf[fy1:fy2, fx1:fx2] = m[fy1:fy2, fx1:fx2] | |
| filtered.append(mf) | |
| return filtered | |
| def generate_masks_sam3(frames, text_prompt): | |
| H,W = frames[0].shape[:2] | |
| pil_frames = [Image.fromarray(f) for f in frames] | |
| with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16): | |
| session = sam3_processor.init_video_session( | |
| video=pil_frames, inference_device=device, | |
| processing_device="cpu", video_storage_device="cpu", dtype=torch.bfloat16) | |
| sam3_processor.add_text_prompt(session, text_prompt) | |
| raw_masks = {} | |
| for model_out in sam3_model.propagate_in_video_iterator(session, show_progress_bar=True): | |
| processed = sam3_processor.postprocess_outputs(session, model_out) | |
| masks_t = processed.get("masks") | |
| if masks_t is not None and masks_t.shape[0] > 0: | |
| combined = masks_t.any(dim=0).cpu().numpy().astype(np.uint8)*255 | |
| else: | |
| combined = np.zeros((H,W), dtype=np.uint8) | |
| raw_masks[model_out.frame_idx] = combined | |
| return [dilate_mask(raw_masks.get(i, np.zeros((H,W),dtype=np.uint8))) for i in range(len(frames))] | |
| # ── First-frame helpers ── | |
| _first_frame_cache = {} # video_path -> np.array | |
| def extract_first_frame(video_path): | |
| if video_path is None: | |
| return None, gr.update(), gr.update(), gr.update(), gr.update() | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None, gr.update(), gr.update(), gr.update(), gr.update() | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| H, W = rgb.shape[:2] | |
| _first_frame_cache["frame"] = rgb | |
| return rgb, gr.update(value=0, maximum=W), gr.update(value=0, maximum=H), \ | |
| gr.update(value=W, maximum=W), gr.update(value=H, maximum=H) | |
| def draw_bbox_preview(x1, y1, x2, y2): | |
| frame = _first_frame_cache.get("frame") | |
| if frame is None: | |
| return None | |
| preview = frame.copy() | |
| x1, y1, x2, y2 = int(x1 or 0), int(y1 or 0), int(x2 or 0), int(y2 or 0) | |
| if x2 > x1 and y2 > y1: | |
| cv2.rectangle(preview, (x1, y1), (x2, y2), (255, 0, 0), 3) | |
| label = f"({x1},{y1}) -> ({x2},{y2})" | |
| cv2.putText(preview, label, (x1, max(y1-8, 12)), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2) | |
| return preview | |
| def infer(input_video, text_prompt, x1, y1, x2, y2, use_propainter, mask_upload=None): | |
| if input_video is None: raise gr.Error("Please upload a video first.") | |
| if not text_prompt.strip(): text_prompt = "watermark" | |
| save_path = "results" | |
| os.makedirs(save_path, exist_ok=True) | |
| ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| print("[1/6] Reading frames...") | |
| frames, fps = read_video_frames(input_video) | |
| if not frames: raise gr.Error("Cannot read video frames.") | |
| H, W = frames[0].shape[:2]; video_length = len(frames) | |
| print(f" {video_length} frames {W}x{H} {fps:.1f}fps") | |
| # Build bbox filter from number inputs | |
| bx1, by1, bx2, by2 = int(x1 or 0), int(y1 or 0), int(x2 or 0), int(y2 or 0) | |
| filter_bbox = None | |
| if bx2 > bx1 and by2 > by1: | |
| bx1 = max(0, min(bx1, W)); bx2 = max(0, min(bx2, W)) | |
| by1 = max(0, min(by1, H)); by2 = max(0, min(by2, H)) | |
| filter_bbox = (by1, bx1, by2, bx2) | |
| print(f" BBox filter: x1={bx1} y1={by1} x2={bx2} y2={by2}") | |
| # ── MASK: either uploaded or generated via SAM3 ── | |
| if mask_upload is not None: | |
| print("[2/6] Using uploaded mask (skipping SAM3)...") | |
| mask_img = cv2.imread(mask_upload, cv2.IMREAD_GRAYSCALE) | |
| if mask_img is None: | |
| raise gr.Error("Could not read the uploaded mask image.") | |
| if mask_img.shape[:2] != (H, W): | |
| mask_img = cv2.resize(mask_img, (W, H)) | |
| # Apply small dilation to the uploaded mask for smoother blending | |
| small_dilation = 4 | |
| dilated_masks = [dilate_mask(mask_img, px=small_dilation)] * len(frames) | |
| else: | |
| print(f"[2/6] SAM3 detecting '{text_prompt}'...") | |
| dilated_masks = generate_masks_sam3(frames, text_prompt.strip()) | |
| if filter_bbox: | |
| dilated_masks = apply_bbox_filter(dilated_masks, filter_bbox) | |
| print(" BBox filter applied.") | |
| print("[3/6] Union BBox...") | |
| bbox = get_union_bbox(dilated_masks, H, W) | |
| if bbox is None: raise gr.Error(f"No mask found for '{text_prompt}'. Adjust the bbox or text prompt.") | |
| y1r, x1r, y2r, x2r = bbox | |
| MIN_ROI = 256 | |
| for _ in range(2): | |
| roi_w, roi_h = x2r - x1r, y2r - y1r | |
| if roi_w < MIN_ROI: | |
| cx = (x1r + x2r) // 2 | |
| x1r = max(0, cx - MIN_ROI // 2) | |
| x2r = min(W, x1r + MIN_ROI) | |
| x1r = max(0, x2r - MIN_ROI) | |
| if roi_h < MIN_ROI: | |
| cy = (y1r + y2r) // 2 | |
| y1r = max(0, cy - MIN_ROI // 2) | |
| y2r = min(H, y1r + MIN_ROI) | |
| y1r = max(0, y2r - MIN_ROI) | |
| roi_w, roi_h = x2r - x1r, y2r - y1r | |
| x2r = x1r + (roi_w + 1) // 2 * 2 | |
| y2r = y1r + (roi_h + 1) // 2 * 2 | |
| x2r = min(W, x2r); y2r = min(H, y2r) | |
| bbox = (y1r, x1r, y2r, x2r) | |
| print(f" BBox: y1={y1r} x1={x1r} y2={y2r} x2={x2r} roi={x2r-x1r}x{y2r-y1r}") | |
| print("[4/6] Cropping...") | |
| cropped_frames = [f[y1r:y2r, x1r:x2r] for f in frames] | |
| cropped_masks = [m[y1r:y2r, x1r:x2r] for m in dilated_masks] | |
| crop_video_path = os.path.join(save_path, f"crop_video_{ts}.mp4") | |
| crop_mask_path = os.path.join(save_path, f"crop_mask_{ts}.mp4") | |
| save_frames_as_video(cropped_frames, crop_video_path, fps) | |
| save_mask_video(cropped_masks, crop_mask_path, fps) | |
| print("[5/6] ProPainter + DiffuEraser...") | |
| priori_path = os.path.join(save_path, f"priori_{ts}.mp4") | |
| repaired_path = os.path.join(save_path, f"repaired_{ts}.mp4") | |
| t0 = time.time() | |
| if use_propainter: | |
| propainter.forward(crop_video_path, crop_mask_path, priori_path, | |
| resize_ratio=1.0, | |
| video_length=video_length, ref_stride=10, | |
| neighbor_length=10, subvideo_length=50, mask_dilation=8) | |
| else: | |
| import shutil | |
| shutil.copy2(crop_video_path, priori_path) | |
| print(" ProPainter skipped, using original crop as priori.") | |
| video_inpainting_sd.forward(crop_video_path, crop_mask_path, priori_path, repaired_path, | |
| max_img_size=960, video_length=video_length, | |
| mask_dilation_iter=8, guidance_scale=None) | |
| print(f" Done in {time.time()-t0:.1f}s") | |
| print("[6/6] Compositing back...") | |
| repaired_frames, _ = read_video_frames(repaired_path) | |
| final_frames = composite_back(frames, repaired_frames, bbox, dilated_masks) | |
| output_path = os.path.join(save_path, f"final_{ts}.mp4") | |
| save_frames_as_video(final_frames, output_path, fps) | |
| torch.cuda.empty_cache() | |
| return output_path, priori_path, repaired_path | |
| def on_image_click(evt: gr.SelectData): | |
| """Single click -> auto-expand into a bbox around the clicked point.""" | |
| frame = _first_frame_cache.get("frame") | |
| px, py = evt.index[0], evt.index[1] | |
| if frame is not None: | |
| H, W = frame.shape[:2] | |
| hw = max(80, W // 8) | |
| hh = max(30, H // 10) | |
| else: | |
| hw, hh = 100, 40 | |
| new_x1 = max(0, px - hw) | |
| new_y1 = max(0, py - hh) | |
| new_x2 = px + hw if frame is None else min(frame.shape[1], px + hw) | |
| new_y2 = py + hh if frame is None else min(frame.shape[0], py + hh) | |
| preview = draw_bbox_preview(new_x1, new_y1, new_x2, new_y2) | |
| return new_x1, new_y1, new_x2, new_y2, preview | |
| with gr.Blocks(title="DiffuEraser + SAM3 Watermark Remover") as demo: | |
| gr.Markdown("# DiffuEraser — Video Watermark Removal") | |
| gr.Markdown( | |
| "**Option A — Automatic (SAM3):** Upload video, type what to remove, click on the area, hit Remove.\n\n" | |
| "**Option B — Manual mask:** Upload video + upload a mask image (white = remove, black = keep). Skips SAM3." | |
| ) | |
| with gr.Row(): | |
| # ── Left column: inputs ── | |
| with gr.Column(scale=1): | |
| input_video = gr.Video(label="Upload Video (MP4)", format="mp4") | |
| text_prompt = gr.Textbox(label="SAM3 text prompt (what to remove)", value="watermark", | |
| info="Used when no mask is uploaded") | |
| mask_upload = gr.Image(label="Or upload a mask image (white=remove, black=keep)", | |
| type="filepath", image_mode="L") | |
| gr.Markdown("### BBox (click on the preview image to set, or type manually)") | |
| with gr.Row(): | |
| n_x1 = gr.Number(label="x1 (left)", value=0, precision=0, minimum=0) | |
| n_y1 = gr.Number(label="y1 (top)", value=0, precision=0, minimum=0) | |
| n_x2 = gr.Number(label="x2 (right)", value=0, precision=0, minimum=0) | |
| n_y2 = gr.Number(label="y2 (bottom)", value=0, precision=0, minimum=0) | |
| use_propainter_chk = gr.Checkbox(label="Use ProPainter prior (better quality)", value=True) | |
| submit_btn = gr.Button("Remove Watermark", variant="primary") | |
| # ── Right column: previews ── | |
| with gr.Column(scale=1): | |
| bbox_preview = gr.Image(label="First Frame — click to set BBox corners", interactive=False) | |
| video_result = gr.Video(label="Result") | |
| priori_result = gr.Video(label="[Debug] ProPainter Priori") | |
| repaired_result = gr.Video(label="[Debug] DiffuEraser Repaired") | |
| # Auto-load first frame when video is uploaded | |
| input_video.change( | |
| fn=extract_first_frame, | |
| inputs=[input_video], | |
| outputs=[bbox_preview, n_x1, n_y1, n_x2, n_y2] | |
| ) | |
| # Single click -> auto bbox around clicked point | |
| bbox_preview.select( | |
| fn=on_image_click, | |
| inputs=[], | |
| outputs=[n_x1, n_y1, n_x2, n_y2, bbox_preview] | |
| ) | |
| # Also update preview when numbers are typed manually | |
| for comp in [n_x1, n_y1, n_x2, n_y2]: | |
| comp.change(fn=draw_bbox_preview, inputs=[n_x1, n_y1, n_x2, n_y2], outputs=[bbox_preview]) | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[input_video, text_prompt, n_x1, n_y1, n_x2, n_y2, use_propainter_chk, mask_upload], | |
| outputs=[video_result, priori_result, repaired_result] | |
| ) | |
| demo.queue().launch(show_error=True, ssr_mode=False) |