drmna_wtrm / gradio_app.py
ButchersBrain's picture
Update gradio_app.py
92e6e0a verified
import os
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
import spaces
import torch, time, datetime, numpy as np, cv2
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
# ── 兼容补丁:新版 transformers 删除了 FLAX_WEIGHTS_NAME,diffusers==0.29.2 仍需要它 ──
import transformers.utils
if not hasattr(transformers.utils, "FLAX_WEIGHTS_NAME"):
transformers.utils.FLAX_WEIGHTS_NAME = "flax_model.msgpack"
def ensure_weights(repo_id, local_dir, sentinel_files, ignore_patterns=None, max_retries=5):
import time as _time
if all(os.path.exists(os.path.join(local_dir, f)) for f in sentinel_files):
print(f"[Skip] {repo_id} already present at {local_dir}")
return
os.makedirs(local_dir, exist_ok=True)
for attempt in range(1, max_retries + 1):
try:
snapshot_download(repo_id=repo_id, local_dir=local_dir,
ignore_patterns=ignore_patterns)
print(f"[OK] {repo_id}")
return
except Exception as e:
if attempt == max_retries:
raise
print(f"[Download] {repo_id} attempt {attempt} failed: {e}\n Retrying in 5s...")
_time.sleep(5)
for subfolder in ["diffuEraser","majicmix-realistic-v7","PCM_Weights","propainter","sd-vae-ft-mse"]:
os.makedirs(os.path.join("weights", subfolder), exist_ok=True)
ensure_weights("lixiaowen/diffuEraser", "./weights/diffuEraser",
sentinel_files=["brushnet/diffusion_pytorch_model.safetensors",
"unet_main/diffusion_pytorch_model.safetensors"])
ensure_weights("digiplay/majicMIX_realistic_v7", "./weights/majicmix-realistic-v7",
sentinel_files=["unet/diffusion_pytorch_model.safetensors", "model_index.json"],
ignore_patterns=["*.ckpt", "*.msgpack", "*.pb", "*.h5", "flax_*"])
ensure_weights("wangfuyun/PCM_Weights", "./weights/PCM_Weights",
sentinel_files=["sd15/pcm_sd15_smallcfg_2step_converted.safetensors"])
ensure_weights("camenduru/ProPainter", "./weights/propainter",
sentinel_files=["ProPainter.pth"])
ensure_weights("stabilityai/sd-vae-ft-mse", "./weights/sd-vae-ft-mse",
sentinel_files=["diffusion_pytorch_model.safetensors"])
from diffueraser.diffueraser import DiffuEraser
from propainter.inference import Propainter, get_device
from transformers import Sam3VideoModel, Sam3VideoProcessor
device = get_device()
video_inpainting_sd = DiffuEraser(device,"weights/majicmix-realistic-v7","weights/sd-vae-ft-mse","weights/diffuEraser",ckpt="2-Step")
propainter = Propainter("weights/propainter", device=device)
sam3_model = Sam3VideoModel.from_pretrained("bodhicitta/sam3").to(device, dtype=torch.bfloat16)
sam3_processor = Sam3VideoProcessor.from_pretrained("bodhicitta/sam3")
def read_video_frames(path):
cap = cv2.VideoCapture(path)
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
frames = []
while True:
ret, f = cap.read()
if not ret: break
frames.append(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
cap.release()
return frames, fps
def save_frames_as_video(frames, path, fps):
h, w = frames[0].shape[:2]
h -= h%2; w -= w%2
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
for f in frames: out.write(cv2.cvtColor(f[:h,:w], cv2.COLOR_RGB2BGR))
out.release()
def save_mask_video(masks, path, fps):
h, w = masks[0].shape[:2]
h -= h%2; w -= w%2
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
for m in masks:
rgb = np.stack([m[:h,:w]]*3, axis=-1).astype(np.uint8)
out.write(cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
out.release()
DILATION_PX = 30
def dilate_mask(mask, px=DILATION_PX):
k = 2*px+1
return cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k,k)))
def get_union_bbox(masks, H, W):
union = np.zeros((H,W), dtype=np.uint8)
for m in masks: union = np.maximum(union, m)
ys, xs = np.where(union > 0)
if len(ys) == 0: return None
y1,y2 = max(0,int(ys.min())), min(H,int(ys.max())+1)
x1,x2 = max(0,int(xs.min())), min(W,int(xs.max())+1)
y2 = y1+((y2-y1+1)//2)*2; x2 = x1+((x2-x1+1)//2)*2
return (y1, x1, min(H,y2), min(W,x2))
def composite_back(orig_frames, repaired_frames, bbox, dilated_masks):
y1,x1,y2,x2 = bbox
roi_h,roi_w = y2-y1, x2-x1
result = [f.copy() for f in orig_frames]
for i in range(min(len(result), len(repaired_frames))):
rep = repaired_frames[i]
if rep.shape[0]!=roi_h or rep.shape[1]!=roi_w:
rep = cv2.resize(rep,(roi_w,roi_h))
alpha = dilated_masks[i].astype(np.float32)/255.0
alpha = cv2.GaussianBlur(alpha,(31,31),0)[:,:,np.newaxis]
src = result[i].copy()
src[y1:y2,x1:x2] = rep
result[i] = (src.astype(np.float32)*alpha + result[i].astype(np.float32)*(1-alpha)).astype(np.uint8)
return result
def apply_bbox_filter(masks, filter_bbox):
if filter_bbox is None:
return masks
fy1, fx1, fy2, fx2 = filter_bbox
filtered = []
for m in masks:
mf = np.zeros_like(m)
mf[fy1:fy2, fx1:fx2] = m[fy1:fy2, fx1:fx2]
filtered.append(mf)
return filtered
def generate_masks_sam3(frames, text_prompt):
H,W = frames[0].shape[:2]
pil_frames = [Image.fromarray(f) for f in frames]
with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
session = sam3_processor.init_video_session(
video=pil_frames, inference_device=device,
processing_device="cpu", video_storage_device="cpu", dtype=torch.bfloat16)
sam3_processor.add_text_prompt(session, text_prompt)
raw_masks = {}
for model_out in sam3_model.propagate_in_video_iterator(session, show_progress_bar=True):
processed = sam3_processor.postprocess_outputs(session, model_out)
masks_t = processed.get("masks")
if masks_t is not None and masks_t.shape[0] > 0:
combined = masks_t.any(dim=0).cpu().numpy().astype(np.uint8)*255
else:
combined = np.zeros((H,W), dtype=np.uint8)
raw_masks[model_out.frame_idx] = combined
return [dilate_mask(raw_masks.get(i, np.zeros((H,W),dtype=np.uint8))) for i in range(len(frames))]
# ── First-frame helpers ──
_first_frame_cache = {} # video_path -> np.array
def extract_first_frame(video_path):
if video_path is None:
return None, gr.update(), gr.update(), gr.update(), gr.update()
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
cap.release()
if not ret:
return None, gr.update(), gr.update(), gr.update(), gr.update()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
H, W = rgb.shape[:2]
_first_frame_cache["frame"] = rgb
return rgb, gr.update(value=0, maximum=W), gr.update(value=0, maximum=H), \
gr.update(value=W, maximum=W), gr.update(value=H, maximum=H)
def draw_bbox_preview(x1, y1, x2, y2):
frame = _first_frame_cache.get("frame")
if frame is None:
return None
preview = frame.copy()
x1, y1, x2, y2 = int(x1 or 0), int(y1 or 0), int(x2 or 0), int(y2 or 0)
if x2 > x1 and y2 > y1:
cv2.rectangle(preview, (x1, y1), (x2, y2), (255, 0, 0), 3)
label = f"({x1},{y1}) -> ({x2},{y2})"
cv2.putText(preview, label, (x1, max(y1-8, 12)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
return preview
@spaces.GPU(duration=240)
def infer(input_video, text_prompt, x1, y1, x2, y2, use_propainter, mask_upload=None):
if input_video is None: raise gr.Error("Please upload a video first.")
if not text_prompt.strip(): text_prompt = "watermark"
save_path = "results"
os.makedirs(save_path, exist_ok=True)
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
print("[1/6] Reading frames...")
frames, fps = read_video_frames(input_video)
if not frames: raise gr.Error("Cannot read video frames.")
H, W = frames[0].shape[:2]; video_length = len(frames)
print(f" {video_length} frames {W}x{H} {fps:.1f}fps")
# Build bbox filter from number inputs
bx1, by1, bx2, by2 = int(x1 or 0), int(y1 or 0), int(x2 or 0), int(y2 or 0)
filter_bbox = None
if bx2 > bx1 and by2 > by1:
bx1 = max(0, min(bx1, W)); bx2 = max(0, min(bx2, W))
by1 = max(0, min(by1, H)); by2 = max(0, min(by2, H))
filter_bbox = (by1, bx1, by2, bx2)
print(f" BBox filter: x1={bx1} y1={by1} x2={bx2} y2={by2}")
# ── MASK: either uploaded or generated via SAM3 ──
if mask_upload is not None:
print("[2/6] Using uploaded mask (skipping SAM3)...")
mask_img = cv2.imread(mask_upload, cv2.IMREAD_GRAYSCALE)
if mask_img is None:
raise gr.Error("Could not read the uploaded mask image.")
if mask_img.shape[:2] != (H, W):
mask_img = cv2.resize(mask_img, (W, H))
# Apply small dilation to the uploaded mask for smoother blending
small_dilation = 4
dilated_masks = [dilate_mask(mask_img, px=small_dilation)] * len(frames)
else:
print(f"[2/6] SAM3 detecting '{text_prompt}'...")
dilated_masks = generate_masks_sam3(frames, text_prompt.strip())
if filter_bbox:
dilated_masks = apply_bbox_filter(dilated_masks, filter_bbox)
print(" BBox filter applied.")
print("[3/6] Union BBox...")
bbox = get_union_bbox(dilated_masks, H, W)
if bbox is None: raise gr.Error(f"No mask found for '{text_prompt}'. Adjust the bbox or text prompt.")
y1r, x1r, y2r, x2r = bbox
MIN_ROI = 256
for _ in range(2):
roi_w, roi_h = x2r - x1r, y2r - y1r
if roi_w < MIN_ROI:
cx = (x1r + x2r) // 2
x1r = max(0, cx - MIN_ROI // 2)
x2r = min(W, x1r + MIN_ROI)
x1r = max(0, x2r - MIN_ROI)
if roi_h < MIN_ROI:
cy = (y1r + y2r) // 2
y1r = max(0, cy - MIN_ROI // 2)
y2r = min(H, y1r + MIN_ROI)
y1r = max(0, y2r - MIN_ROI)
roi_w, roi_h = x2r - x1r, y2r - y1r
x2r = x1r + (roi_w + 1) // 2 * 2
y2r = y1r + (roi_h + 1) // 2 * 2
x2r = min(W, x2r); y2r = min(H, y2r)
bbox = (y1r, x1r, y2r, x2r)
print(f" BBox: y1={y1r} x1={x1r} y2={y2r} x2={x2r} roi={x2r-x1r}x{y2r-y1r}")
print("[4/6] Cropping...")
cropped_frames = [f[y1r:y2r, x1r:x2r] for f in frames]
cropped_masks = [m[y1r:y2r, x1r:x2r] for m in dilated_masks]
crop_video_path = os.path.join(save_path, f"crop_video_{ts}.mp4")
crop_mask_path = os.path.join(save_path, f"crop_mask_{ts}.mp4")
save_frames_as_video(cropped_frames, crop_video_path, fps)
save_mask_video(cropped_masks, crop_mask_path, fps)
print("[5/6] ProPainter + DiffuEraser...")
priori_path = os.path.join(save_path, f"priori_{ts}.mp4")
repaired_path = os.path.join(save_path, f"repaired_{ts}.mp4")
t0 = time.time()
if use_propainter:
propainter.forward(crop_video_path, crop_mask_path, priori_path,
resize_ratio=1.0,
video_length=video_length, ref_stride=10,
neighbor_length=10, subvideo_length=50, mask_dilation=8)
else:
import shutil
shutil.copy2(crop_video_path, priori_path)
print(" ProPainter skipped, using original crop as priori.")
video_inpainting_sd.forward(crop_video_path, crop_mask_path, priori_path, repaired_path,
max_img_size=960, video_length=video_length,
mask_dilation_iter=8, guidance_scale=None)
print(f" Done in {time.time()-t0:.1f}s")
print("[6/6] Compositing back...")
repaired_frames, _ = read_video_frames(repaired_path)
final_frames = composite_back(frames, repaired_frames, bbox, dilated_masks)
output_path = os.path.join(save_path, f"final_{ts}.mp4")
save_frames_as_video(final_frames, output_path, fps)
torch.cuda.empty_cache()
return output_path, priori_path, repaired_path
def on_image_click(evt: gr.SelectData):
"""Single click -> auto-expand into a bbox around the clicked point."""
frame = _first_frame_cache.get("frame")
px, py = evt.index[0], evt.index[1]
if frame is not None:
H, W = frame.shape[:2]
hw = max(80, W // 8)
hh = max(30, H // 10)
else:
hw, hh = 100, 40
new_x1 = max(0, px - hw)
new_y1 = max(0, py - hh)
new_x2 = px + hw if frame is None else min(frame.shape[1], px + hw)
new_y2 = py + hh if frame is None else min(frame.shape[0], py + hh)
preview = draw_bbox_preview(new_x1, new_y1, new_x2, new_y2)
return new_x1, new_y1, new_x2, new_y2, preview
with gr.Blocks(title="DiffuEraser + SAM3 Watermark Remover") as demo:
gr.Markdown("# DiffuEraser — Video Watermark Removal")
gr.Markdown(
"**Option A — Automatic (SAM3):** Upload video, type what to remove, click on the area, hit Remove.\n\n"
"**Option B — Manual mask:** Upload video + upload a mask image (white = remove, black = keep). Skips SAM3."
)
with gr.Row():
# ── Left column: inputs ──
with gr.Column(scale=1):
input_video = gr.Video(label="Upload Video (MP4)", format="mp4")
text_prompt = gr.Textbox(label="SAM3 text prompt (what to remove)", value="watermark",
info="Used when no mask is uploaded")
mask_upload = gr.Image(label="Or upload a mask image (white=remove, black=keep)",
type="filepath", image_mode="L")
gr.Markdown("### BBox (click on the preview image to set, or type manually)")
with gr.Row():
n_x1 = gr.Number(label="x1 (left)", value=0, precision=0, minimum=0)
n_y1 = gr.Number(label="y1 (top)", value=0, precision=0, minimum=0)
n_x2 = gr.Number(label="x2 (right)", value=0, precision=0, minimum=0)
n_y2 = gr.Number(label="y2 (bottom)", value=0, precision=0, minimum=0)
use_propainter_chk = gr.Checkbox(label="Use ProPainter prior (better quality)", value=True)
submit_btn = gr.Button("Remove Watermark", variant="primary")
# ── Right column: previews ──
with gr.Column(scale=1):
bbox_preview = gr.Image(label="First Frame — click to set BBox corners", interactive=False)
video_result = gr.Video(label="Result")
priori_result = gr.Video(label="[Debug] ProPainter Priori")
repaired_result = gr.Video(label="[Debug] DiffuEraser Repaired")
# Auto-load first frame when video is uploaded
input_video.change(
fn=extract_first_frame,
inputs=[input_video],
outputs=[bbox_preview, n_x1, n_y1, n_x2, n_y2]
)
# Single click -> auto bbox around clicked point
bbox_preview.select(
fn=on_image_click,
inputs=[],
outputs=[n_x1, n_y1, n_x2, n_y2, bbox_preview]
)
# Also update preview when numbers are typed manually
for comp in [n_x1, n_y1, n_x2, n_y2]:
comp.change(fn=draw_bbox_preview, inputs=[n_x1, n_y1, n_x2, n_y2], outputs=[bbox_preview])
submit_btn.click(
fn=infer,
inputs=[input_video, text_prompt, n_x1, n_y1, n_x2, n_y2, use_propainter_chk, mask_upload],
outputs=[video_result, priori_result, repaired_result]
)
demo.queue().launch(show_error=True, ssr_mode=False)