|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import spaces |
|
|
import gradio as gr |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import time |
|
|
import random |
|
|
from PIL import Image |
|
|
import torch |
|
|
import re |
|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
import tempfile |
|
|
|
|
|
torch.jit.script = lambda f: f |
|
|
|
|
|
from transparent_background import Remover |
|
|
|
|
|
@spaces.GPU(duration=90) |
|
|
def doo(video, color, mode, out_format, progress=gr.Progress()): |
|
|
print(str(color)) |
|
|
if str(color).startswith('#'): |
|
|
color = color.lstrip('#') |
|
|
rgb = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) |
|
|
color = str(list(rgb)) |
|
|
elif str(color).startswith('rgba'): |
|
|
rgba_match = re.match(r'rgba\(([\d.]+), ([\d.]+), ([\d.]+), [\d.]+\)', color) |
|
|
if rgba_match: |
|
|
r, g, b = rgba_match.groups() |
|
|
color = str([int(float(r)), int(float(g)), int(float(b))]) |
|
|
print("Parsed color:", color) |
|
|
if mode == 'Fast': |
|
|
remover = Remover(mode='fast') |
|
|
else: |
|
|
remover = Remover() |
|
|
|
|
|
cap = cv2.VideoCapture(video) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
|
|
writer = None |
|
|
tmpname = random.randint(111111111, 999999999) |
|
|
processed_frames = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
mp4_path = str(tmpname) + '.mp4' |
|
|
webm_path = str(tmpname) + '.webm' |
|
|
|
|
|
if out_format == 'mp4': |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
|
|
|
if ret is False: |
|
|
break |
|
|
|
|
|
if time.time() - start_time >= 20 * 60 - 5: |
|
|
print("GPU Timeout is coming") |
|
|
cap.release() |
|
|
if writer is not None: |
|
|
writer.release() |
|
|
return mp4_path |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
img = Image.fromarray(frame).convert('RGB') |
|
|
|
|
|
if writer is None: |
|
|
writer = cv2.VideoWriter(mp4_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, img.size) |
|
|
|
|
|
processed_frames += 1 |
|
|
print(f"Processing frame {processed_frames}") |
|
|
progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") |
|
|
|
|
|
out = remover.process(img, type=color) |
|
|
|
|
|
frame_bgr = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR) |
|
|
writer.write(frame_bgr) |
|
|
|
|
|
cap.release() |
|
|
if writer is not None: |
|
|
writer.release() |
|
|
return mp4_path |
|
|
|
|
|
else: |
|
|
temp_dir = tempfile.mkdtemp(prefix=f"tb_{tmpname}_") |
|
|
try: |
|
|
frame_idx = 0 |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
|
|
|
if ret is False: |
|
|
break |
|
|
|
|
|
if time.time() - start_time >= 20 * 60 - 5: |
|
|
print("GPU Timeout is coming") |
|
|
cap.release() |
|
|
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
return webm_path |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
img = Image.fromarray(frame).convert('RGB') |
|
|
|
|
|
processed_frames += 1 |
|
|
frame_idx += 1 |
|
|
print(f"Processing frame {processed_frames}") |
|
|
progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") |
|
|
|
|
|
out = remover.process(img, type='rgba') |
|
|
out = out.convert('RGBA') |
|
|
|
|
|
frame_name = os.path.join(temp_dir, f"frame_{frame_idx:06d}.png") |
|
|
out.save(frame_name, 'PNG') |
|
|
|
|
|
cap.release() |
|
|
|
|
|
fr_str = str(int(round(fps))) if fps > 0 else "25" |
|
|
pattern = os.path.join(temp_dir, "frame_%06d.png") |
|
|
ffmpeg_cmd = [ |
|
|
"ffmpeg", "-y", |
|
|
"-framerate", fr_str, |
|
|
"-i", pattern, |
|
|
"-i", str(video), |
|
|
"-map", "0:v", |
|
|
"-map", "1:a?", |
|
|
"-c:v", "libvpx-vp9", |
|
|
"-pix_fmt", "yuva420p", |
|
|
"-auto-alt-ref", "0", |
|
|
"-metadata:s:v:0", "alpha_mode=1", |
|
|
"-c:a", "libopus", |
|
|
"-shortest", |
|
|
webm_path |
|
|
] |
|
|
print("Running ffmpeg:", " ".join(ffmpeg_cmd)) |
|
|
subprocess.run(ffmpeg_cmd, check=True) |
|
|
|
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
return webm_path |
|
|
|
|
|
except subprocess.CalledProcessError as e: |
|
|
print("ffmpeg failed:", e) |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
return webm_path |
|
|
except Exception as e: |
|
|
print("Error during processing:", e) |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
raise |
|
|
|
|
|
title = "🎞️ Video Background Removal Tool 🎥" |
|
|
description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.*""" |
|
|
|
|
|
examples = [ |
|
|
['./input.mp4', '#00FF00', 'Normal', 'mp4'], |
|
|
] |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=doo, |
|
|
inputs=[ |
|
|
"video", |
|
|
gr.ColorPicker(label="Background color", value="#00FF00"), |
|
|
gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.'), |
|
|
gr.components.Radio(['mp4', 'webm'], label='Output format', value='mp4') |
|
|
], |
|
|
outputs="video", |
|
|
examples=examples, |
|
|
title=title, |
|
|
description=description |
|
|
) |
|
|
iface.launch() |
|
|
|