SoraWatermarkCleaner / profile /run_whole.py
孙振宇
Initial HF Spaces deployment
060fbda
from contextlib import contextmanager
from pathlib import Path
from typing import Callable
import numpy as np
from loguru import logger
from torch.cuda.nvtx import range_pop, range_push
from tqdm import tqdm
import ffmpeg
from sorawm.core import SoraWM
from sorawm.schemas import CleanerType
from sorawm.utils.imputation_utils import (
find_2d_data_bkps,
find_idxs_interval,
get_interval_average_bbox,
)
from sorawm.utils.video_utils import VideoLoader, merge_frames_with_overlap
from sorawm.watermark_cleaner import WaterMarkCleaner
from sorawm.watermark_detector import SoraWaterMarkDetector
@contextmanager
def nvtx(msg: str):
range_push(msg)
try:
yield
finally:
range_pop()
class ProfileSoraWM(SoraWM):
def run(
self,
input_video_path: Path,
output_video_path: Path,
progress_callback: Callable[[int], None] | None = None,
quiet: bool = False,
):
"""
Run the watermark detection and removal pipeline on an input video and write the processed video (with audio merged) to the given output path.
Detects watermark bounding boxes per frame, fills missing detections by interval averaging or neighboring frames, processes the video in breakpoint-based segments with overlap using the configured cleaner, encodes the cleaned frames to an intermediate video file, then merges the original audio into the final output.
Parameters:
input_video_path (Path): Path to the source video to process.
output_video_path (Path): Path where the final video with merged audio will be written.
progress_callback (Callable[[int], None] | None): Optional callback invoked periodically with an integer progress percentage (0–100). Progress values generally advance through detection and cleaning phases and report a final near-completion value before audio merge.
quiet (bool): If True, suppresses progress bar and most debug logging.
"""
with nvtx("ProfileSoraWM.run"):
with nvtx("init video loader"):
input_video_loader = VideoLoader(input_video_path)
width = input_video_loader.width
height = input_video_loader.height
fps = input_video_loader.fps
total_frames = input_video_loader.total_frames
temp_output_path = (
output_video_path.parent / f"temp_{output_video_path.name}"
)
output_options = {
"pix_fmt": "yuv420p",
"vcodec": "libx264",
"preset": "slow",
}
if input_video_loader.original_bitrate:
output_options["video_bitrate"] = str(
int(int(input_video_loader.original_bitrate) * 1.2)
)
else:
output_options["crf"] = "18"
process_out = (
ffmpeg.input(
"pipe:",
format="rawvideo",
pix_fmt="bgr24",
s=f"{width}x{height}",
r=fps,
)
.output(str(temp_output_path), **output_options)
.overwrite_output()
.global_args("-loglevel", "error")
.run_async(pipe_stdin=True)
)
range_push("detect watermarks")
frame_bboxes = {}
detect_missed = []
bbox_centers = []
bboxes = []
if not quiet:
logger.debug(
f"total frames: {total_frames}, fps: {fps}, width: {width}, height: {height}"
)
for idx, frame in enumerate(
tqdm(
input_video_loader,
total=total_frames,
desc="Detect watermarks",
disable=quiet,
)
):
detection_result = self.detector.detect(frame)
if detection_result["detected"]:
frame_bboxes[idx] = {"bbox": detection_result["bbox"]}
x1, y1, x2, y2 = detection_result["bbox"]
bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
bboxes.append((x1, y1, x2, y2))
else:
frame_bboxes[idx] = {"bbox": None}
detect_missed.append(idx)
bbox_centers.append(None)
bboxes.append(None)
if progress_callback and idx % 10 == 0:
progress = 10 + int((idx / total_frames) * 40)
progress_callback(progress)
if not quiet:
logger.debug(f"detect missed frames: {detect_missed}")
range_pop()
range_push("find bkps")
bkps_full = [0, total_frames]
if detect_missed:
bkps = find_2d_data_bkps(bbox_centers)
bkps_full = [0] + bkps + [total_frames]
interval_bboxes = get_interval_average_bbox(bboxes, bkps_full)
missed_intervals = find_idxs_interval(detect_missed, bkps_full)
for missed_idx, interval_idx in zip(detect_missed, missed_intervals):
if (
interval_idx < len(interval_bboxes)
and interval_bboxes[interval_idx] is not None
):
frame_bboxes[missed_idx]["bbox"] = interval_bboxes[interval_idx]
if not quiet:
logger.debug(
f"Filled missed frame {missed_idx} with bbox:\n"
f" {interval_bboxes[interval_idx]}"
)
else:
before = max(missed_idx - 1, 0)
after = min(missed_idx + 1, total_frames - 1)
before_box = frame_bboxes[before]["bbox"]
after_box = frame_bboxes[after]["bbox"]
if before_box:
frame_bboxes[missed_idx]["bbox"] = before_box
elif after_box:
frame_bboxes[missed_idx]["bbox"] = after_box
else:
del bboxes, bbox_centers, detect_missed
range_pop()
range_push("remove watermarks")
if self.cleaner_type == CleanerType.LAMA:
raise NotImplementedError("Lama cleaner is not implemented yet.")
elif self.cleaner_type == CleanerType.E2FGVI_HQ:
input_video_loader = VideoLoader(input_video_path)
frame_counter = 0
overlap_ratio = self.cleaner.config.overlap_ratio
all_cleaned_frames = None
num_segments = len(bkps_full) - 1
for segment_idx in range(num_segments):
# with nvtx(f"process segment {segment_idx}"):
range_push(f"process segment {segment_idx}")
seg_start = bkps_full[segment_idx]
seg_end = bkps_full[segment_idx + 1]
seg_length = seg_end - seg_start
segment_overlap = max(1, int(overlap_ratio * seg_length))
start = seg_start
end = seg_end
if segment_idx > 0:
start = max(
seg_start - segment_overlap,
bkps_full[segment_idx - 1],
)
if segment_idx < num_segments - 1:
end = min(
seg_end + segment_overlap,
bkps_full[segment_idx + 2],
)
if not quiet:
logger.debug(
f"Segment {segment_idx}: original=[{seg_start}, {seg_end}), "
f"with_overlap=[{start}, {end}), overlap={segment_overlap}"
)
frames = np.array(input_video_loader.get_slice(start, end))
frames = frames[:, :, :, ::-1].copy()
masks = np.zeros((len(frames), height, width), dtype=np.uint8)
for idx in range(start, end):
bbox = frame_bboxes[idx]["bbox"]
if bbox is not None:
x1, y1, x2, y2 = bbox
idx_offset = idx - start
masks[idx_offset][y1:y2, x1:x2] = 255
# with nvtx(f"clean frames [{start},{end})"):
range_push(f"clean frames [{start},{end})")
# masks_npy_path = Path("masks.npy")
# frames_np_path = Path("frames.npy")
# np.save(masks_npy_path, masks)
# np.save(frames_np_path, frames)
# raise Exception("Stop here")
cleaned_frames = self.cleaner.clean(frames, masks)
range_pop()
# with nvtx("merge frames"):
range_push("merge frames")
all_cleaned_frames = merge_frames_with_overlap(
result_frames=all_cleaned_frames,
chunk_frames=cleaned_frames,
start_idx=start,
overlap_size=segment_overlap,
is_first_chunk=(segment_idx == 0),
)
range_pop()
# with nvtx("write frames"):
range_push("write frames")
write_start = seg_start
write_end = seg_end
for write_idx in range(write_start, write_end):
if (
write_idx < len(all_cleaned_frames)
and all_cleaned_frames[write_idx] is not None
):
cleaned_frame = all_cleaned_frames[write_idx]
cleaned_frame_bgr = cleaned_frame[:, :, ::-1]
process_out.stdin.write(
cleaned_frame_bgr.astype(np.uint8).tobytes()
)
frame_counter += 1
if progress_callback and frame_counter % 10 == 0:
progress = 50 + int((frame_counter / total_frames) * 45)
progress_callback(progress)
range_pop()
range_pop()
range_pop()
range_push("finalize ffmpeg")
process_out.stdin.close()
process_out.wait()
if progress_callback:
progress_callback(95)
range_pop()
range_push("merge audio track")
self.merge_audio_track(
input_video_path, temp_output_path, output_video_path
)
range_pop()
if __name__ == "__main__":
input_video_path = Path("resources/dog_vs_sam.mp4")
output_stem = Path("outputs/sora_watermark_removed")
sora_wm = ProfileSoraWM(cleaner_type=CleanerType.E2FGVI_HQ)
sora_wm.run(
input_video_path, output_stem.parent / (output_stem.name + "_e2fgvi_hq.mp4")
)