fantaxy commited on
Commit
62861c6
·
verified ·
1 Parent(s): 0b82ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -101
app.py CHANGED
@@ -1,130 +1,79 @@
1
- import spaces
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
 
5
  import time
6
  import random
7
  from PIL import Image
8
- import torch
9
- from torchvision import transforms
10
  from transparent_background import Remover
11
 
12
- torch.jit.script = lambda f: f
13
-
14
- def apply_temporal_smoothing(current_mask, previous_mask, alpha=0.9):
15
- if previous_mask is None:
16
- return current_mask
17
- return alpha * previous_mask + (1 - alpha) * current_mask
18
-
19
- def post_process_mask(mask, kernel_size=5):
20
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
21
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
22
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
23
- mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
24
- return mask
25
-
26
- @spaces.GPU()
27
- def doo(video, mode, progress=gr.Progress()):
28
- if mode == 'Fast':
29
- remover = Remover(mode='fast')
30
- else:
31
- remover = Remover()
32
-
33
  cap = cv2.VideoCapture(video)
34
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
35
- writer = None
36
- tmpname = random.randint(111111111, 999999999)
37
- processed_frames = 0
38
- start_time = time.time()
39
- previous_mask = None
 
 
 
 
40
 
41
- while cap.isOpened():
42
  ret, frame = cap.read()
43
- if ret is False:
44
  break
45
 
46
- if time.time() - start_time >= 20 * 60 - 5:
47
- print("GPU Timeout is coming")
48
- cap.release()
49
- if writer:
50
- writer.release()
51
- return str(tmpname) + '.mp4'
52
 
53
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
54
- img = Image.fromarray(frame).convert('RGB')
55
-
56
- if writer is None:
57
- writer = cv2.VideoWriter(str(tmpname) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), cap.get(cv2.CAP_PROP_FPS), img.size)
58
-
59
- processed_frames += 1
60
- print(f"Processing frame {processed_frames}")
61
- progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")
62
-
63
- # 배경 제거
64
- out = remover.process(img, type='green')
65
-
66
- # 마스크 생성 (RGB to grayscale)
67
- mask = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2GRAY)
68
 
69
- # 마스크를 float32로 변환하고 0-1 범위로 정규화
70
- mask = mask.astype(np.float32) / 255.0
71
-
72
- # 마스크 후처리
73
- mask = post_process_mask(mask)
74
-
75
- # 시간적 평활화 적용
76
- mask = apply_temporal_smoothing(mask, previous_mask)
77
- previous_mask = mask
78
 
79
- # 마스크 0-255 범위의 uint8로 변환
80
- mask = (mask * 255).astype(np.uint8)
81
-
82
- # 마스크 적용 및 색상 보정
83
- mask_3d = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
84
-
85
- # frame을 float32로 변환하고 0-1 범위로 정규화
86
- frame = frame.astype(np.float32) / 255.0
87
- mask_3d = mask_3d.astype(np.float32) / 255.0
88
-
89
- result = cv2.multiply(frame, mask_3d)
90
- result = cv2.addWeighted(result, 1.1, frame, 0, 0) # 색상 보정
91
-
92
- # 0-255 범위로 변환
93
- result = (result * 255).astype(np.uint8)
94
 
95
- writer.write(cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
 
96
 
97
- cap.release()
98
- if writer:
99
- writer.release()
100
- return str(tmpname) + '.mp4'
 
 
101
 
102
- def post_process_mask(mask, kernel_size=5):
103
- kernel = np.ones((kernel_size, kernel_size), np.float32)
104
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
105
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
106
- return cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
107
 
108
- def apply_temporal_smoothing(current_mask, previous_mask, alpha=0.9):
109
- if previous_mask is None:
110
- return current_mask
111
- return cv2.addWeighted(previous_mask, alpha, current_mask, 1-alpha, 0)
 
112
 
113
- title = "🎞️ Enhanced Video Background Removal Tool 🎥"
114
- description = """
115
- *Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.
116
- This enhanced version includes improved mask processing, temporal smoothing, and color correction for better results.*
117
- """
118
 
119
- examples = [['./input.mp4']]
120
 
121
  iface = gr.Interface(
122
- fn=doo,
123
- inputs=["video", gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.')],
124
  outputs="video",
125
- examples=examples,
126
- title=title,
127
- description=description
128
  )
129
 
130
  iface.launch()
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ import tempfile
5
  import time
6
  import random
7
  from PIL import Image
 
 
8
  from transparent_background import Remover
9
 
10
+ def process_video(video, mode, progress=gr.Progress()):
11
+ remover = Remover(mode='fast' if mode == 'Fast' else 'base')
12
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  cap = cv2.VideoCapture(video)
14
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
15
+ fps = cap.get(cv2.CAP_PROP_FPS)
16
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
17
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
18
+
19
+ output_filename = f"{random.randint(111111111, 999999999)}.mp4"
20
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
21
+ out = cv2.VideoWriter(output_filename, fourcc, fps, (width, height))
22
+
23
+ prev_frame = None
24
 
25
+ for frame_num in range(total_frames):
26
  ret, frame = cap.read()
27
+ if not ret:
28
  break
29
 
30
+ progress(frame_num / total_frames, desc=f"Processing frame {frame_num+1}/{total_frames}")
 
 
 
 
 
31
 
32
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33
+ pil_image = Image.fromarray(rgb_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # 배경 제거
36
+ output = remover.process(pil_image)
37
+ output = np.array(output)
 
 
 
 
 
 
38
 
39
+ # 알파 채널이 있다면 마스크 사용, 없 그레이스케일로 변환
40
+ if output.shape[2] == 4:
41
+ mask = output[:,:,3]
42
+ else:
43
+ mask = cv2.cvtColor(output, cv2.COLOR_RGB2GRAY)
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # 마스크 임계값 처리
46
+ _, mask = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY)
47
 
48
+ # 움직임 검출
49
+ if prev_frame is not None:
50
+ diff = cv2.absdiff(frame, prev_frame)
51
+ motion_mask = cv2.threshold(cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY), 30, 255, cv2.THRESH_BINARY)[1]
52
+ motion_mask = cv2.dilate(motion_mask, np.ones((5,5), np.uint8), iterations=2)
53
+ mask = cv2.bitwise_or(mask, motion_mask)
54
 
55
+ prev_frame = frame.copy()
 
 
 
 
56
 
57
+ # 마스크 적용
58
+ mask = cv2.GaussianBlur(mask, (5, 5), 0)
59
+ mask = mask.astype(float) / 255.0
60
+ mask = np.stack([mask] * 3, axis=2)
61
+ result = frame.astype(float) * mask
62
 
63
+ # 결과 저장
64
+ out.write(result.astype(np.uint8))
65
+
66
+ cap.release()
67
+ out.release()
68
 
69
+ return output_filename
70
 
71
  iface = gr.Interface(
72
+ fn=process_video,
73
+ inputs=["video", gr.Radio(["Normal", "Fast"], label="Processing mode")],
74
  outputs="video",
75
+ title="Video Background Removal",
76
+ description="Upload a video to remove its background."
 
77
  )
78
 
79
  iface.launch()