fantaxy commited on
Commit
02af6f7
·
verified ·
1 Parent(s): 4fcd804

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import cv2
4
+ import numpy as np
5
+ import time
6
+ import random
7
+ from PIL import Image
8
+ import torch
9
+ from torchvision import transforms
10
+ from transparent_background import Remover
11
+
12
+ torch.jit.script = lambda f: f
13
+
14
+ def apply_temporal_smoothing(current_mask, previous_mask, alpha=0.9):
15
+ if previous_mask is None:
16
+ return current_mask
17
+ return alpha * previous_mask + (1 - alpha) * current_mask
18
+
19
+ def post_process_mask(mask, kernel_size=5):
20
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
21
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
22
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
23
+ mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
24
+ return mask
25
+
26
+ @spaces.GPU()
27
+ def doo(video, mode, progress=gr.Progress()):
28
+ if mode == 'Fast':
29
+ remover = Remover(mode='fast')
30
+ else:
31
+ remover = Remover()
32
+
33
+ cap = cv2.VideoCapture(video)
34
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
35
+ writer = None
36
+ tmpname = random.randint(111111111, 999999999)
37
+ processed_frames = 0
38
+ start_time = time.time()
39
+ previous_mask = None
40
+
41
+ while cap.isOpened():
42
+ ret, frame = cap.read()
43
+ if ret is False:
44
+ break
45
+
46
+ if time.time() - start_time >= 20 * 60 - 5:
47
+ print("GPU Timeout is coming")
48
+ cap.release()
49
+ writer.release()
50
+ return str(tmpname) + '.mp4'
51
+
52
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53
+ img = Image.fromarray(frame).convert('RGB')
54
+
55
+ if writer is None:
56
+ writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
57
+
58
+ processed_frames += 1
59
+ print(f"Processing frame {processed_frames}")
60
+ progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
61
+
62
+ # 배경 제거
63
+ out = remover.process(img, type='green')
64
+
65
+ # 마스크 추출 및 후처리
66
+ mask = np.array(out)[:,:,3]
67
+ mask = post_process_mask(mask)
68
+
69
+ # 시간적 평활화 적용
70
+ mask = apply_temporal_smoothing(mask, previous_mask)
71
+ previous_mask = mask
72
+
73
+ # 마스크 적용 및 색상 보정
74
+ result = cv2.multiply(frame, cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB).astype(float) / 255.0)
75
+ result = cv2.addWeighted(result, 1.1, frame, 0, 0) # 색상 보정
76
+
77
+ writer.write(cv2.cvtColor(result.astype(np.uint8), cv2.COLOR_RGB2BGR))
78
+
79
+ cap.release()
80
+ writer.release()
81
+ return str(tmpname) + '.mp4'
82
+
83
+ title = "🎞️ Enhanced Video Background Removal Tool 🎥"
84
+ description = """
85
+ *Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.
86
+ This enhanced version includes improved mask processing, temporal smoothing, and color correction for better results.*
87
+ """
88
+
89
+ examples = [['./input.mp4']]
90
+
91
+ iface = gr.Interface(
92
+ fn=doo,
93
+ inputs=["video", gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.')],
94
+ outputs="video",
95
+ examples=examples,
96
+ title=title,
97
+ description=description
98
+ )
99
+
100
+ iface.launch()