ahmad walidurosyad commited on
Commit
9fd445b
·
1 Parent(s): a6408e4
backend/config.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from enum import Enum, unique
3
+ warnings.filterwarnings('ignore')
4
+ import os
5
+ import torch
6
+ import logging
7
+ import platform
8
+ import stat
9
+ from fsplit.filesplit import Filesplit
10
+ import onnxruntime as ort
11
+
12
+ # 项目版本号
13
+ VERSION = "1.1.1"
14
+ # ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
15
+ logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
16
+ logging.disable(logging.WARNING) # 关闭WARNING日志的打印
17
+ try:
18
+ import torch_directml
19
+ device = torch_directml.device(torch_directml.default_device())
20
+ USE_DML = True
21
+ except:
22
+ USE_DML = False
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
26
+ STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
27
+ VIDEO_INPAINT_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'video')
28
+ MODEL_VERSION = 'V4'
29
+ DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
30
+ DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
31
+
32
+ # 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件
33
+ if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)):
34
+ fs = Filesplit()
35
+ fs.merge(input_dir=LAMA_MODEL_PATH)
36
+
37
+ if 'inference.pdiparams' not in os.listdir(DET_MODEL_PATH):
38
+ fs = Filesplit()
39
+ fs.merge(input_dir=DET_MODEL_PATH)
40
+
41
+ if 'ProPainter.pth' not in os.listdir(VIDEO_INPAINT_MODEL_PATH):
42
+ fs = Filesplit()
43
+ fs.merge(input_dir=VIDEO_INPAINT_MODEL_PATH)
44
+
45
+ # 指定ffmpeg可执行程序路径
46
+ sys_str = platform.system()
47
+ if sys_str == "Windows":
48
+ ffmpeg_bin = os.path.join('win_x64', 'ffmpeg.exe')
49
+ elif sys_str == "Linux":
50
+ ffmpeg_bin = os.path.join('linux_x64', 'ffmpeg')
51
+ else:
52
+ ffmpeg_bin = os.path.join('macos', 'ffmpeg')
53
+ FFMPEG_PATH = os.path.join(BASE_DIR, '', 'ffmpeg', ffmpeg_bin)
54
+
55
+ if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')):
56
+ fs = Filesplit()
57
+ fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'))
58
+ # 将ffmpeg添加可执行权限
59
+ os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
60
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
61
+
62
+ # 是否使用ONNX(DirectML/AMD/Intel)
63
+ ONNX_PROVIDERS = []
64
+ available_providers = ort.get_available_providers()
65
+ for provider in available_providers:
66
+ if provider in [
67
+ "CPUExecutionProvider"
68
+ ]:
69
+ continue
70
+ if provider not in [
71
+ "DmlExecutionProvider", # DirectML,适用于 Windows GPU
72
+ "ROCMExecutionProvider", # AMD ROCm
73
+ "MIGraphXExecutionProvider", # AMD MIGraphX
74
+ "VitisAIExecutionProvider", # AMD VitisAI,适用于 RyzenAI & Windows, 实测和DirectML性能似乎差不多
75
+ "OpenVINOExecutionProvider", # Intel GPU
76
+ "MetalExecutionProvider", # Apple macOS
77
+ "CoreMLExecutionProvider", # Apple macOS
78
+ "CUDAExecutionProvider", # Nvidia GPU
79
+ ]:
80
+ continue
81
+ ONNX_PROVIDERS.append(provider)
82
+ # ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
83
+
84
+
85
+ @unique
86
+ class InpaintMode(Enum):
87
+ """
88
+ 图像重绘算法枚举
89
+ """
90
+ STTN = 'sttn'
91
+ LAMA = 'lama'
92
+ PROPAINTER = 'propainter'
93
+ STABLE_DIFFUSION = 'sd' # Stable Diffusion Inpainting
94
+ DIFFUERASER = 'diffueraser' # DiffuEraser (diffusion-based)
95
+ E2FGVI = 'e2fgvi' # Flow-guided video inpainting
96
+
97
+
98
+ # ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
99
+ # 是否使用h264编码,如果需要安卓手机分享生成的视频,请打开该选项
100
+ USE_H264 = True
101
+
102
+ # ×××××××××× 通用设置 start ××××××××××
103
+ """
104
+ MODE可选算法类型
105
+ - InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
106
+ - InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
107
+ - InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
108
+ """
109
+ # 默认重绘算法模式 sttn/lama/propainter/sd/diffueraser/e2fgvi
110
+ MODE = InpaintMode.STTN
111
+
112
+ # ×××××××××××××××××××× Stable Diffusion Settings ××××××××××××××××××××
113
+ SD_MODEL_PATH = 'backend/models/stable-diffusion-inpainting'
114
+ SD_STEPS = 50 # Inference steps
115
+ SD_GUIDANCE_SCALE = 7.5 # Classifier-free guidance
116
+ SD_PROMPT = "natural scene, high quality" # Text prompt for guidance
117
+ SD_USE_FP16 = True # Use half precision for faster inference
118
+
119
+ # ×××××××××××××××××××× DiffuEraser Settings ××××××××××××××××××××
120
+ DIFFUERASER_MODEL_PATH = 'backend/models/diffueraser'
121
+ DIFFUERASER_STEPS = 50 # Diffusion steps
122
+ DIFFUERASER_GUIDANCE = 7.5 # Guidance scale
123
+ DIFFUERASER_USE_SAM2 = False # Auto-masking with SAM2
124
+ DIFFUERASER_MAX_LOAD_NUM = 80 # Max frames per batch
125
+
126
+ # ×××××××××××××××××××× E2FGVI Settings ××××××××××××××××××××
127
+ E2FGVI_MODEL_PATH = 'backend/models/e2fgvi'
128
+ E2FGVI_MAX_LOAD_NUM = 80 # Max frames per batch
129
+ E2FGVI_NEIGHBOR_LENGTH = 10 # Temporal window for flow
130
+ # 【设置像素点偏差】
131
+ # 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
132
+ THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10
133
+ # 用于放大mask大小,防止自动检测的文本框过小,inpaint阶段出现文字边,有残留
134
+ SUBTITLE_AREA_DEVIATION_PIXEL = 20
135
+ # 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行
136
+ THRESHOLD_HEIGHT_DIFFERENCE = 20
137
+ # 用于判断两个字幕文本的矩形框是否相似,如果X轴和Y轴偏差都在指定阈值内,则认为时同一个文本框
138
+ PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差的像素点数
139
+ PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
140
+ # ×××××××××× 通用设置 end ××××××××××
141
+
142
+ # ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
143
+ # 以下参数仅适用STTN算法时,才生效
144
+ """
145
+ 1. STTN_SKIP_DETECTION
146
+ 含义:是否使用跳过检测
147
+ 效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
148
+
149
+ 2. STTN_NEIGHBOR_STRIDE
150
+ 含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。
151
+ 效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
152
+
153
+ 3. STTN_REFERENCE_LENGTH
154
+ 含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
155
+ 效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
156
+
157
+ 4. STTN_MAX_LOAD_NUM
158
+ 含义:STTN算法每次最多加载的视频帧数量
159
+ 效果:设置越大速度越慢,但效果越好
160
+ 注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
161
+ """
162
+ STTN_SKIP_DETECTION = True
163
+ # 参考帧步长
164
+ STTN_NEIGHBOR_STRIDE = 5
165
+ # 参考帧长度(数量)
166
+ STTN_REFERENCE_LENGTH = 10
167
+ # 设置STTN算法最大同时处理的帧数量
168
+ STTN_MAX_LOAD_NUM = 50
169
+ if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
170
+ STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
171
+ # ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
172
+
173
+ # ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
174
+ # 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
175
+ # 1280x720p视频设置80需要25G显存,设置50需要19G显存
176
+ # 720x480p视频设置80需要8G显存,设置50需要7G显存
177
+ PROPAINTER_MAX_LOAD_NUM = 70
178
+ # ×××××××××× InpaintMode.PROPAINTER算法设置 end ××××××××××
179
+
180
+ # ×××××××××× InpaintMode.LAMA算法设置 start ××××××××××
181
+ # 是否开启极速模式,开启后不保证inpaint效果,仅仅对包含文本的区域文本进行去除
182
+ LAMA_SUPER_FAST = False
183
+ # ×××××××××× InpaintMode.LAMA算法设置 end ××××××××××
184
+ # ×××××××××××××××××××× [可以改] end ××××××××××××××××××××
backend/scenedetect/detectors/motion_detector.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # PySceneDetect: Python-Based Video Scene Detector
4
+ # -------------------------------------------------------------------
5
+ # [ Site: https://scenedetect.com ]
6
+ # [ Docs: https://scenedetect.com/docs/ ]
7
+ # [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
8
+ #
9
+ # Copyright (C) 2014-2023 Brandon Castellano <http://www.bcastell.com>.
10
+ # PySceneDetect is licensed under the BSD 3-Clause License; see the
11
+ # included LICENSE file, or visit one of the above pages for details.
12
+ #
13
+ """:class:`MotionDetector`, detects motion events using background subtraction, morphological
14
+ transforms, and thresholding."""
15
+
16
+ # Third-Party Library Imports
17
+ import cv2
18
+
19
+ # PySceneDetect Library Imports
20
+ from scenedetect.scene_detector import SparseSceneDetector
21
+
22
+
23
+ class MotionDetector(SparseSceneDetector):
24
+ """Detects motion events in scenes containing a static background.
25
+
26
+ Uses background subtraction followed by noise removal (via morphological
27
+ opening) to generate a frame score compared against the set threshold.
28
+
29
+ Attributes:
30
+ threshold: floating point value compared to each frame's score, which
31
+ represents average intensity change per pixel (lower values are
32
+ more sensitive to motion changes). Default 0.5, must be > 0.0.
33
+ num_frames_post_scene: Number of frames to include in each motion
34
+ event after the frame score falls below the threshold, adding any
35
+ subsequent motion events to the same scene.
36
+ kernel_size: Size of morphological opening kernel for noise removal.
37
+ Setting to -1 (default) will auto-compute based on video resolution
38
+ (typically 3 for SD, 5-7 for HD). Must be an odd integer > 1.
39
+ """
40
+
41
+ def __init__(self, threshold=0.50, num_frames_post_scene=30, kernel_size=-1):
42
+ """Initializes motion-based scene detector object."""
43
+ # TODO: Requires porting to v0.5 API.
44
+ raise NotImplementedError()
45
+ """
46
+ self.threshold = float(threshold)
47
+ self.num_frames_post_scene = int(num_frames_post_scene)
48
+
49
+ self.kernel_size = int(kernel_size)
50
+ if self.kernel_size < 0:
51
+ # Set kernel size when process_frame first runs based on
52
+ # video resolution (480p = 3x3, 720p = 5x5, 1080p = 7x7).
53
+ pass
54
+
55
+ self.bg_subtractor = cv2.createBackgroundSubtractorMOG2(
56
+ detectShadows = False )
57
+
58
+ self.last_frame_score = 0.0
59
+
60
+ self.in_motion_event = False
61
+ self.first_motion_frame_index = -1
62
+ self.last_motion_frame_index = -1
63
+ """
64
+
65
+ def process_frame(self, frame_num, frame_img):
66
+ # TODO.
67
+ """
68
+ frame_grayscale = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
69
+ masked_frame = self.bg_subtractor.apply(frame_grayscale)
70
+
71
+ kernel = numpy.ones((self.kernel_size, self.kernel_size), numpy.uint8)
72
+ filtered_frame = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, kernel)
73
+
74
+ frame_score = numpy.sum(filtered_frame) / float(
75
+ filtered_frame.shape[0] * filtered_frame.shape[1] )
76
+ """
77
+ return []
78
+
79
+ def post_process(self, frame_num):
80
+ """Writes the last scene if the video ends while in a motion event.
81
+ """
82
+
83
+ # If the last fade detected was a fade out, we add a corresponding new
84
+ # scene break to indicate the end of the scene. This is only done for
85
+ # fade-outs, as a scene cut is already added when a fade-in is found.
86
+ """
87
+ if self.in_motion_event:
88
+ # Write new scene based on first and last motion event frames.
89
+ pass
90
+ return self.in_motion_event
91
+ """
92
+ return []
backend/scenedetect/detectors/threshold_detector.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # PySceneDetect: Python-Based Video Scene Detector
4
+ # -------------------------------------------------------------------
5
+ # [ Site: https://scenedetect.com ]
6
+ # [ Docs: https://scenedetect.com/docs/ ]
7
+ # [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
8
+ #
9
+ # Copyright (C) 2014-2023 Brandon Castellano <http://www.bcastell.com>.
10
+ # PySceneDetect is licensed under the BSD 3-Clause License; see the
11
+ # included LICENSE file, or visit one of the above pages for details.
12
+ #
13
+ """:class:`ThresholdDetector` uses a set intensity as a threshold to scene_detect cuts, which are
14
+ triggered when the average pixel intensity exceeds or falls below this threshold.
15
+
16
+ This detector is available from the command-line as the `scene_detect-threshold` command.
17
+ """
18
+
19
+ from enum import Enum
20
+ from logging import getLogger
21
+ from typing import List, Optional
22
+
23
+ import numpy
24
+
25
+ from backend.scenedetect.scene_detector import SceneDetector
26
+
27
+ logger = getLogger('pyscenedetect')
28
+
29
+ ##
30
+ ## ThresholdDetector Helper Functions
31
+ ##
32
+
33
+
34
+ def _compute_frame_average(frame: numpy.ndarray) -> float:
35
+ """Computes the average pixel value/intensity for all pixels in a frame.
36
+
37
+ The value is computed by adding up the 8-bit R, G, and B values for
38
+ each pixel, and dividing by the number of pixels multiplied by 3.
39
+
40
+ Arguments:
41
+ frame: Frame representing the RGB pixels to average.
42
+
43
+ Returns:
44
+ Average pixel intensity across all 3 channels of `frame`
45
+ """
46
+ num_pixel_values = float(frame.shape[0] * frame.shape[1] * frame.shape[2])
47
+ avg_pixel_value = numpy.sum(frame[:, :, :]) / num_pixel_values
48
+ return avg_pixel_value
49
+
50
+
51
+ ##
52
+ ## ThresholdDetector Class Implementation
53
+ ##
54
+
55
+
56
+ class ThresholdDetector(SceneDetector):
57
+ """Detects fast cuts/slow fades in from and out to a given threshold level.
58
+
59
+ Detects both fast cuts and slow fades so long as an appropriate threshold
60
+ is chosen (especially taking into account the minimum grey/black level).
61
+ """
62
+
63
+ class Method(Enum):
64
+ """Method for ThresholdDetector to use when comparing frame brightness to the threshold."""
65
+ FLOOR = 0
66
+ """Fade out happens when frame brightness falls below threshold."""
67
+ CEILING = 1
68
+ """Fade out happens when frame brightness rises above threshold."""
69
+
70
+ THRESHOLD_VALUE_KEY = 'average_rgb'
71
+
72
+ def __init__(
73
+ self,
74
+ threshold: float = 12,
75
+ min_scene_len: int = 15,
76
+ fade_bias: float = 0.0,
77
+ add_final_scene: bool = False,
78
+ method: Method = Method.FLOOR,
79
+ block_size=None,
80
+ ):
81
+ """
82
+ Arguments:
83
+ threshold: 8-bit intensity value that each pixel value (R, G, and B)
84
+ must be <= to in order to trigger a fade in/out.
85
+ min_scene_len: FrameTimecode object or integer greater than 0 of the
86
+ minimum length, in frames, of a scene (or subsequent scene cut).
87
+ fade_bias: Float between -1.0 and +1.0 representing the percentage of
88
+ timecode skew for the start of a scene (-1.0 causing a cut at the
89
+ fade-to-black, 0.0 in the middle, and +1.0 causing the cut to be
90
+ right at the position where the threshold is passed).
91
+ add_final_scene: Boolean indicating if the video ends on a fade-out to
92
+ generate an additional scene at this timecode.
93
+ method: How to treat `threshold` when detecting fade events.
94
+ block_size: [DEPRECATED] DO NOT USE. For backwards compatibility.
95
+ """
96
+ # TODO(v0.7): Replace with DeprecationWarning that `block_size` will be removed in v0.8.
97
+ if block_size is not None:
98
+ logger.error('block_size is deprecated.')
99
+
100
+ super().__init__()
101
+ self.threshold = int(threshold)
102
+ self.method = ThresholdDetector.Method(method)
103
+ self.fade_bias = fade_bias
104
+ self.min_scene_len = min_scene_len
105
+ self.processed_frame = False
106
+ self.last_scene_cut = None
107
+ # Whether to add an additional scene or not when ending on a fade out
108
+ # (as cuts are only added on fade ins; see post_process() for details).
109
+ self.add_final_scene = add_final_scene
110
+ # Where the last fade (threshold crossing) was detected.
111
+ self.last_fade = {
112
+ 'frame': 0, # frame number where the last detected fade is
113
+ 'type': None # type of fade, can be either 'in' or 'out'
114
+ }
115
+ self._metric_keys = [ThresholdDetector.THRESHOLD_VALUE_KEY]
116
+
117
+ def get_metrics(self) -> List[str]:
118
+ return self._metric_keys
119
+
120
+ def process_frame(self, frame_num: int, frame_img: Optional[numpy.ndarray]) -> List[int]:
121
+ """
122
+ Args:
123
+ frame_num (int): Frame number of frame that is being passed.
124
+ frame_img (numpy.ndarray or None): Decoded frame image (numpy.ndarray) to perform
125
+ scene detection with. Can be None *only* if the self.is_processing_required()
126
+ method (inhereted from the base SceneDetector class) returns True.
127
+ Returns:
128
+ List[int]: List of frames where scene cuts have been detected. There may be 0
129
+ or more frames in the list, and not necessarily the same as frame_num.
130
+ """
131
+
132
+ # Initialize last scene cut point at the beginning of the frames of interest.
133
+ if self.last_scene_cut is None:
134
+ self.last_scene_cut = frame_num
135
+
136
+ # Compare the # of pixels under threshold in current_frame & last_frame.
137
+ # If absolute value of pixel intensity delta is above the threshold,
138
+ # then we trigger a new scene cut/break.
139
+
140
+ # List of cuts to return.
141
+ cut_list = []
142
+
143
+ # The metric used here to scene_detect scene breaks is the percent of pixels
144
+ # less than or equal to the threshold; however, since this differs on
145
+ # user-supplied values, we supply the average pixel intensity as this
146
+ # frame metric instead (to assist with manually selecting a threshold)
147
+ if (self.stats_manager is not None) and (self.stats_manager.metrics_exist(
148
+ frame_num, self._metric_keys)):
149
+ frame_avg = self.stats_manager.get_metrics(frame_num, self._metric_keys)[0]
150
+ else:
151
+ frame_avg = _compute_frame_average(frame_img)
152
+ if self.stats_manager is not None:
153
+ self.stats_manager.set_metrics(frame_num, {self._metric_keys[0]: frame_avg})
154
+
155
+ if self.processed_frame:
156
+ if self.last_fade['type'] == 'in' and ((
157
+ (self.method == ThresholdDetector.Method.FLOOR and frame_avg < self.threshold) or
158
+ (self.method == ThresholdDetector.Method.CEILING and frame_avg >= self.threshold))):
159
+ # Just faded out of a scene, wait for next fade in.
160
+ self.last_fade['type'] = 'out'
161
+ self.last_fade['frame'] = frame_num
162
+
163
+ elif self.last_fade['type'] == 'out' and (
164
+ (self.method == ThresholdDetector.Method.FLOOR and frame_avg >= self.threshold) or
165
+ (self.method == ThresholdDetector.Method.CEILING and frame_avg < self.threshold)):
166
+ # Only add the scene if min_scene_len frames have passed.
167
+ if (frame_num - self.last_scene_cut) >= self.min_scene_len:
168
+ # Just faded into a new scene, compute timecode for the scene
169
+ # split based on the fade bias.
170
+ f_out = self.last_fade['frame']
171
+ f_split = int(
172
+ (frame_num + f_out + int(self.fade_bias * (frame_num - f_out))) / 2)
173
+ cut_list.append(f_split)
174
+ self.last_scene_cut = frame_num
175
+ self.last_fade['type'] = 'in'
176
+ self.last_fade['frame'] = frame_num
177
+ else:
178
+ self.last_fade['frame'] = 0
179
+ if frame_avg < self.threshold:
180
+ self.last_fade['type'] = 'out'
181
+ else:
182
+ self.last_fade['type'] = 'in'
183
+ self.processed_frame = True
184
+ return cut_list
185
+
186
+ def post_process(self, frame_num: int):
187
+ """Writes a final scene cut if the last detected fade was a fade-out.
188
+
189
+ Only writes the scene cut if add_final_scene is true, and the last fade
190
+ that was detected was a fade-out. There is no bias applied to this cut
191
+ (since there is no corresponding fade-in) so it will be located at the
192
+ exact frame where the fade-out crossed the detection threshold.
193
+ """
194
+
195
+ # If the last fade detected was a fade out, we add a corresponding new
196
+ # scene break to indicate the end of the scene. This is only done for
197
+ # fade-outs, as a scene cut is already added when a fade-in is found.
198
+ cut_times = []
199
+ if self.last_fade['type'] == 'out' and self.add_final_scene and (
200
+ (self.last_scene_cut is None and frame_num >= self.min_scene_len) or
201
+ (frame_num - self.last_scene_cut) >= self.min_scene_len):
202
+ cut_times.append(self.last_fade['frame'])
203
+ return cut_times
backend/tools/common_tools.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ video_extensions = {
4
+ '.mp4', '.m4a', '.m4v', '.f4v', '.f4a', '.m4b', '.m4r', '.f4b', '.mov',
5
+ '.3gp', '.3gp2', '.3g2', '.3gpp', '.3gpp2', '.ogg', '.oga', '.ogv', '.ogx',
6
+ '.wmv', '.wma', '.asf', '.webm', '.flv', '.avi', '.gifv', '.mkv', '.rm',
7
+ '.rmvb', '.vob', '.dvd', '.mpg', '.mpeg', '.mp2', '.mpe', '.mpv', '.mpg',
8
+ '.mpeg', '.m2v', '.svi', '.3gp', '.mxf', '.roq', '.nsv', '.flv', '.f4v',
9
+ '.f4p', '.f4a', '.f4b'
10
+ }
11
+
12
+ image_extensions = {
13
+ '.jpg', '.jpeg', '.jpe', '.jif', '.jfif', '.jfi', '.png', '.gif',
14
+ '.webp', '.tiff', '.tif', '.psd', '.raw', '.arw', '.cr2', '.nrw',
15
+ '.k25', '.bmp', '.dib', '.heif', '.heic', '.ind', '.indd', '.indt',
16
+ '.jp2', '.j2k', '.jpf', '.jpx', '.jpm', '.mj2', '.svg', '.svgz',
17
+ '.ai', '.eps', '.ico'
18
+ }
19
+
20
+
21
+ def is_video_file(filename):
22
+ return os.path.splitext(filename)[-1].lower() in video_extensions
23
+
24
+
25
+ def is_image_file(filename):
26
+ return os.path.splitext(filename)[-1].lower() in image_extensions
27
+
28
+
29
+ def is_video_or_image(filename):
30
+ file_extension = os.path.splitext(filename)[-1].lower()
31
+ # 检查扩展名是否在定义的视频或图片文件后缀集合中
32
+ return file_extension in video_extensions or file_extension in image_extensions
backend/tools/inpaint_tools.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from backend import config
6
+ from backend.inpaint.lama_inpaint import LamaInpaint
7
+
8
+
9
+ def batch_generator(data, max_batch_size):
10
+ """
11
+ 根据data大小,生成最大长度不超过max_batch_size的均匀批次数据
12
+ """
13
+ n_samples = len(data)
14
+ # 尝试找到一个比MAX_BATCH_SIZE小的batch_size,以使得所有的批次数量尽量接近
15
+ batch_size = max_batch_size
16
+ num_batches = n_samples // batch_size
17
+
18
+ # 处理最后一批可能不足batch_size的情况
19
+ # 如果最后一批少于其他批次,则减小batch_size尝试平衡每批的数量
20
+ while n_samples % batch_size < batch_size / 2.0 and batch_size > 1:
21
+ batch_size -= 1 # 减小批次大小
22
+ num_batches = n_samples // batch_size
23
+
24
+ # 生成前num_batches个批次
25
+ for i in range(num_batches):
26
+ yield data[i * batch_size:(i + 1) * batch_size]
27
+
28
+ # 将剩余的数据作为最后一个批次
29
+ last_batch_start = num_batches * batch_size
30
+ if last_batch_start < n_samples:
31
+ yield data[last_batch_start:]
32
+
33
+
34
+ def inference_task(batch_data):
35
+ inpainted_frame_dict = dict()
36
+ for data in batch_data:
37
+ index, original_frame, coords_list = data
38
+ mask_size = original_frame.shape[:2]
39
+ mask = create_mask(mask_size, coords_list)
40
+ inpaint_frame = inpaint(original_frame, mask)
41
+ inpainted_frame_dict[index] = inpaint_frame
42
+ return inpainted_frame_dict
43
+
44
+
45
+ def parallel_inference(inputs, batch_size=None, pool_size=None):
46
+ """
47
+ 并行推理,同时保持结果顺序
48
+ """
49
+ if pool_size is None:
50
+ pool_size = multiprocessing.cpu_count()
51
+ # 使用上下文管理器自动管理进程池
52
+ with multiprocessing.Pool(processes=pool_size) as pool:
53
+ batched_inputs = list(batch_generator(inputs, batch_size))
54
+ # 使用map函数保证输入输出的顺序是一致的
55
+ batch_results = pool.map(inference_task, batched_inputs)
56
+ # 将批推理结果展平
57
+ index_inpainted_frames = [item for sublist in batch_results for item in sublist]
58
+ return index_inpainted_frames
59
+
60
+
61
+ def inpaint(img, mask):
62
+ lama_inpaint_instance = LamaInpaint()
63
+ img_inpainted = lama_inpaint_instance(img, mask)
64
+ return img_inpainted
65
+
66
+
67
+ def inpaint_with_multiple_masks(censored_img, mask_list):
68
+ inpainted_frame = censored_img
69
+ if mask_list:
70
+ for mask in mask_list:
71
+ inpainted_frame = inpaint(inpainted_frame, mask)
72
+ return inpainted_frame
73
+
74
+
75
+ def create_mask(size, coords_list):
76
+ mask = np.zeros(size, dtype="uint8")
77
+ if coords_list:
78
+ for coords in coords_list:
79
+ xmin, xmax, ymin, ymax = coords
80
+ # 为了避免框过小,放大10个像素
81
+ x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL
82
+ if x1 < 0:
83
+ x1 = 0
84
+ y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL
85
+ if y1 < 0:
86
+ y1 = 0
87
+ x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL
88
+ y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL
89
+ cv2.rectangle(mask, (x1, y1),
90
+ (x2, y2), (255, 255, 255), thickness=-1)
91
+ return mask
92
+
93
+
94
+ def inpaint_video(video_path, sub_list):
95
+ index = 0
96
+ frame_to_inpaint_list = []
97
+ video_cap = cv2.VideoCapture(video_path)
98
+ while True:
99
+ # 读取视频帧
100
+ ret, frame = video_cap.read()
101
+ if not ret:
102
+ break
103
+ index += 1
104
+ if index in sub_list.keys():
105
+ frame_to_inpaint_list.append((index, frame, sub_list[index]))
106
+ if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM:
107
+ batch_results = parallel_inference(frame_to_inpaint_list)
108
+ for index, frame in batch_results:
109
+ file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'
110
+ cv2.imwrite(file_name, frame)
111
+ print(f"success write: {file_name}")
112
+ frame_to_inpaint_list.clear()
113
+ print(f'finished')
114
+
115
+
116
+ if __name__ == '__main__':
117
+ multiprocessing.set_start_method("spawn")
backend/tools/makedist.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from qpt.executor import CreateExecutableModule as CEM
4
+ from qpt.modules.cuda import CopyCUDAPackage
5
+ from qpt.smart_opt import set_default_pip_source
6
+ from qpt.kernel.qinterpreter import PYPI_PIP_SOURCE
7
+ from qpt.modules.package import CustomPackage, DEFAULT_DEPLOY_MODE
8
+
9
+
10
+
11
+ def main():
12
+ WORK_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ LAUNCH_PATH = os.path.join(WORK_DIR, 'gui.py')
14
+ SAVE_PATH = os.path.join(os.path.dirname(WORK_DIR), 'vsr_out')
15
+ ICON_PATH = os.path.join(WORK_DIR, "design", "vsr.ico")
16
+
17
+ # 解析命令行参数
18
+ parser = argparse.ArgumentParser(description="打包程序")
19
+ parser.add_argument(
20
+ "--cuda",
21
+ nargs="?", # 可选参数值
22
+ const="11.8", # 如果只写 --cuda,默认值是 10.2
23
+ default=None, # 不写 --cuda,则为 None
24
+ help="是否包含CUDA模块,可指定版本,如 --cuda 或 --cuda=11.8"
25
+ )
26
+ parser.add_argument(
27
+ "--directml",
28
+ nargs="?", # 可选参数值
29
+ const=True, # 如果只写 --directml,默认为True
30
+ default=None, # 不写 --directml,则为 None
31
+ help="是否使用DirectML加速,仅指定 --directml 即可启用"
32
+ )
33
+
34
+ args = parser.parse_args()
35
+
36
+ sub_modules = []
37
+
38
+ if args.cuda == "11.8":
39
+ sub_modules.append(CustomPackage("torch==2.7.0 torchvision==0.22.0", deploy_mode=DEFAULT_DEPLOY_MODE, find_links=PYPI_PIP_SOURCE, opts="--index-url https://download.pytorch.org/whl/cu118 "))
40
+ elif args.cuda == "12.6":
41
+ sub_modules.append(CustomPackage("torch==2.7.0 torchvision==0.22.0", deploy_mode=DEFAULT_DEPLOY_MODE, find_links=PYPI_PIP_SOURCE, opts="--index-url https://download.pytorch.org/whl/cu126 "))
42
+ elif args.cuda == "12.8":
43
+ sub_modules.append(CustomPackage("torch==2.7.0 torchvision==0.22.0", deploy_mode=DEFAULT_DEPLOY_MODE, find_links=PYPI_PIP_SOURCE, opts="--index-url https://download.pytorch.org/whl/cu128 "))
44
+
45
+ if args.directml:
46
+ sub_modules.append(CustomPackage("torch_directml==0.2.5.dev240914", deploy_mode=DEFAULT_DEPLOY_MODE))
47
+
48
+ if os.getenv("QPT_Action") == "True":
49
+ set_default_pip_source(PYPI_PIP_SOURCE)
50
+
51
+ module = CEM(
52
+ work_dir=WORK_DIR,
53
+ launcher_py_path=LAUNCH_PATH,
54
+ save_path=SAVE_PATH,
55
+ icon=ICON_PATH,
56
+ hidden_terminal=False,
57
+ requirements_file="./requirements.txt",
58
+ sub_modules=sub_modules,
59
+ )
60
+
61
+ module.make()
62
+
63
+
64
+ if __name__ == '__main__':
65
+ main()
backend/tools/merge_video.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ def merge_video(video_input_path0, video_input_path1, video_output_path):
5
+ """
6
+ 将两个视频文件安装水平方向合并
7
+ """
8
+ input_video_cap0 = cv2.VideoCapture(video_input_path0)
9
+ input_video_cap1 = cv2.VideoCapture(video_input_path1)
10
+ fps = input_video_cap1.get(cv2.CAP_PROP_FPS)
11
+ size = (int(input_video_cap1.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video_cap1.get(cv2.CAP_PROP_FRAME_HEIGHT)) * 2)
12
+ video_writer = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
13
+ while True:
14
+ ret0, frame0 = input_video_cap0.read()
15
+ ret1, frame1 = input_video_cap1.read()
16
+ if not ret1 and not ret0:
17
+ break
18
+ else:
19
+ show = cv2.vconcat([frame0, frame1])
20
+ video_writer.write(show)
21
+ video_writer.release()
22
+
23
+
24
+ if __name__ == '__main__':
25
+ v0_path = '../../test/test4.mp4'
26
+ v1_path = '../../test/test4_no_sub(1).mp4'
27
+ video_out_path = '../../test/demo.mp4'
28
+ merge_video(v0_path, v1_path, video_out_path)
29
+ # ffmpeg 命令 mp4转gif
30
+ # ffmpeg -i demo3.mp4 -vf "scale=w=720:h=-1,fps=15,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 -r 15 -f gif output.gif
31
+ # 宽度固定400,高度成比例:
32
+ # ffmpeg - i input.avi -vf scale=400:-2
backend/tools/train/dataset_sttn.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from torch.utils.data import DataLoader
7
+ from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
8
+ from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
9
+
10
+
11
+ # 自定义的数据集
12
+ class Dataset(torch.utils.data.Dataset):
13
+ def __init__(self, args: dict, split='train', debug=False):
14
+ # 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
15
+ self.args = args
16
+ self.split = split
17
+ self.sample_length = args['sample_length'] # 样本长度参数
18
+ self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
19
+
20
+ # 打开存放数据相关信息的json文件
21
+ with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
22
+ self.video_dict = json.load(f) # 加载json文件内容
23
+ self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
24
+ if debug or split != 'train': # 如果是调试模式或者不是训练集,只取前100个视频
25
+ self.video_names = self.video_names[:100]
26
+
27
+ # 定义数据的转换操作,转换成堆叠的张量
28
+ self._to_tensors = transforms.Compose([
29
+ Stack(),
30
+ ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
31
+ ])
32
+
33
+ def __len__(self):
34
+ # 返回数据集中视频的数量
35
+ return len(self.video_names)
36
+
37
+ def __getitem__(self, index):
38
+ # 获取一个样本项
39
+ try:
40
+ item = self.load_item(index) # 尝试加载指定索引的数据项
41
+ except:
42
+ print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
43
+ item = self.load_item(0) # 加载第一个项目作为兜底
44
+ return item
45
+
46
+ def load_item(self, index):
47
+ # 加载数据项的具体实现
48
+ video_name = self.video_names[index] # 根据索引获取视频名称
49
+ # 为所有视频帧生成帧文件名列表
50
+ all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
51
+ # 生成随机运动的随机形状的遮罩
52
+ all_masks = create_random_shape_with_random_motion(
53
+ len(all_frames), imageHeight=self.h, imageWidth=self.w)
54
+ # 获取参考帧的索引
55
+ ref_index = get_ref_index(len(all_frames), self.sample_length)
56
+ # 读取视频帧
57
+ frames = []
58
+ masks = []
59
+ for idx in ref_index:
60
+ # 读取图片,转化为RGB,调整大小并添加到列表中
61
+ img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
62
+ self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
63
+ img = img.resize(self.size)
64
+ frames.append(img)
65
+ masks.append(all_masks[idx])
66
+ if self.split == 'train':
67
+ # 如果是训练集,随机水平翻转图像
68
+ frames = GroupRandomHorizontalFlip()(frames)
69
+ # 转换成张量形式
70
+ frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
71
+ mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
72
+ return frame_tensors, mask_tensors # 返回图像和遮罩的张量
73
+
74
+
75
+ def get_ref_index(length, sample_length):
76
+ # 获取参考帧索引的实现
77
+ if random.uniform(0, 1) > 0.5:
78
+ # 有一半的概率随机选择帧
79
+ ref_index = random.sample(range(length), sample_length)
80
+ ref_index.sort() # 排序保证顺序
81
+ else:
82
+ # 另一半概率选择连续的帧
83
+ pivot = random.randint(0, length-sample_length)
84
+ ref_index = [pivot+i for i in range(sample_length)]
85
+ return ref_index
backend/tools/train/loss_sttn.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AdversarialLoss(nn.Module):
6
+ """
7
+ 对抗性损失
8
+ 根据论文 https://arxiv.org/abs/1711.10337 实现
9
+ """
10
+
11
+ def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
12
+ """
13
+ 可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
14
+ type: 指定使用哪种类型的 GAN 损失。
15
+ target_real_label: 真实图像的目标标签值。
16
+ target_fake_label: 生成图像的目标标签值。
17
+ """
18
+ super(AdversarialLoss, self).__init__()
19
+ self.type = type # 损失类型
20
+ # 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
21
+ self.register_buffer('real_label', torch.tensor(target_real_label))
22
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
23
+
24
+ # 根据选择的类型初始化不同的损失函数
25
+ if type == 'nsgan':
26
+ self.criterion = nn.BCELoss() # 二进制交叉熵损失(非饱和GAN)
27
+ elif type == 'lsgan':
28
+ self.criterion = nn.MSELoss() # 均方误差损失(最小平方GAN)
29
+ elif type == 'hinge':
30
+ self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
31
+
32
+ def __call__(self, outputs, is_real, is_disc=None):
33
+ """
34
+ 调用函数计算损失。
35
+ outputs: 网络输出。
36
+ is_real: 如果是真实样本,则为 True;如果是生成样本,则为 False。
37
+ is_disc: 指示当前是否在优化判别器。
38
+ """
39
+ if self.type == 'hinge':
40
+ # 对于 hinge 损失
41
+ if is_disc:
42
+ # 如果是判别器
43
+ if is_real:
44
+ outputs = -outputs # 对真实样本反向标签
45
+ # max(0, 1 - (真/假)示例输出)
46
+ return self.criterion(1 + outputs).mean()
47
+ else:
48
+ # 如果是生成器, -min(0, -输出) = max(0, 输出)
49
+ return (-outputs).mean()
50
+ else:
51
+ # 对于 nsgan 和 lsgan 损失
52
+ labels = (self.real_label if is_real else self.fake_label).expand_as(
53
+ outputs)
54
+ # 计算模型输出和目标标签之间的损失
55
+ loss = self.criterion(outputs, labels)
56
+ return loss
backend/tools/train/train_sttn.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ from shutil import copyfile
5
+ import torch
6
+ import torch.multiprocessing as mp
7
+
8
+ from backend.tools.train.trainer_sttn import Trainer
9
+ from backend.tools.train.utils_sttn import (
10
+ get_world_size,
11
+ get_local_rank,
12
+ get_global_rank,
13
+ get_master_ip,
14
+ )
15
+
16
+ parser = argparse.ArgumentParser(description='STTN')
17
+ parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
18
+ parser.add_argument('-m', '--model', default='sttn', type=str)
19
+ parser.add_argument('-p', '--port', default='23455', type=str)
20
+ parser.add_argument('-e', '--exam', action='store_true')
21
+ args = parser.parse_args()
22
+
23
+
24
+ def main_worker(rank, config):
25
+ # 如果配置中没有提到局部排序(local_rank),就给它和全局排序(global_rank)赋值为传入的排序(rank)
26
+ if 'local_rank' not in config:
27
+ config['local_rank'] = config['global_rank'] = rank
28
+
29
+ # 如果配置指定为分布式训练
30
+ if config['distributed']:
31
+ # 设置使用的CUDA设备为当前的本地排名对应的GPU
32
+ torch.cuda.set_device(int(config['local_rank']))
33
+ # 初始化分布式进程组,通过nccl后端
34
+ torch.distributed.init_process_group(
35
+ backend='nccl',
36
+ init_method=config['init_method'],
37
+ world_size=config['world_size'],
38
+ rank=config['global_rank'],
39
+ group_name='mtorch'
40
+ )
41
+ # 打印当前GPU的使用情况,输出全球排名和本地排名
42
+ print('using GPU {}-{} for training'.format(
43
+ int(config['global_rank']), int(config['local_rank']))
44
+ )
45
+
46
+ # 创建模型保存的目录路径,包括模型名和配置文件名
47
+ config['save_dir'] = os.path.join(
48
+ config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
49
+ )
50
+
51
+ # 如果CUDA可用,则设置设备为相应的CUDA设备,否则为CPU
52
+ if torch.cuda.is_available():
53
+ config['device'] = torch.device("cuda:{}".format(config['local_rank']))
54
+ else:
55
+ config['device'] = 'cpu'
56
+
57
+ # 如果不是分布式训练,或者是分布式训练的主节点(rank 0)
58
+ if (not config['distributed']) or config['global_rank'] == 0:
59
+ # 创建模型保存目录,并允许如果该目录存在则忽略创建(exist_ok=True)
60
+ os.makedirs(config['save_dir'], exist_ok=True)
61
+ # 设置配置文件的保存路径
62
+ config_path = os.path.join(
63
+ config['save_dir'], config['config'].split('/')[-1]
64
+ )
65
+ # 如果配置文件不存在,则从给定的配置文件路径复制到新路径
66
+ if not os.path.isfile(config_path):
67
+ copyfile(config['config'], config_path)
68
+ # 打印创建目录的信息
69
+ print('[**] create folder {}'.format(config['save_dir']))
70
+
71
+ # 初始化训练器,传入配置参数和debug标记
72
+ trainer = Trainer(config, debug=args.exam)
73
+ # 开始训练
74
+ trainer.train()
75
+
76
+
77
+ if __name__ == "__main__":
78
+ # 加载配置文件
79
+ config = json.load(open(args.config))
80
+ config['model'] = args.model # 设置模型名称
81
+ config['config'] = args.config # 设置配置文件路径
82
+
83
+ # 设置分布式训练的相关配置
84
+ config['world_size'] = get_world_size() # 获取全局进程数,即训练过程中参与计算的总GPU数量
85
+ config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法,包括主节点IP和端口
86
+ config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
87
+
88
+ # 设置分布式并行训练环境
89
+ if get_master_ip() == "127.0.0.1":
90
+ # 如果主节点IP是本机地址,那么手动启动多个分布式训练进程
91
+ mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
92
+ else:
93
+ # 如果是由其他工具如OpenMPI启动的多个进程,不需手动创建进程。
94
+ config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
95
+ config['global_rank'] = get_global_rank() # 获取全局排名
96
+ main_worker(-1, config) # 启动主工作函数
backend/tools/train/trainer_sttn.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from tqdm import tqdm
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+ from tensorboardX import SummaryWriter
10
+
11
+ from backend.inpaint.sttn.auto_sttn import Discriminator
12
+ from backend.inpaint.sttn.auto_sttn import InpaintGenerator
13
+ from backend.tools.train.dataset_sttn import Dataset
14
+ from backend.tools.train.loss_sttn import AdversarialLoss
15
+
16
+
17
+ class Trainer:
18
+ def __init__(self, config, debug=False):
19
+ # 训练器初始化
20
+ self.config = config # 保存配置信息
21
+ self.epoch = 0 # 当前训练所处的epoch
22
+ self.iteration = 0 # 当前训练迭代次数
23
+ if debug:
24
+ # 如果是调试模式,设置更频繁的保存和验证频率
25
+ self.config['trainer']['save_freq'] = 5
26
+ self.config['trainer']['valid_freq'] = 5
27
+ self.config['trainer']['iterations'] = 5
28
+
29
+ # 设置数据集和数据加载器
30
+ self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) # 创建训练集对象
31
+ self.train_sampler = None # 初始化训练集采样器为None
32
+ self.train_args = config['trainer'] # 训练过程参数
33
+ if config['distributed']:
34
+ # 如果是分布式训练,则初始化分布式采样器
35
+ self.train_sampler = DistributedSampler(
36
+ self.train_dataset,
37
+ num_replicas=config['world_size'],
38
+ rank=config['global_rank']
39
+ )
40
+ self.train_loader = DataLoader(
41
+ self.train_dataset,
42
+ batch_size=self.train_args['batch_size'] // config['world_size'],
43
+ shuffle=(self.train_sampler is None), # 如果没有采样器则进行打乱
44
+ num_workers=self.train_args['num_workers'],
45
+ sampler=self.train_sampler
46
+ )
47
+
48
+ # 设置损失函数
49
+ self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) # 对抗性损失
50
+ self.adversarial_loss = self.adversarial_loss.to(self.config['device']) # 将损失函数转移到相应设备
51
+ self.l1_loss = nn.L1Loss() # L1损失
52
+
53
+ # 初始化生成器和判别器模型
54
+ self.netG = InpaintGenerator() # 生成网络
55
+ self.netG = self.netG.to(self.config['device']) # 转移到设备
56
+ self.netD = Discriminator(
57
+ in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge'
58
+ )
59
+ self.netD = self.netD.to(self.config['device']) # 判别网络
60
+ # 初始化优化器
61
+ self.optimG = torch.optim.Adam(
62
+ self.netG.parameters(), # 生成器参数
63
+ lr=config['trainer']['lr'], # 学习率
64
+ betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
65
+ )
66
+ self.optimD = torch.optim.Adam(
67
+ self.netD.parameters(), # 判别器参数
68
+ lr=config['trainer']['lr'], # 学习率
69
+ betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
70
+ )
71
+ self.load() # 加载模型
72
+
73
+ if config['distributed']:
74
+ # 如果是分布式训练,则使用分布式数据并行包装器
75
+ self.netG = DDP(
76
+ self.netG,
77
+ device_ids=[self.config['local_rank']],
78
+ output_device=self.config['local_rank'],
79
+ broadcast_buffers=True,
80
+ find_unused_parameters=False
81
+ )
82
+ self.netD = DDP(
83
+ self.netD,
84
+ device_ids=[self.config['local_rank']],
85
+ output_device=self.config['local_rank'],
86
+ broadcast_buffers=True,
87
+ find_unused_parameters=False
88
+ )
89
+
90
+ # 设置日志记录器
91
+ self.dis_writer = None # 判别器写入器
92
+ self.gen_writer = None # 生成器写入器
93
+ self.summary = {} # 存放摘要统计
94
+ if self.config['global_rank'] == 0 or (not config['distributed']):
95
+ # 如果不是分布式训练或者为分布式训练的主节点
96
+ self.dis_writer = SummaryWriter(
97
+ os.path.join(config['save_dir'], 'dis')
98
+ )
99
+ self.gen_writer = SummaryWriter(
100
+ os.path.join(config['save_dir'], 'gen')
101
+ )
102
+
103
+ # 获取当前学习率
104
+ def get_lr(self):
105
+ return self.optimG.param_groups[0]['lr']
106
+
107
+ # 调整学习率
108
+ def adjust_learning_rate(self):
109
+ # 计算衰减的学习率
110
+ decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
111
+ new_lr = self.config['trainer']['lr'] * decay
112
+ # 如果新的学习率和当前学习率不同,则更新优化器中的学习率
113
+ if new_lr != self.get_lr():
114
+ for param_group in self.optimG.param_groups:
115
+ param_group['lr'] = new_lr
116
+ for param_group in self.optimD.param_groups:
117
+ param_group['lr'] = new_lr
118
+
119
+ # 添加摘要信息
120
+ def add_summary(self, writer, name, val):
121
+ # 添加并更新统计信息,每次迭代都累加
122
+ if name not in self.summary:
123
+ self.summary[name] = 0
124
+ self.summary[name] += val
125
+ # 每100次迭代记录一次
126
+ if writer is not None and self.iteration % 100 == 0:
127
+ writer.add_scalar(name, self.summary[name] / 100, self.iteration)
128
+ self.summary[name] = 0
129
+
130
+ # 加载模型netG and netD
131
+ def load(self):
132
+ model_path = self.config['save_dir'] # 模型的保存路径
133
+ # 检测是否存在最近的模型检查点
134
+ if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
135
+ # 读取最后一个epoch的编号
136
+ latest_epoch = open(os.path.join(
137
+ model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
138
+ else:
139
+ # 如果不存在latest.ckpt,尝试读取存储好的模型文件列表,获取最近的一个
140
+ ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
141
+ os.path.join(model_path, '*.pth'))]
142
+ ckpts.sort() # 排序模型文件,以获取最近的一个
143
+ latest_epoch = ckpts[-1] if len(ckpts) > 0 else None # 获取最近的epoch值
144
+ if latest_epoch is not None:
145
+ # 拼接得到生成器和判别器的模型文件路径
146
+ gen_path = os.path.join(
147
+ model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
148
+ dis_path = os.path.join(
149
+ model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
150
+ opt_path = os.path.join(
151
+ model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
152
+ # 如果是主节点,输出加载模型的信息
153
+ if self.config['global_rank'] == 0:
154
+ print('Loading model from {}...'.format(gen_path))
155
+ # 加载生成器模型
156
+ data = torch.load(gen_path, map_location=self.config['device'])
157
+ self.netG.load_state_dict(data['netG'])
158
+ # 加载判别器模型
159
+ data = torch.load(dis_path, map_location=self.config['device'])
160
+ self.netD.load_state_dict(data['netD'])
161
+ # 加载优化器状态
162
+ data = torch.load(opt_path, map_location=self.config['device'])
163
+ self.optimG.load_state_dict(data['optimG'])
164
+ self.optimD.load_state_dict(data['optimD'])
165
+ # 更新当前epoch和迭代次数
166
+ self.epoch = data['epoch']
167
+ self.iteration = data['iteration']
168
+ else:
169
+ # 如果没有找到模型文件,则输出警告信息
170
+ if self.config['global_rank'] == 0:
171
+ print('Warning: There is no trained model found. An initialized model will be used.')
172
+
173
+ # 保存模型参数,每次评估周期 (eval_epoch) 调用一次
174
+ def save(self, it):
175
+ # 只在全局排名为0的进程上执行保存操作,通常代表主节点
176
+ if self.config['global_rank'] == 0:
177
+ # 生成保存生成器模型状态字典的文件路径
178
+ gen_path = os.path.join(
179
+ self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
180
+ # 生成保存判别器模型状态字典的文件路径
181
+ dis_path = os.path.join(
182
+ self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
183
+ # 生成保存优化器状态字典的文件路径
184
+ opt_path = os.path.join(
185
+ self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
186
+
187
+ # 打印消息表示模型正在保存
188
+ print('\nsaving model to {} ...'.format(gen_path))
189
+
190
+ # 判断模型是否是经过DataParallel或DDP包装的,若是则获取原始的模型
191
+ if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
192
+ netG = self.netG.module
193
+ netD = self.netD.module
194
+ else:
195
+ netG = self.netG
196
+ netD = self.netD
197
+
198
+ # 保存生成器和判别器的模型参数
199
+ torch.save({'netG': netG.state_dict()}, gen_path)
200
+ torch.save({'netD': netD.state_dict()}, dis_path)
201
+ # 保存当前的epoch、迭代次数和优化器的状态
202
+ torch.save({
203
+ 'epoch': self.epoch,
204
+ 'iteration': self.iteration,
205
+ 'optimG': self.optimG.state_dict(),
206
+ 'optimD': self.optimD.state_dict()
207
+ }, opt_path)
208
+
209
+ # 写入最新的迭代次数到"latest.ckpt"文件
210
+ os.system('echo {} > {}'.format(str(it).zfill(5),
211
+ os.path.join(self.config['save_dir'], 'latest.ckpt')))
212
+
213
+ # 训练入口
214
+
215
+ def train(self):
216
+ # 初始化进度条范围
217
+ pbar = range(int(self.train_args['iterations']))
218
+ # 如果是全局rank 0的进程,则设置显示进度条
219
+ if self.config['global_rank'] == 0:
220
+ pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
221
+
222
+ # 开始训练循环
223
+ while True:
224
+ self.epoch += 1 # epoch计数增加
225
+ if self.config['distributed']:
226
+ # 如果是分布式训练,则对采样器进行设置,保证每个进程获取的数据不同
227
+ self.train_sampler.set_epoch(self.epoch)
228
+
229
+ # 调用训练一个epoch的函数
230
+ self._train_epoch(pbar)
231
+ # 如果迭代次数超过配置中的迭代上限,则退出循环
232
+ if self.iteration > self.train_args['iterations']:
233
+ break
234
+ # 训练结束输出
235
+ print('\nEnd training....')
236
+
237
+ # 每个训练周期处理输入并计算损失
238
+
239
+ def _train_epoch(self, pbar):
240
+ device = self.config['device'] # 获取设备信息
241
+
242
+ # 遍历数据加载器中的数据
243
+ for frames, masks in self.train_loader:
244
+ # 调整学习率
245
+ self.adjust_learning_rate()
246
+ # 迭代次数+1
247
+ self.iteration += 1
248
+
249
+ # 将frames和masks转移到设备上
250
+ frames, masks = frames.to(device), masks.to(device)
251
+ b, t, c, h, w = frames.size() # 获取帧和蒙版的尺寸
252
+ masked_frame = (frames * (1 - masks).float()) # 应用蒙版到图像
253
+ pred_img = self.netG(masked_frame, masks) # 使用生成器生成填充图像
254
+ # 调整frames和masks的维度以符合网络的输入要求
255
+ frames = frames.view(b * t, c, h, w)
256
+ masks = masks.view(b * t, 1, h, w)
257
+ comp_img = frames * (1. - masks) + masks * pred_img # 生成最终的组合图像
258
+
259
+ gen_loss = 0 # 初始化生成器损失
260
+ dis_loss = 0 # 初始化判别器损失
261
+
262
+ # 判别器对抗性损失
263
+ real_vid_feat = self.netD(frames) # 判别器对真实图像判别
264
+ fake_vid_feat = self.netD(comp_img.detach()) # 判别器对生成图像判别,注意detach是为了不计算梯度
265
+ dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) # 真实图像的损失
266
+ dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) # 生成图像的损失
267
+ dis_loss += (dis_real_loss + dis_fake_loss) / 2 # 求平均的判别器损失
268
+ # 添加判别器损失到摘要
269
+ self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
270
+ self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
271
+ # 优化判别器
272
+ self.optimD.zero_grad()
273
+ dis_loss.backward()
274
+ self.optimD.step()
275
+
276
+ # 生成器对抗性损失
277
+ gen_vid_feat = self.netD(comp_img)
278
+ gan_loss = self.adversarial_loss(gen_vid_feat, True, False) # 生成器的对抗损失
279
+ gan_loss = gan_loss * self.config['losses']['adversarial_weight'] # 权重放大
280
+ gen_loss += gan_loss # 累加到生成器损失
281
+ # 添加生成器对抗性损失到摘要
282
+ self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
283
+
284
+ # 生成器L1损失
285
+ hole_loss = self.l1_loss(pred_img * masks, frames * masks) # 只计算有蒙版区域的损失
286
+ # 考虑蒙版的平均值,乘以配置中的hole_weight
287
+ hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
288
+ gen_loss += hole_loss # 累加到生成器损失
289
+ # 添加hole_loss到摘要
290
+ self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
291
+
292
+ # 计算蒙版外区域的L1损失
293
+ valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
294
+ # 考虑非蒙版区的平均值,乘以配置中的valid_weight
295
+ valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
296
+ gen_loss += valid_loss # 累加到生成器损失
297
+ # 添加valid_loss到摘要
298
+ self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
299
+
300
+ # 生成器优化
301
+ self.optimG.zero_grad()
302
+ gen_loss.backward()
303
+ self.optimG.step()
304
+
305
+ # 控制台日志输出
306
+ if self.config['global_rank'] == 0:
307
+ pbar.update(1) # 进度条更新
308
+ pbar.set_description(( # 设置进度条描述
309
+ f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" # 打印损失数值
310
+ f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
311
+ )
312
+
313
+ # 模型保存
314
+ if self.iteration % self.train_args['save_freq'] == 0:
315
+ self.save(int(self.iteration // self.train_args['save_freq']))
316
+ # 迭代次数终止判断
317
+ if self.iteration > self.train_args['iterations']:
318
+ break
319
+
backend/tools/train/utils_sttn.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.patches as patches
4
+ from matplotlib.path import Path
5
+ import io
6
+ import cv2
7
+ import random
8
+ import zipfile
9
+ import numpy as np
10
+ from PIL import Image, ImageOps
11
+ import torch
12
+ import matplotlib
13
+ from matplotlib import pyplot as plt
14
+ matplotlib.use('agg')
15
+
16
+
17
+ class ZipReader(object):
18
+ file_dict = dict()
19
+
20
+ def __init__(self):
21
+ super(ZipReader, self).__init__()
22
+
23
+ @staticmethod
24
+ def build_file_dict(path):
25
+ file_dict = ZipReader.file_dict
26
+ if path in file_dict:
27
+ return file_dict[path]
28
+ else:
29
+ file_handle = zipfile.ZipFile(path, 'r')
30
+ file_dict[path] = file_handle
31
+ return file_dict[path]
32
+
33
+ @staticmethod
34
+ def imread(path, image_name):
35
+ zfile = ZipReader.build_file_dict(path)
36
+ data = zfile.read(image_name)
37
+ im = Image.open(io.BytesIO(data))
38
+ return im
39
+
40
+
41
+ class GroupRandomHorizontalFlip(object):
42
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
43
+ """
44
+
45
+ def __init__(self, is_flow=False):
46
+ self.is_flow = is_flow
47
+
48
+ def __call__(self, img_group, is_flow=False):
49
+ v = random.random()
50
+ if v < 0.5:
51
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
52
+ if self.is_flow:
53
+ for i in range(0, len(ret), 2):
54
+ # invert flow pixel values when flipping
55
+ ret[i] = ImageOps.invert(ret[i])
56
+ return ret
57
+ else:
58
+ return img_group
59
+
60
+
61
+ class Stack(object):
62
+ def __init__(self, roll=False):
63
+ self.roll = roll
64
+
65
+ def __call__(self, img_group):
66
+ mode = img_group[0].mode
67
+ if mode == '1':
68
+ img_group = [img.convert('L') for img in img_group]
69
+ mode = 'L'
70
+ if mode == 'L':
71
+ return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
72
+ elif mode == 'RGB':
73
+ if self.roll:
74
+ return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
75
+ else:
76
+ return np.stack(img_group, axis=2)
77
+ else:
78
+ raise NotImplementedError(f"Image mode {mode}")
79
+
80
+
81
+ class ToTorchFormatTensor(object):
82
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
83
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
84
+
85
+ def __init__(self, div=True):
86
+ self.div = div
87
+
88
+ def __call__(self, pic):
89
+ if isinstance(pic, np.ndarray):
90
+ # numpy img: [L, C, H, W]
91
+ img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
92
+ else:
93
+ # handle PIL Image
94
+ img = torch.ByteTensor(
95
+ torch.ByteStorage.from_buffer(pic.tobytes()))
96
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
97
+ # put it from HWC to CHW format
98
+ # yikes, this transpose takes 80% of the loading time/CPU
99
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
100
+ img = img.float().div(255) if self.div else img.float()
101
+ return img
102
+
103
+
104
+ def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
105
+ # get a random shape
106
+ height = random.randint(imageHeight//3, imageHeight-1)
107
+ width = random.randint(imageWidth//3, imageWidth-1)
108
+ edge_num = random.randint(6, 8)
109
+ ratio = random.randint(6, 8)/10
110
+ region = get_random_shape(
111
+ edge_num=edge_num, ratio=ratio, height=height, width=width)
112
+ region_width, region_height = region.size
113
+ # get random position
114
+ x, y = random.randint(
115
+ 0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
116
+ velocity = get_random_velocity(max_speed=3)
117
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
118
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
119
+ masks = [m.convert('L')]
120
+ # return fixed masks
121
+ if random.uniform(0, 1) > 0.5:
122
+ return masks*video_length
123
+ # return moving masks
124
+ for _ in range(video_length-1):
125
+ x, y, velocity = random_move_control_points(
126
+ x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
127
+ m = Image.fromarray(
128
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
129
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
130
+ masks.append(m.convert('L'))
131
+ return masks
132
+
133
+
134
+ def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
135
+ '''
136
+ There is the initial point and 3 points per cubic bezier curve.
137
+ Thus, the curve will only pass though n points, which will be the sharp edges.
138
+ The other 2 modify the shape of the bezier curve.
139
+ edge_num, Number of possibly sharp edges
140
+ points_num, number of points in the Path
141
+ ratio, (0, 1) magnitude of the perturbation from the unit circle,
142
+ '''
143
+ points_num = edge_num*3 + 1
144
+ angles = np.linspace(0, 2*np.pi, points_num)
145
+ codes = np.full(points_num, Path.CURVE4)
146
+ codes[0] = Path.MOVETO
147
+ # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
148
+ verts = np.stack((np.cos(angles), np.sin(angles))).T * \
149
+ (2*ratio*np.random.random(points_num)+1-ratio)[:, None]
150
+ verts[-1, :] = verts[0, :]
151
+ path = Path(verts, codes)
152
+ # draw paths into images
153
+ fig = plt.figure()
154
+ ax = fig.add_subplot(111)
155
+ patch = patches.PathPatch(path, facecolor='black', lw=2)
156
+ ax.add_patch(patch)
157
+ ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
158
+ ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
159
+ ax.axis('off') # removes the axis to leave only the shape
160
+ fig.canvas.draw()
161
+ # convert plt images into numpy images
162
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
163
+ data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
164
+ plt.close(fig)
165
+ # postprocess
166
+ data = cv2.resize(data, (width, height))[:, :, 0]
167
+ data = (1 - np.array(data > 0).astype(np.uint8))*255
168
+ corrdinates = np.where(data > 0)
169
+ xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
170
+ corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
171
+ region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
172
+ return region
173
+
174
+
175
+ def random_accelerate(velocity, maxAcceleration, dist='uniform'):
176
+ speed, angle = velocity
177
+ d_speed, d_angle = maxAcceleration
178
+ if dist == 'uniform':
179
+ speed += np.random.uniform(-d_speed, d_speed)
180
+ angle += np.random.uniform(-d_angle, d_angle)
181
+ elif dist == 'guassian':
182
+ speed += np.random.normal(0, d_speed / 2)
183
+ angle += np.random.normal(0, d_angle / 2)
184
+ else:
185
+ raise NotImplementedError(
186
+ f'Distribution type {dist} is not supported.')
187
+ return (speed, angle)
188
+
189
+
190
+ def get_random_velocity(max_speed=3, dist='uniform'):
191
+ if dist == 'uniform':
192
+ speed = np.random.uniform(max_speed)
193
+ elif dist == 'guassian':
194
+ speed = np.abs(np.random.normal(0, max_speed / 2))
195
+ else:
196
+ raise NotImplementedError(
197
+ f'Distribution type {dist} is not supported.')
198
+ angle = np.random.uniform(0, 2 * np.pi)
199
+ return (speed, angle)
200
+
201
+
202
+ def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
203
+ region_width, region_height = region_size
204
+ speed, angle = lineVelocity
205
+ X += int(speed * np.cos(angle))
206
+ Y += int(speed * np.sin(angle))
207
+ lineVelocity = random_accelerate(
208
+ lineVelocity, maxLineAcceleration, dist='guassian')
209
+ if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0):
210
+ lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
211
+ new_X = np.clip(X, 0, imageHeight - region_height)
212
+ new_Y = np.clip(Y, 0, imageWidth - region_width)
213
+ return new_X, new_Y, lineVelocity
214
+
215
+
216
+ def get_world_size():
217
+ """Find OMPI world size without calling mpi functions
218
+ :rtype: int
219
+ """
220
+ if os.environ.get('PMI_SIZE') is not None:
221
+ return int(os.environ.get('PMI_SIZE') or 1)
222
+ elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
223
+ return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
224
+ else:
225
+ return torch.cuda.device_count()
226
+
227
+
228
+ def get_global_rank():
229
+ """Find OMPI world rank without calling mpi functions
230
+ :rtype: int
231
+ """
232
+ if os.environ.get('PMI_RANK') is not None:
233
+ return int(os.environ.get('PMI_RANK') or 0)
234
+ elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
235
+ return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
236
+ else:
237
+ return 0
238
+
239
+
240
+ def get_local_rank():
241
+ """Find OMPI local rank without calling mpi functions
242
+ :rtype: int
243
+ """
244
+ if os.environ.get('MPI_LOCALRANKID') is not None:
245
+ return int(os.environ.get('MPI_LOCALRANKID') or 0)
246
+ elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
247
+ return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
248
+ else:
249
+ return 0
250
+
251
+
252
+ def get_master_ip():
253
+ if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
254
+ return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
255
+ elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
256
+ return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
257
+ else:
258
+ return "127.0.0.1"
259
+
260
+ if __name__ == '__main__':
261
+ trials = 10
262
+ for _ in range(trials):
263
+ video_length = 10
264
+ # The returned masks are either stationary (50%) or moving (50%)
265
+ masks = create_random_shape_with_random_motion(
266
+ video_length, imageHeight=240, imageWidth=432)
267
+
268
+ for m in masks:
269
+ cv2.imshow('mask', np.array(m))
270
+ cv2.waitKey(500)
271
+
docker/Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ RUN --mount=type=cache,target=/root/.cache,sharing=private \
4
+ apt update && \
5
+ apt install -y libgl1-mesa-glx && \
6
+ true
7
+
8
+ ADD . /vsr
9
+ ARG CUDA_VERSION=11.8
10
+ ARG USE_DIRECTML=0
11
+
12
+ # 如果是 CUDA 版本,执行 CUDA 特定设置
13
+ RUN --mount=type=cache,target=/root/.cache,sharing=private \
14
+ if [ "${USE_DIRECTML:-0}" != "1" ]; then \
15
+ pip install paddlepaddle==3.0 && \
16
+ pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu$(echo ${CUDA_VERSION} | tr -d '.') && \
17
+ pip install -r /vsr/requirements.txt; \
18
+ fi
19
+
20
+ # 如果是 DirectML 版本,执行 DirectML 特定设置
21
+ RUN --mount=type=cache,target=/root/.cache,sharing=private \
22
+ if [ "${USE_DIRECTML:-0}" = "1" ]; then \
23
+ pip install paddlepaddle==3.0 && \
24
+ pip install torch_directml==0.2.5.dev240914 && \
25
+ pip install -r /vsr/requirements.txt; \
26
+ fi
27
+
28
+ ENV LD_LIBRARY_PATH=/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/
29
+ WORKDIR /vsr
30
+ CMD ["python", "/vsr/backend/main.py"]
google_colabs/README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Google Colab Gradio Interface
2
+
3
+ This folder contains two versions of the Google Colab notebook:
4
+
5
+ ## Files
6
+
7
+ ### 1. `Video_Subtitle_Remover_Gradio.ipynb` ⭐ **NEW - Recommended**
8
+ **Gradio Web Interface** - Easy-to-use browser-based UI
9
+
10
+ **Features:**
11
+ - 🖱️ Click-and-upload interface (no coding required)
12
+ - 🎨 Visual algorithm selection
13
+ - ⚙️ Adjustable parameters with sliders
14
+ - 📊 Real-time progress tracking
15
+ - 📥 One-click download
16
+
17
+ **Best for:**
18
+ - Users who prefer GUI
19
+ - Quick testing
20
+ - Non-technical users
21
+ - Multiple video processing
22
+
23
+ **Usage:**
24
+ 1. Open in Colab
25
+ 2. Run all cells
26
+ 3. Click the generated link
27
+ 4. Use web interface in browser
28
+
29
+ ---
30
+
31
+ ### 2. `Video_Subtitle_Remover.ipynb`
32
+ **Traditional Notebook** - Code-based approach
33
+
34
+ **Features:**
35
+ - Step-by-step execution
36
+ - Full control over parameters
37
+ - Good for understanding the process
38
+ - Batch processing scripts
39
+
40
+ **Best for:**
41
+ - Users comfortable with code
42
+ - Custom workflows
43
+ - Debugging
44
+ - Learning the internals
45
+
46
+ ---
47
+
48
+ ## Quick Start
49
+
50
+ ### For Gradio Interface (Recommended):
51
+
52
+ ```bash
53
+ 1. Open Video_Subtitle_Remover_Gradio.ipynb in Colab
54
+ 2. Runtime → Change runtime type → GPU
55
+ 3. Run all cells (Ctrl+F9)
56
+ 4. Click the gradio.live URL
57
+ 5. Upload video and click "Remove Subtitles"
58
+ ```
59
+
60
+ ### For Traditional Notebook:
61
+
62
+ ```bash
63
+ 1. Open Video_Subtitle_Remover.ipynb in Colab
64
+ 2. Runtime → Change runtime type → GPU
65
+ 3. Run cells step by step
66
+ 4. Configure settings in Step 5
67
+ 5. Run processing in Step 7
68
+ ```
69
+
70
+ ## Algorithm Recommendations
71
+
72
+ | Use Case | Algorithm | Quality | Speed |
73
+ |----------|-----------|---------|-------|
74
+ | **Best Quality** | DiffuEraser | ⭐⭐⭐⭐⭐ | ⭐⭐ |
75
+ | **Fastest** | STTN | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
76
+ | **Balanced** | Stable Diffusion | ⭐⭐⭐⭐ | ⭐⭐⭐ |
77
+ | **High Motion** | ProPainter | ⭐⭐⭐⭐⭐ | ⭐ |
78
+
79
+ ## System Requirements
80
+
81
+ - **GPU**: Required (T4/P100/V100)
82
+ - **Storage**: 10-20GB for models
83
+ - **VRAM**:
84
+ - STTN: 4GB
85
+ - DiffuEraser: 12GB
86
+ - Stable Diffusion: 8GB
87
+
88
+ ## Performance (Colab T4 GPU)
89
+
90
+ | Video | Algorithm | Time |
91
+ |-------|-----------|------|
92
+ | 1 min 720p | STTN | ~30s |
93
+ | 1 min 720p | DiffuEraser | ~3-5min |
94
+ | 5 min 720p | STTN | ~2min |
95
+ | 5 min 720p | DiffuEraser | ~15-20min |
96
+
97
+ ## Troubleshooting
98
+
99
+ ### Gradio not loading
100
+ - Wait 30-60 seconds for models to load
101
+ - Check if all cells ran successfully
102
+ - Restart runtime and try again
103
+
104
+ ### Out of Memory
105
+ - Reduce batch size in settings
106
+ - Use STTN instead of DiffuEraser
107
+ - Process shorter videos
108
+
109
+ ### Slow processing
110
+ - Use STTN for preview
111
+ - Enable GPU in Colab settings
112
+ - Consider Colab Pro for unlimited runtime
113
+
114
+ ## Links
115
+
116
+ - **GitHub**: https://github.com/YaoFANGUK/video-subtitle-remover
117
+ - **Documentation**: See `docs/` folder
118
+ - **Issues**: Report on GitHub
119
+
120
+ ## Tips
121
+
122
+ 1. **Start with STTN** to test quickly
123
+ 2. **Use DiffuEraser** for final high-quality output
124
+ 3. **Keep videos under 10 minutes** on free tier
125
+ 4. **Save to Google Drive** to avoid data loss
126
+ 5. **Monitor GPU usage** with `!nvidia-smi`
google_colabs/Video_Subtitle_Remover_Gradio.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🎬 Video Subtitle Remover - Gradio Interface\n",
8
+ "\n",
9
+ "**Easy-to-use web interface for removing hardcoded subtitles from videos**\n",
10
+ "\n",
11
+ "This notebook provides a Gradio web UI that runs in your browser.\n",
12
+ "\n",
13
+ "**Features:**\n",
14
+ "- 🖱️ Click-and-upload interface\n",
15
+ "- 🎨 Multiple AI algorithms (STTN, LAMA, DiffuEraser, etc.)\n",
16
+ "- ⚙️ Adjustable parameters\n",
17
+ "- 📊 Real-time progress\n",
18
+ "- 📥 Direct download\n",
19
+ "\n",
20
+ "**Requirements:**\n",
21
+ "- Google Colab with GPU (Runtime → Change runtime type → GPU)\n",
22
+ "- ~10-20GB storage (for models)"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {},
28
+ "source": [
29
+ "## Step 1: Check GPU"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "!nvidia-smi"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "metadata": {},
44
+ "source": [
45
+ "## Step 2: Clone Repository"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "!git clone https://github.com/walidurrosyad/sub-remover.git\n",
55
+ "%cd sub-remover"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "## Step 3: Install Dependencies"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "# Core dependencies\n",
72
+ "!pip install -q filesplit==3.0.2 albumentations scikit-image imgaug pyclipper lmdb\n",
73
+ "!pip install -q PyYAML omegaconf tqdm easydict scikit-learn pandas webdataset\n",
74
+ "!pip install -q protobuf av einops paddleocr paddle2onnx onnxruntime-gpu\n",
75
+ "!pip install -q paddlepaddle-gpu==2.6.2\n",
76
+ "\n",
77
+ "# Gradio for web interface\n",
78
+ "!pip install -q gradio\n",
79
+ "\n",
80
+ "# Advanced models (optional - uncomment if using SD or DiffuEraser)\n",
81
+ "!pip install -q diffusers transformers accelerate\n",
82
+ "\n",
83
+ "print(\"✓ All dependencies installed!\")"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## Step 4: Launch Gradio Interface\n",
91
+ "\n",
92
+ "This will create a web interface you can use in your browser!\n",
93
+ "\n",
94
+ "**Click the public URL that appears below to access the interface.**"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "# Launch Gradio interface\n",
104
+ "import sys\n",
105
+ "import os\n",
106
+ "\n",
107
+ "# Add paths\n",
108
+ "sys.path.insert(0, '/content/sub-remover')\n",
109
+ "sys.path.insert(0, '/content/sub-remover/backend')\n",
110
+ "\n",
111
+ "# Change to google_colabs directory to import gradio_app\n",
112
+ "os.chdir('/content/sub-remover/google_colabs')\n",
113
+ "\n",
114
+ "from gradio_app import create_interface\n",
115
+ "\n",
116
+ "demo = create_interface()\n",
117
+ "demo.launch(share=True, debug=True)"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "metadata": {},
123
+ "source": [
124
+ "## Alternative: Run Gradio in Notebook\n",
125
+ "\n",
126
+ "If the above doesn't work, run Gradio directly in the notebook:"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "import sys\n",
136
+ "sys.path.insert(0, '/content/sub-remover')\n",
137
+ "sys.path.insert(0, '/content/sub-remover/backend')\n",
138
+ "\n",
139
+ "# Import and run\n",
140
+ "from google_colabs.gradio_app import create_interface\n",
141
+ "\n",
142
+ "demo = create_interface()\n",
143
+ "demo.launch(share=True, debug=True)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "metadata": {},
149
+ "source": [
150
+ "## How to Use the Gradio Interface\n",
151
+ "\n",
152
+ "1. **Click the public URL** (looks like: https://xxxxx.gradio.live)\n",
153
+ "2. **Upload your video** using the upload button\n",
154
+ "3. **Select algorithm**:\n",
155
+ " - **DiffuEraser (Recommended)**: Best quality for subtitles\n",
156
+ " - **STTN (Fast)**: Quickest processing\n",
157
+ " - **Stable Diffusion**: High quality alternative\n",
158
+ "4. **Adjust settings** (optional) in \"Advanced Settings\"\n",
159
+ "5. **Click \"Remove Subtitles\"**\n",
160
+ "6. **Wait for processing** (progress shown)\n",
161
+ "7. **Download result** using the download button\n",
162
+ "\n",
163
+ "## Performance Guide\n",
164
+ "\n",
165
+ "### Colab T4 GPU (Free Tier)\n",
166
+ "\n",
167
+ "| Video Length | Algorithm | Time |\n",
168
+ "|--------------|-----------|------|\n",
169
+ "| 1 min 720p | STTN | ~30s |\n",
170
+ "| 1 min 720p | DiffuEraser | ~3-5min |\n",
171
+ "| 5 min 720p | STTN | ~2min |\n",
172
+ "| 5 min 720p | DiffuEraser | ~15-20min |\n",
173
+ "\n",
174
+ "### Tips\n",
175
+ "\n",
176
+ "- **Use STTN for preview**, DiffuEraser for final\n",
177
+ "- **Keep videos under 10 minutes** to avoid Colab timeout\n",
178
+ "- **Enable GPU** in Runtime settings\n",
179
+ "- **Reduce batch size** if you get OOM errors\n",
180
+ "\n",
181
+ "## Troubleshooting\n",
182
+ "\n",
183
+ "### \"Out of Memory\" Error\n",
184
+ "Reduce batch size in Advanced Settings:\n",
185
+ "- DiffuEraser: Set \"Max Frames per Batch\" to 40\n",
186
+ "- STTN: Set to 30\n",
187
+ "\n",
188
+ "### Gradio Not Loading\n",
189
+ "Restart runtime and run all cells again.\n",
190
+ "\n",
191
+ "### Slow Processing\n",
192
+ "- Use STTN algorithm for faster results\n",
193
+ "- Process shorter video clips\n",
194
+ "\n",
195
+ "### Session Timeout\n",
196
+ "Colab free tier has time limits. Process shorter videos or use Colab Pro."
197
+ ]
198
+ }
199
+ ],
200
+ "metadata": {
201
+ "accelerator": "GPU",
202
+ "colab": {
203
+ "gpuType": "T4",
204
+ "provenance": []
205
+ },
206
+ "kernelspec": {
207
+ "display_name": "Python 3",
208
+ "language": "python",
209
+ "name": "python3"
210
+ },
211
+ "language_info": {
212
+ "name": "python",
213
+ "version": "3.10.12"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 0
218
+ }
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations~=1.4.10
2
+ filesplit==3.0.2
3
+ opencv-python==4.11.0.86
4
+ scikit-image==0.25.2
5
+ imgaug==0.4.0
6
+ pyclipper==1.3.0.post6
7
+ lmdb==1.6.2
8
+ PyYAML==6.0.2
9
+ omegaconf==2.3.0
10
+ tqdm==4.67.1
11
+ PySimpleGUI-4-foss==4.60.4.1
12
+ easydict==1.13
13
+ scikit-learn==1.6.1
14
+ pandas==2.2.3
15
+ webdataset==0.2.111
16
+ numpy==2.2.5
17
+ protobuf==6.30.2
18
+ av==14.3.0
19
+ einops==0.8.1
20
+ paddleocr==2.10.0
21
+ paddle2onnx==1.3.1
22
+ onnxruntime-gpu==1.20.1
23
+ onnxruntime-directml==1.20.1; sys_platform == 'win32'
24
+
25
+ # Advanced Inpainting Models
26
+ diffusers>=0.27.0 # For Stable Diffusion & DiffuEraser
27
+ transformers>=4.36.0 # Required by diffusers
28
+ accelerate>=0.25.0 # For faster inference
29
+ xformers>=0.0.23; sys_platform != 'darwin' # Memory optimization (not on macOS)