ahmad walidurosyad
commited on
Commit
·
9fd445b
1
Parent(s):
a6408e4
add
Browse files- backend/config.py +184 -0
- backend/scenedetect/detectors/motion_detector.py +92 -0
- backend/scenedetect/detectors/threshold_detector.py +203 -0
- backend/tools/common_tools.py +32 -0
- backend/tools/inpaint_tools.py +117 -0
- backend/tools/makedist.py +65 -0
- backend/tools/merge_video.py +32 -0
- backend/tools/train/dataset_sttn.py +85 -0
- backend/tools/train/loss_sttn.py +56 -0
- backend/tools/train/train_sttn.py +96 -0
- backend/tools/train/trainer_sttn.py +319 -0
- backend/tools/train/utils_sttn.py +271 -0
- docker/Dockerfile +30 -0
- google_colabs/README.md +126 -0
- google_colabs/Video_Subtitle_Remover_Gradio.ipynb +218 -0
- requirements.txt +29 -0
backend/config.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from enum import Enum, unique
|
| 3 |
+
warnings.filterwarnings('ignore')
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import logging
|
| 7 |
+
import platform
|
| 8 |
+
import stat
|
| 9 |
+
from fsplit.filesplit import Filesplit
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
|
| 12 |
+
# 项目版本号
|
| 13 |
+
VERSION = "1.1.1"
|
| 14 |
+
# ×××××××××××××××××××× [不要改] start ××××××××××××××××××××
|
| 15 |
+
logging.disable(logging.DEBUG) # 关闭DEBUG日志的打印
|
| 16 |
+
logging.disable(logging.WARNING) # 关闭WARNING日志的打印
|
| 17 |
+
try:
|
| 18 |
+
import torch_directml
|
| 19 |
+
device = torch_directml.device(torch_directml.default_device())
|
| 20 |
+
USE_DML = True
|
| 21 |
+
except:
|
| 22 |
+
USE_DML = False
|
| 23 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama')
|
| 26 |
+
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth')
|
| 27 |
+
VIDEO_INPAINT_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'video')
|
| 28 |
+
MODEL_VERSION = 'V4'
|
| 29 |
+
DET_MODEL_BASE = os.path.join(BASE_DIR, 'models')
|
| 30 |
+
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det')
|
| 31 |
+
|
| 32 |
+
# 查看该路径下是否有模型完整文件,没有的话合并小文件生成完整文件
|
| 33 |
+
if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)):
|
| 34 |
+
fs = Filesplit()
|
| 35 |
+
fs.merge(input_dir=LAMA_MODEL_PATH)
|
| 36 |
+
|
| 37 |
+
if 'inference.pdiparams' not in os.listdir(DET_MODEL_PATH):
|
| 38 |
+
fs = Filesplit()
|
| 39 |
+
fs.merge(input_dir=DET_MODEL_PATH)
|
| 40 |
+
|
| 41 |
+
if 'ProPainter.pth' not in os.listdir(VIDEO_INPAINT_MODEL_PATH):
|
| 42 |
+
fs = Filesplit()
|
| 43 |
+
fs.merge(input_dir=VIDEO_INPAINT_MODEL_PATH)
|
| 44 |
+
|
| 45 |
+
# 指定ffmpeg可执行程序路径
|
| 46 |
+
sys_str = platform.system()
|
| 47 |
+
if sys_str == "Windows":
|
| 48 |
+
ffmpeg_bin = os.path.join('win_x64', 'ffmpeg.exe')
|
| 49 |
+
elif sys_str == "Linux":
|
| 50 |
+
ffmpeg_bin = os.path.join('linux_x64', 'ffmpeg')
|
| 51 |
+
else:
|
| 52 |
+
ffmpeg_bin = os.path.join('macos', 'ffmpeg')
|
| 53 |
+
FFMPEG_PATH = os.path.join(BASE_DIR, '', 'ffmpeg', ffmpeg_bin)
|
| 54 |
+
|
| 55 |
+
if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')):
|
| 56 |
+
fs = Filesplit()
|
| 57 |
+
fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64'))
|
| 58 |
+
# 将ffmpeg添加可执行权限
|
| 59 |
+
os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
|
| 60 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 61 |
+
|
| 62 |
+
# 是否使用ONNX(DirectML/AMD/Intel)
|
| 63 |
+
ONNX_PROVIDERS = []
|
| 64 |
+
available_providers = ort.get_available_providers()
|
| 65 |
+
for provider in available_providers:
|
| 66 |
+
if provider in [
|
| 67 |
+
"CPUExecutionProvider"
|
| 68 |
+
]:
|
| 69 |
+
continue
|
| 70 |
+
if provider not in [
|
| 71 |
+
"DmlExecutionProvider", # DirectML,适用于 Windows GPU
|
| 72 |
+
"ROCMExecutionProvider", # AMD ROCm
|
| 73 |
+
"MIGraphXExecutionProvider", # AMD MIGraphX
|
| 74 |
+
"VitisAIExecutionProvider", # AMD VitisAI,适用于 RyzenAI & Windows, 实测和DirectML性能似乎差不多
|
| 75 |
+
"OpenVINOExecutionProvider", # Intel GPU
|
| 76 |
+
"MetalExecutionProvider", # Apple macOS
|
| 77 |
+
"CoreMLExecutionProvider", # Apple macOS
|
| 78 |
+
"CUDAExecutionProvider", # Nvidia GPU
|
| 79 |
+
]:
|
| 80 |
+
continue
|
| 81 |
+
ONNX_PROVIDERS.append(provider)
|
| 82 |
+
# ×××××××××××××××××××× [不要改] end ××××××××××××××××××××
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@unique
|
| 86 |
+
class InpaintMode(Enum):
|
| 87 |
+
"""
|
| 88 |
+
图像重绘算法枚举
|
| 89 |
+
"""
|
| 90 |
+
STTN = 'sttn'
|
| 91 |
+
LAMA = 'lama'
|
| 92 |
+
PROPAINTER = 'propainter'
|
| 93 |
+
STABLE_DIFFUSION = 'sd' # Stable Diffusion Inpainting
|
| 94 |
+
DIFFUERASER = 'diffueraser' # DiffuEraser (diffusion-based)
|
| 95 |
+
E2FGVI = 'e2fgvi' # Flow-guided video inpainting
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ×××××××××××××××××××× [可以改] start ××××××××××××××××××××
|
| 99 |
+
# 是否使用h264编码,如果需要安卓手机分享生成的视频,请打开该选项
|
| 100 |
+
USE_H264 = True
|
| 101 |
+
|
| 102 |
+
# ×××××××××× 通用设置 start ××××××××××
|
| 103 |
+
"""
|
| 104 |
+
MODE可选算法类型
|
| 105 |
+
- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测
|
| 106 |
+
- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测
|
| 107 |
+
- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好
|
| 108 |
+
"""
|
| 109 |
+
# 默认重绘算法模式 sttn/lama/propainter/sd/diffueraser/e2fgvi
|
| 110 |
+
MODE = InpaintMode.STTN
|
| 111 |
+
|
| 112 |
+
# ×××××××××××××××××××× Stable Diffusion Settings ××××××××××××××××××××
|
| 113 |
+
SD_MODEL_PATH = 'backend/models/stable-diffusion-inpainting'
|
| 114 |
+
SD_STEPS = 50 # Inference steps
|
| 115 |
+
SD_GUIDANCE_SCALE = 7.5 # Classifier-free guidance
|
| 116 |
+
SD_PROMPT = "natural scene, high quality" # Text prompt for guidance
|
| 117 |
+
SD_USE_FP16 = True # Use half precision for faster inference
|
| 118 |
+
|
| 119 |
+
# ×××××××××××××××××××× DiffuEraser Settings ××××××××××××××××××××
|
| 120 |
+
DIFFUERASER_MODEL_PATH = 'backend/models/diffueraser'
|
| 121 |
+
DIFFUERASER_STEPS = 50 # Diffusion steps
|
| 122 |
+
DIFFUERASER_GUIDANCE = 7.5 # Guidance scale
|
| 123 |
+
DIFFUERASER_USE_SAM2 = False # Auto-masking with SAM2
|
| 124 |
+
DIFFUERASER_MAX_LOAD_NUM = 80 # Max frames per batch
|
| 125 |
+
|
| 126 |
+
# ×××××××××××××××××××× E2FGVI Settings ××××××××××××××××××××
|
| 127 |
+
E2FGVI_MODEL_PATH = 'backend/models/e2fgvi'
|
| 128 |
+
E2FGVI_MAX_LOAD_NUM = 80 # Max frames per batch
|
| 129 |
+
E2FGVI_NEIGHBOR_LENGTH = 10 # Temporal window for flow
|
| 130 |
+
# 【设置像素点偏差】
|
| 131 |
+
# 用于判断是不是非字幕区域(一般认为字幕文本框的长度是要大于宽度的,如果字幕框的高大于宽,且大于的幅度超过指定像素点大小,则认为是错误检测)
|
| 132 |
+
THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10
|
| 133 |
+
# 用于放大mask大小,防止自动检测的文本框过小,inpaint阶段出现文字边,有残留
|
| 134 |
+
SUBTITLE_AREA_DEVIATION_PIXEL = 20
|
| 135 |
+
# 同于判断两个文本框是否为同一行字幕,高度差距指定像素点以内认为是同一行
|
| 136 |
+
THRESHOLD_HEIGHT_DIFFERENCE = 20
|
| 137 |
+
# 用于判断两个字幕文本的矩形框是否相似,如果X轴和Y轴偏差都在指定阈值内,则认为时同一个文本框
|
| 138 |
+
PIXEL_TOLERANCE_Y = 20 # 允许检测框纵向偏差的像素点数
|
| 139 |
+
PIXEL_TOLERANCE_X = 20 # 允许检测框横向偏差的像素点数
|
| 140 |
+
# ×××××××××× 通用设置 end ××××××××××
|
| 141 |
+
|
| 142 |
+
# ×××××××××× InpaintMode.STTN算法设置 start ××××××××××
|
| 143 |
+
# 以下参数仅适用STTN算法时,才生效
|
| 144 |
+
"""
|
| 145 |
+
1. STTN_SKIP_DETECTION
|
| 146 |
+
含义:是否使用跳过检测
|
| 147 |
+
效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了
|
| 148 |
+
|
| 149 |
+
2. STTN_NEIGHBOR_STRIDE
|
| 150 |
+
含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。
|
| 151 |
+
效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。
|
| 152 |
+
|
| 153 |
+
3. STTN_REFERENCE_LENGTH
|
| 154 |
+
含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息
|
| 155 |
+
效果:调大会增加显存占用,处理效果变好,但是处理速度变慢
|
| 156 |
+
|
| 157 |
+
4. STTN_MAX_LOAD_NUM
|
| 158 |
+
含义:STTN算法每次最多加载的视频帧数量
|
| 159 |
+
效果:设置越大速度越慢,但效果越好
|
| 160 |
+
注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH
|
| 161 |
+
"""
|
| 162 |
+
STTN_SKIP_DETECTION = True
|
| 163 |
+
# 参考帧步长
|
| 164 |
+
STTN_NEIGHBOR_STRIDE = 5
|
| 165 |
+
# 参考帧长度(数量)
|
| 166 |
+
STTN_REFERENCE_LENGTH = 10
|
| 167 |
+
# 设置STTN算法最大同时处理的帧数量
|
| 168 |
+
STTN_MAX_LOAD_NUM = 50
|
| 169 |
+
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE:
|
| 170 |
+
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE
|
| 171 |
+
# ×××××××××× InpaintMode.STTN算法设置 end ××××××××××
|
| 172 |
+
|
| 173 |
+
# ×××××××××× InpaintMode.PROPAINTER算法设置 start ××××××××××
|
| 174 |
+
# 【根据自己的GPU显存大小设置】最大同时处理的图片数量,设置越大处理效果越好,但是要求显存越高
|
| 175 |
+
# 1280x720p视频设置80需要25G显存,设置50需要19G显存
|
| 176 |
+
# 720x480p视频设置80需要8G显存,设置50需要7G显存
|
| 177 |
+
PROPAINTER_MAX_LOAD_NUM = 70
|
| 178 |
+
# ×××××××××× InpaintMode.PROPAINTER算法设置 end ××××××××××
|
| 179 |
+
|
| 180 |
+
# ×××××××××× InpaintMode.LAMA算法设置 start ××××××××××
|
| 181 |
+
# 是否开启极速模式,开启后不保证inpaint效果,仅仅对包含文本的区域文本进行去除
|
| 182 |
+
LAMA_SUPER_FAST = False
|
| 183 |
+
# ×××××××××× InpaintMode.LAMA算法设置 end ××××××××××
|
| 184 |
+
# ×××××××××××××××××××× [可以改] end ××××××××××××××××××××
|
backend/scenedetect/detectors/motion_detector.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# PySceneDetect: Python-Based Video Scene Detector
|
| 4 |
+
# -------------------------------------------------------------------
|
| 5 |
+
# [ Site: https://scenedetect.com ]
|
| 6 |
+
# [ Docs: https://scenedetect.com/docs/ ]
|
| 7 |
+
# [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
|
| 8 |
+
#
|
| 9 |
+
# Copyright (C) 2014-2023 Brandon Castellano <http://www.bcastell.com>.
|
| 10 |
+
# PySceneDetect is licensed under the BSD 3-Clause License; see the
|
| 11 |
+
# included LICENSE file, or visit one of the above pages for details.
|
| 12 |
+
#
|
| 13 |
+
""":class:`MotionDetector`, detects motion events using background subtraction, morphological
|
| 14 |
+
transforms, and thresholding."""
|
| 15 |
+
|
| 16 |
+
# Third-Party Library Imports
|
| 17 |
+
import cv2
|
| 18 |
+
|
| 19 |
+
# PySceneDetect Library Imports
|
| 20 |
+
from scenedetect.scene_detector import SparseSceneDetector
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MotionDetector(SparseSceneDetector):
|
| 24 |
+
"""Detects motion events in scenes containing a static background.
|
| 25 |
+
|
| 26 |
+
Uses background subtraction followed by noise removal (via morphological
|
| 27 |
+
opening) to generate a frame score compared against the set threshold.
|
| 28 |
+
|
| 29 |
+
Attributes:
|
| 30 |
+
threshold: floating point value compared to each frame's score, which
|
| 31 |
+
represents average intensity change per pixel (lower values are
|
| 32 |
+
more sensitive to motion changes). Default 0.5, must be > 0.0.
|
| 33 |
+
num_frames_post_scene: Number of frames to include in each motion
|
| 34 |
+
event after the frame score falls below the threshold, adding any
|
| 35 |
+
subsequent motion events to the same scene.
|
| 36 |
+
kernel_size: Size of morphological opening kernel for noise removal.
|
| 37 |
+
Setting to -1 (default) will auto-compute based on video resolution
|
| 38 |
+
(typically 3 for SD, 5-7 for HD). Must be an odd integer > 1.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, threshold=0.50, num_frames_post_scene=30, kernel_size=-1):
|
| 42 |
+
"""Initializes motion-based scene detector object."""
|
| 43 |
+
# TODO: Requires porting to v0.5 API.
|
| 44 |
+
raise NotImplementedError()
|
| 45 |
+
"""
|
| 46 |
+
self.threshold = float(threshold)
|
| 47 |
+
self.num_frames_post_scene = int(num_frames_post_scene)
|
| 48 |
+
|
| 49 |
+
self.kernel_size = int(kernel_size)
|
| 50 |
+
if self.kernel_size < 0:
|
| 51 |
+
# Set kernel size when process_frame first runs based on
|
| 52 |
+
# video resolution (480p = 3x3, 720p = 5x5, 1080p = 7x7).
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
self.bg_subtractor = cv2.createBackgroundSubtractorMOG2(
|
| 56 |
+
detectShadows = False )
|
| 57 |
+
|
| 58 |
+
self.last_frame_score = 0.0
|
| 59 |
+
|
| 60 |
+
self.in_motion_event = False
|
| 61 |
+
self.first_motion_frame_index = -1
|
| 62 |
+
self.last_motion_frame_index = -1
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def process_frame(self, frame_num, frame_img):
|
| 66 |
+
# TODO.
|
| 67 |
+
"""
|
| 68 |
+
frame_grayscale = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 69 |
+
masked_frame = self.bg_subtractor.apply(frame_grayscale)
|
| 70 |
+
|
| 71 |
+
kernel = numpy.ones((self.kernel_size, self.kernel_size), numpy.uint8)
|
| 72 |
+
filtered_frame = cv2.morphologyEx(fgmask, cv2.MORPH_OPEN, kernel)
|
| 73 |
+
|
| 74 |
+
frame_score = numpy.sum(filtered_frame) / float(
|
| 75 |
+
filtered_frame.shape[0] * filtered_frame.shape[1] )
|
| 76 |
+
"""
|
| 77 |
+
return []
|
| 78 |
+
|
| 79 |
+
def post_process(self, frame_num):
|
| 80 |
+
"""Writes the last scene if the video ends while in a motion event.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# If the last fade detected was a fade out, we add a corresponding new
|
| 84 |
+
# scene break to indicate the end of the scene. This is only done for
|
| 85 |
+
# fade-outs, as a scene cut is already added when a fade-in is found.
|
| 86 |
+
"""
|
| 87 |
+
if self.in_motion_event:
|
| 88 |
+
# Write new scene based on first and last motion event frames.
|
| 89 |
+
pass
|
| 90 |
+
return self.in_motion_event
|
| 91 |
+
"""
|
| 92 |
+
return []
|
backend/scenedetect/detectors/threshold_detector.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# PySceneDetect: Python-Based Video Scene Detector
|
| 4 |
+
# -------------------------------------------------------------------
|
| 5 |
+
# [ Site: https://scenedetect.com ]
|
| 6 |
+
# [ Docs: https://scenedetect.com/docs/ ]
|
| 7 |
+
# [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
|
| 8 |
+
#
|
| 9 |
+
# Copyright (C) 2014-2023 Brandon Castellano <http://www.bcastell.com>.
|
| 10 |
+
# PySceneDetect is licensed under the BSD 3-Clause License; see the
|
| 11 |
+
# included LICENSE file, or visit one of the above pages for details.
|
| 12 |
+
#
|
| 13 |
+
""":class:`ThresholdDetector` uses a set intensity as a threshold to scene_detect cuts, which are
|
| 14 |
+
triggered when the average pixel intensity exceeds or falls below this threshold.
|
| 15 |
+
|
| 16 |
+
This detector is available from the command-line as the `scene_detect-threshold` command.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from enum import Enum
|
| 20 |
+
from logging import getLogger
|
| 21 |
+
from typing import List, Optional
|
| 22 |
+
|
| 23 |
+
import numpy
|
| 24 |
+
|
| 25 |
+
from backend.scenedetect.scene_detector import SceneDetector
|
| 26 |
+
|
| 27 |
+
logger = getLogger('pyscenedetect')
|
| 28 |
+
|
| 29 |
+
##
|
| 30 |
+
## ThresholdDetector Helper Functions
|
| 31 |
+
##
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _compute_frame_average(frame: numpy.ndarray) -> float:
|
| 35 |
+
"""Computes the average pixel value/intensity for all pixels in a frame.
|
| 36 |
+
|
| 37 |
+
The value is computed by adding up the 8-bit R, G, and B values for
|
| 38 |
+
each pixel, and dividing by the number of pixels multiplied by 3.
|
| 39 |
+
|
| 40 |
+
Arguments:
|
| 41 |
+
frame: Frame representing the RGB pixels to average.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Average pixel intensity across all 3 channels of `frame`
|
| 45 |
+
"""
|
| 46 |
+
num_pixel_values = float(frame.shape[0] * frame.shape[1] * frame.shape[2])
|
| 47 |
+
avg_pixel_value = numpy.sum(frame[:, :, :]) / num_pixel_values
|
| 48 |
+
return avg_pixel_value
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
##
|
| 52 |
+
## ThresholdDetector Class Implementation
|
| 53 |
+
##
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ThresholdDetector(SceneDetector):
|
| 57 |
+
"""Detects fast cuts/slow fades in from and out to a given threshold level.
|
| 58 |
+
|
| 59 |
+
Detects both fast cuts and slow fades so long as an appropriate threshold
|
| 60 |
+
is chosen (especially taking into account the minimum grey/black level).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
class Method(Enum):
|
| 64 |
+
"""Method for ThresholdDetector to use when comparing frame brightness to the threshold."""
|
| 65 |
+
FLOOR = 0
|
| 66 |
+
"""Fade out happens when frame brightness falls below threshold."""
|
| 67 |
+
CEILING = 1
|
| 68 |
+
"""Fade out happens when frame brightness rises above threshold."""
|
| 69 |
+
|
| 70 |
+
THRESHOLD_VALUE_KEY = 'average_rgb'
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
threshold: float = 12,
|
| 75 |
+
min_scene_len: int = 15,
|
| 76 |
+
fade_bias: float = 0.0,
|
| 77 |
+
add_final_scene: bool = False,
|
| 78 |
+
method: Method = Method.FLOOR,
|
| 79 |
+
block_size=None,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Arguments:
|
| 83 |
+
threshold: 8-bit intensity value that each pixel value (R, G, and B)
|
| 84 |
+
must be <= to in order to trigger a fade in/out.
|
| 85 |
+
min_scene_len: FrameTimecode object or integer greater than 0 of the
|
| 86 |
+
minimum length, in frames, of a scene (or subsequent scene cut).
|
| 87 |
+
fade_bias: Float between -1.0 and +1.0 representing the percentage of
|
| 88 |
+
timecode skew for the start of a scene (-1.0 causing a cut at the
|
| 89 |
+
fade-to-black, 0.0 in the middle, and +1.0 causing the cut to be
|
| 90 |
+
right at the position where the threshold is passed).
|
| 91 |
+
add_final_scene: Boolean indicating if the video ends on a fade-out to
|
| 92 |
+
generate an additional scene at this timecode.
|
| 93 |
+
method: How to treat `threshold` when detecting fade events.
|
| 94 |
+
block_size: [DEPRECATED] DO NOT USE. For backwards compatibility.
|
| 95 |
+
"""
|
| 96 |
+
# TODO(v0.7): Replace with DeprecationWarning that `block_size` will be removed in v0.8.
|
| 97 |
+
if block_size is not None:
|
| 98 |
+
logger.error('block_size is deprecated.')
|
| 99 |
+
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.threshold = int(threshold)
|
| 102 |
+
self.method = ThresholdDetector.Method(method)
|
| 103 |
+
self.fade_bias = fade_bias
|
| 104 |
+
self.min_scene_len = min_scene_len
|
| 105 |
+
self.processed_frame = False
|
| 106 |
+
self.last_scene_cut = None
|
| 107 |
+
# Whether to add an additional scene or not when ending on a fade out
|
| 108 |
+
# (as cuts are only added on fade ins; see post_process() for details).
|
| 109 |
+
self.add_final_scene = add_final_scene
|
| 110 |
+
# Where the last fade (threshold crossing) was detected.
|
| 111 |
+
self.last_fade = {
|
| 112 |
+
'frame': 0, # frame number where the last detected fade is
|
| 113 |
+
'type': None # type of fade, can be either 'in' or 'out'
|
| 114 |
+
}
|
| 115 |
+
self._metric_keys = [ThresholdDetector.THRESHOLD_VALUE_KEY]
|
| 116 |
+
|
| 117 |
+
def get_metrics(self) -> List[str]:
|
| 118 |
+
return self._metric_keys
|
| 119 |
+
|
| 120 |
+
def process_frame(self, frame_num: int, frame_img: Optional[numpy.ndarray]) -> List[int]:
|
| 121 |
+
"""
|
| 122 |
+
Args:
|
| 123 |
+
frame_num (int): Frame number of frame that is being passed.
|
| 124 |
+
frame_img (numpy.ndarray or None): Decoded frame image (numpy.ndarray) to perform
|
| 125 |
+
scene detection with. Can be None *only* if the self.is_processing_required()
|
| 126 |
+
method (inhereted from the base SceneDetector class) returns True.
|
| 127 |
+
Returns:
|
| 128 |
+
List[int]: List of frames where scene cuts have been detected. There may be 0
|
| 129 |
+
or more frames in the list, and not necessarily the same as frame_num.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
# Initialize last scene cut point at the beginning of the frames of interest.
|
| 133 |
+
if self.last_scene_cut is None:
|
| 134 |
+
self.last_scene_cut = frame_num
|
| 135 |
+
|
| 136 |
+
# Compare the # of pixels under threshold in current_frame & last_frame.
|
| 137 |
+
# If absolute value of pixel intensity delta is above the threshold,
|
| 138 |
+
# then we trigger a new scene cut/break.
|
| 139 |
+
|
| 140 |
+
# List of cuts to return.
|
| 141 |
+
cut_list = []
|
| 142 |
+
|
| 143 |
+
# The metric used here to scene_detect scene breaks is the percent of pixels
|
| 144 |
+
# less than or equal to the threshold; however, since this differs on
|
| 145 |
+
# user-supplied values, we supply the average pixel intensity as this
|
| 146 |
+
# frame metric instead (to assist with manually selecting a threshold)
|
| 147 |
+
if (self.stats_manager is not None) and (self.stats_manager.metrics_exist(
|
| 148 |
+
frame_num, self._metric_keys)):
|
| 149 |
+
frame_avg = self.stats_manager.get_metrics(frame_num, self._metric_keys)[0]
|
| 150 |
+
else:
|
| 151 |
+
frame_avg = _compute_frame_average(frame_img)
|
| 152 |
+
if self.stats_manager is not None:
|
| 153 |
+
self.stats_manager.set_metrics(frame_num, {self._metric_keys[0]: frame_avg})
|
| 154 |
+
|
| 155 |
+
if self.processed_frame:
|
| 156 |
+
if self.last_fade['type'] == 'in' and ((
|
| 157 |
+
(self.method == ThresholdDetector.Method.FLOOR and frame_avg < self.threshold) or
|
| 158 |
+
(self.method == ThresholdDetector.Method.CEILING and frame_avg >= self.threshold))):
|
| 159 |
+
# Just faded out of a scene, wait for next fade in.
|
| 160 |
+
self.last_fade['type'] = 'out'
|
| 161 |
+
self.last_fade['frame'] = frame_num
|
| 162 |
+
|
| 163 |
+
elif self.last_fade['type'] == 'out' and (
|
| 164 |
+
(self.method == ThresholdDetector.Method.FLOOR and frame_avg >= self.threshold) or
|
| 165 |
+
(self.method == ThresholdDetector.Method.CEILING and frame_avg < self.threshold)):
|
| 166 |
+
# Only add the scene if min_scene_len frames have passed.
|
| 167 |
+
if (frame_num - self.last_scene_cut) >= self.min_scene_len:
|
| 168 |
+
# Just faded into a new scene, compute timecode for the scene
|
| 169 |
+
# split based on the fade bias.
|
| 170 |
+
f_out = self.last_fade['frame']
|
| 171 |
+
f_split = int(
|
| 172 |
+
(frame_num + f_out + int(self.fade_bias * (frame_num - f_out))) / 2)
|
| 173 |
+
cut_list.append(f_split)
|
| 174 |
+
self.last_scene_cut = frame_num
|
| 175 |
+
self.last_fade['type'] = 'in'
|
| 176 |
+
self.last_fade['frame'] = frame_num
|
| 177 |
+
else:
|
| 178 |
+
self.last_fade['frame'] = 0
|
| 179 |
+
if frame_avg < self.threshold:
|
| 180 |
+
self.last_fade['type'] = 'out'
|
| 181 |
+
else:
|
| 182 |
+
self.last_fade['type'] = 'in'
|
| 183 |
+
self.processed_frame = True
|
| 184 |
+
return cut_list
|
| 185 |
+
|
| 186 |
+
def post_process(self, frame_num: int):
|
| 187 |
+
"""Writes a final scene cut if the last detected fade was a fade-out.
|
| 188 |
+
|
| 189 |
+
Only writes the scene cut if add_final_scene is true, and the last fade
|
| 190 |
+
that was detected was a fade-out. There is no bias applied to this cut
|
| 191 |
+
(since there is no corresponding fade-in) so it will be located at the
|
| 192 |
+
exact frame where the fade-out crossed the detection threshold.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
# If the last fade detected was a fade out, we add a corresponding new
|
| 196 |
+
# scene break to indicate the end of the scene. This is only done for
|
| 197 |
+
# fade-outs, as a scene cut is already added when a fade-in is found.
|
| 198 |
+
cut_times = []
|
| 199 |
+
if self.last_fade['type'] == 'out' and self.add_final_scene and (
|
| 200 |
+
(self.last_scene_cut is None and frame_num >= self.min_scene_len) or
|
| 201 |
+
(frame_num - self.last_scene_cut) >= self.min_scene_len):
|
| 202 |
+
cut_times.append(self.last_fade['frame'])
|
| 203 |
+
return cut_times
|
backend/tools/common_tools.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
video_extensions = {
|
| 4 |
+
'.mp4', '.m4a', '.m4v', '.f4v', '.f4a', '.m4b', '.m4r', '.f4b', '.mov',
|
| 5 |
+
'.3gp', '.3gp2', '.3g2', '.3gpp', '.3gpp2', '.ogg', '.oga', '.ogv', '.ogx',
|
| 6 |
+
'.wmv', '.wma', '.asf', '.webm', '.flv', '.avi', '.gifv', '.mkv', '.rm',
|
| 7 |
+
'.rmvb', '.vob', '.dvd', '.mpg', '.mpeg', '.mp2', '.mpe', '.mpv', '.mpg',
|
| 8 |
+
'.mpeg', '.m2v', '.svi', '.3gp', '.mxf', '.roq', '.nsv', '.flv', '.f4v',
|
| 9 |
+
'.f4p', '.f4a', '.f4b'
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
image_extensions = {
|
| 13 |
+
'.jpg', '.jpeg', '.jpe', '.jif', '.jfif', '.jfi', '.png', '.gif',
|
| 14 |
+
'.webp', '.tiff', '.tif', '.psd', '.raw', '.arw', '.cr2', '.nrw',
|
| 15 |
+
'.k25', '.bmp', '.dib', '.heif', '.heic', '.ind', '.indd', '.indt',
|
| 16 |
+
'.jp2', '.j2k', '.jpf', '.jpx', '.jpm', '.mj2', '.svg', '.svgz',
|
| 17 |
+
'.ai', '.eps', '.ico'
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def is_video_file(filename):
|
| 22 |
+
return os.path.splitext(filename)[-1].lower() in video_extensions
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_image_file(filename):
|
| 26 |
+
return os.path.splitext(filename)[-1].lower() in image_extensions
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_video_or_image(filename):
|
| 30 |
+
file_extension = os.path.splitext(filename)[-1].lower()
|
| 31 |
+
# 检查扩展名是否在定义的视频或图片文件后缀集合中
|
| 32 |
+
return file_extension in video_extensions or file_extension in image_extensions
|
backend/tools/inpaint_tools.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from backend import config
|
| 6 |
+
from backend.inpaint.lama_inpaint import LamaInpaint
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def batch_generator(data, max_batch_size):
|
| 10 |
+
"""
|
| 11 |
+
根据data大小,生成最大长度不超过max_batch_size的均匀批次数据
|
| 12 |
+
"""
|
| 13 |
+
n_samples = len(data)
|
| 14 |
+
# 尝试找到一个比MAX_BATCH_SIZE小的batch_size,以使得所有的批次数量尽量接近
|
| 15 |
+
batch_size = max_batch_size
|
| 16 |
+
num_batches = n_samples // batch_size
|
| 17 |
+
|
| 18 |
+
# 处理最后一批可能不足batch_size的情况
|
| 19 |
+
# 如果最后一批少于其他批次,则减小batch_size尝试平衡每批的数量
|
| 20 |
+
while n_samples % batch_size < batch_size / 2.0 and batch_size > 1:
|
| 21 |
+
batch_size -= 1 # 减小批次大小
|
| 22 |
+
num_batches = n_samples // batch_size
|
| 23 |
+
|
| 24 |
+
# 生成前num_batches个批次
|
| 25 |
+
for i in range(num_batches):
|
| 26 |
+
yield data[i * batch_size:(i + 1) * batch_size]
|
| 27 |
+
|
| 28 |
+
# 将剩余的数据作为最后一个批次
|
| 29 |
+
last_batch_start = num_batches * batch_size
|
| 30 |
+
if last_batch_start < n_samples:
|
| 31 |
+
yield data[last_batch_start:]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def inference_task(batch_data):
|
| 35 |
+
inpainted_frame_dict = dict()
|
| 36 |
+
for data in batch_data:
|
| 37 |
+
index, original_frame, coords_list = data
|
| 38 |
+
mask_size = original_frame.shape[:2]
|
| 39 |
+
mask = create_mask(mask_size, coords_list)
|
| 40 |
+
inpaint_frame = inpaint(original_frame, mask)
|
| 41 |
+
inpainted_frame_dict[index] = inpaint_frame
|
| 42 |
+
return inpainted_frame_dict
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def parallel_inference(inputs, batch_size=None, pool_size=None):
|
| 46 |
+
"""
|
| 47 |
+
并行推理,同时保持结果顺序
|
| 48 |
+
"""
|
| 49 |
+
if pool_size is None:
|
| 50 |
+
pool_size = multiprocessing.cpu_count()
|
| 51 |
+
# 使用上下文管理器自动管理进程池
|
| 52 |
+
with multiprocessing.Pool(processes=pool_size) as pool:
|
| 53 |
+
batched_inputs = list(batch_generator(inputs, batch_size))
|
| 54 |
+
# 使用map函数保证输入输出的顺序是一致的
|
| 55 |
+
batch_results = pool.map(inference_task, batched_inputs)
|
| 56 |
+
# 将批推理结果展平
|
| 57 |
+
index_inpainted_frames = [item for sublist in batch_results for item in sublist]
|
| 58 |
+
return index_inpainted_frames
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def inpaint(img, mask):
|
| 62 |
+
lama_inpaint_instance = LamaInpaint()
|
| 63 |
+
img_inpainted = lama_inpaint_instance(img, mask)
|
| 64 |
+
return img_inpainted
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def inpaint_with_multiple_masks(censored_img, mask_list):
|
| 68 |
+
inpainted_frame = censored_img
|
| 69 |
+
if mask_list:
|
| 70 |
+
for mask in mask_list:
|
| 71 |
+
inpainted_frame = inpaint(inpainted_frame, mask)
|
| 72 |
+
return inpainted_frame
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def create_mask(size, coords_list):
|
| 76 |
+
mask = np.zeros(size, dtype="uint8")
|
| 77 |
+
if coords_list:
|
| 78 |
+
for coords in coords_list:
|
| 79 |
+
xmin, xmax, ymin, ymax = coords
|
| 80 |
+
# 为了避免框过小,放大10个像素
|
| 81 |
+
x1 = xmin - config.SUBTITLE_AREA_DEVIATION_PIXEL
|
| 82 |
+
if x1 < 0:
|
| 83 |
+
x1 = 0
|
| 84 |
+
y1 = ymin - config.SUBTITLE_AREA_DEVIATION_PIXEL
|
| 85 |
+
if y1 < 0:
|
| 86 |
+
y1 = 0
|
| 87 |
+
x2 = xmax + config.SUBTITLE_AREA_DEVIATION_PIXEL
|
| 88 |
+
y2 = ymax + config.SUBTITLE_AREA_DEVIATION_PIXEL
|
| 89 |
+
cv2.rectangle(mask, (x1, y1),
|
| 90 |
+
(x2, y2), (255, 255, 255), thickness=-1)
|
| 91 |
+
return mask
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def inpaint_video(video_path, sub_list):
|
| 95 |
+
index = 0
|
| 96 |
+
frame_to_inpaint_list = []
|
| 97 |
+
video_cap = cv2.VideoCapture(video_path)
|
| 98 |
+
while True:
|
| 99 |
+
# 读取视频帧
|
| 100 |
+
ret, frame = video_cap.read()
|
| 101 |
+
if not ret:
|
| 102 |
+
break
|
| 103 |
+
index += 1
|
| 104 |
+
if index in sub_list.keys():
|
| 105 |
+
frame_to_inpaint_list.append((index, frame, sub_list[index]))
|
| 106 |
+
if len(frame_to_inpaint_list) > config.PROPAINTER_MAX_LOAD_NUM:
|
| 107 |
+
batch_results = parallel_inference(frame_to_inpaint_list)
|
| 108 |
+
for index, frame in batch_results:
|
| 109 |
+
file_name = f'/home/yao/Documents/Project/video-subtitle-remover/test/temp/{index}.png'
|
| 110 |
+
cv2.imwrite(file_name, frame)
|
| 111 |
+
print(f"success write: {file_name}")
|
| 112 |
+
frame_to_inpaint_list.clear()
|
| 113 |
+
print(f'finished')
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == '__main__':
|
| 117 |
+
multiprocessing.set_start_method("spawn")
|
backend/tools/makedist.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from qpt.executor import CreateExecutableModule as CEM
|
| 4 |
+
from qpt.modules.cuda import CopyCUDAPackage
|
| 5 |
+
from qpt.smart_opt import set_default_pip_source
|
| 6 |
+
from qpt.kernel.qinterpreter import PYPI_PIP_SOURCE
|
| 7 |
+
from qpt.modules.package import CustomPackage, DEFAULT_DEPLOY_MODE
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
WORK_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
LAUNCH_PATH = os.path.join(WORK_DIR, 'gui.py')
|
| 14 |
+
SAVE_PATH = os.path.join(os.path.dirname(WORK_DIR), 'vsr_out')
|
| 15 |
+
ICON_PATH = os.path.join(WORK_DIR, "design", "vsr.ico")
|
| 16 |
+
|
| 17 |
+
# 解析命令行参数
|
| 18 |
+
parser = argparse.ArgumentParser(description="打包程序")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--cuda",
|
| 21 |
+
nargs="?", # 可选参数值
|
| 22 |
+
const="11.8", # 如果只写 --cuda,默认值是 10.2
|
| 23 |
+
default=None, # 不写 --cuda,则为 None
|
| 24 |
+
help="是否包含CUDA模块,可指定版本,如 --cuda 或 --cuda=11.8"
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--directml",
|
| 28 |
+
nargs="?", # 可选参数值
|
| 29 |
+
const=True, # 如果只写 --directml,默认为True
|
| 30 |
+
default=None, # 不写 --directml,则为 None
|
| 31 |
+
help="是否使用DirectML加速,仅指定 --directml 即可启用"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
sub_modules = []
|
| 37 |
+
|
| 38 |
+
if args.cuda == "11.8":
|
| 39 |
+
sub_modules.append(CustomPackage("torch==2.7.0 torchvision==0.22.0", deploy_mode=DEFAULT_DEPLOY_MODE, find_links=PYPI_PIP_SOURCE, opts="--index-url https://download.pytorch.org/whl/cu118 "))
|
| 40 |
+
elif args.cuda == "12.6":
|
| 41 |
+
sub_modules.append(CustomPackage("torch==2.7.0 torchvision==0.22.0", deploy_mode=DEFAULT_DEPLOY_MODE, find_links=PYPI_PIP_SOURCE, opts="--index-url https://download.pytorch.org/whl/cu126 "))
|
| 42 |
+
elif args.cuda == "12.8":
|
| 43 |
+
sub_modules.append(CustomPackage("torch==2.7.0 torchvision==0.22.0", deploy_mode=DEFAULT_DEPLOY_MODE, find_links=PYPI_PIP_SOURCE, opts="--index-url https://download.pytorch.org/whl/cu128 "))
|
| 44 |
+
|
| 45 |
+
if args.directml:
|
| 46 |
+
sub_modules.append(CustomPackage("torch_directml==0.2.5.dev240914", deploy_mode=DEFAULT_DEPLOY_MODE))
|
| 47 |
+
|
| 48 |
+
if os.getenv("QPT_Action") == "True":
|
| 49 |
+
set_default_pip_source(PYPI_PIP_SOURCE)
|
| 50 |
+
|
| 51 |
+
module = CEM(
|
| 52 |
+
work_dir=WORK_DIR,
|
| 53 |
+
launcher_py_path=LAUNCH_PATH,
|
| 54 |
+
save_path=SAVE_PATH,
|
| 55 |
+
icon=ICON_PATH,
|
| 56 |
+
hidden_terminal=False,
|
| 57 |
+
requirements_file="./requirements.txt",
|
| 58 |
+
sub_modules=sub_modules,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
module.make()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
main()
|
backend/tools/merge_video.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def merge_video(video_input_path0, video_input_path1, video_output_path):
|
| 5 |
+
"""
|
| 6 |
+
将两个视频文件安装水平方向合并
|
| 7 |
+
"""
|
| 8 |
+
input_video_cap0 = cv2.VideoCapture(video_input_path0)
|
| 9 |
+
input_video_cap1 = cv2.VideoCapture(video_input_path1)
|
| 10 |
+
fps = input_video_cap1.get(cv2.CAP_PROP_FPS)
|
| 11 |
+
size = (int(input_video_cap1.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_video_cap1.get(cv2.CAP_PROP_FRAME_HEIGHT)) * 2)
|
| 12 |
+
video_writer = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
|
| 13 |
+
while True:
|
| 14 |
+
ret0, frame0 = input_video_cap0.read()
|
| 15 |
+
ret1, frame1 = input_video_cap1.read()
|
| 16 |
+
if not ret1 and not ret0:
|
| 17 |
+
break
|
| 18 |
+
else:
|
| 19 |
+
show = cv2.vconcat([frame0, frame1])
|
| 20 |
+
video_writer.write(show)
|
| 21 |
+
video_writer.release()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
v0_path = '../../test/test4.mp4'
|
| 26 |
+
v1_path = '../../test/test4_no_sub(1).mp4'
|
| 27 |
+
video_out_path = '../../test/demo.mp4'
|
| 28 |
+
merge_video(v0_path, v1_path, video_out_path)
|
| 29 |
+
# ffmpeg 命令 mp4转gif
|
| 30 |
+
# ffmpeg -i demo3.mp4 -vf "scale=w=720:h=-1,fps=15,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" -loop 0 -r 15 -f gif output.gif
|
| 31 |
+
# 宽度固定400,高度成比例:
|
| 32 |
+
# ffmpeg - i input.avi -vf scale=400:-2
|
backend/tools/train/dataset_sttn.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms as transforms
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from backend.tools.train.utils_sttn import ZipReader, create_random_shape_with_random_motion
|
| 8 |
+
from backend.tools.train.utils_sttn import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# 自定义的数据集
|
| 12 |
+
class Dataset(torch.utils.data.Dataset):
|
| 13 |
+
def __init__(self, args: dict, split='train', debug=False):
|
| 14 |
+
# 初始化函数,传入配置参数字典,数据集划分类型,默认为'train'
|
| 15 |
+
self.args = args
|
| 16 |
+
self.split = split
|
| 17 |
+
self.sample_length = args['sample_length'] # 样本长度参数
|
| 18 |
+
self.size = self.w, self.h = (args['w'], args['h']) # 设置图像的目标宽高
|
| 19 |
+
|
| 20 |
+
# 打开存放数据相关信息的json文件
|
| 21 |
+
with open(os.path.join(args['data_root'], args['name'], split+'.json'), 'r') as f:
|
| 22 |
+
self.video_dict = json.load(f) # 加载json文件内容
|
| 23 |
+
self.video_names = list(self.video_dict.keys()) # 获取视频的名称列表
|
| 24 |
+
if debug or split != 'train': # 如果是调试模式或者不是训练集,只取前100个视频
|
| 25 |
+
self.video_names = self.video_names[:100]
|
| 26 |
+
|
| 27 |
+
# 定义数据的转换操作,转换成堆叠的张量
|
| 28 |
+
self._to_tensors = transforms.Compose([
|
| 29 |
+
Stack(),
|
| 30 |
+
ToTorchFormatTensor(), # 便于在PyTorch中使用的张量格式
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
# 返回数据集中视频的数量
|
| 35 |
+
return len(self.video_names)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
# 获取一个样本项
|
| 39 |
+
try:
|
| 40 |
+
item = self.load_item(index) # 尝试加载指定索引的数据项
|
| 41 |
+
except:
|
| 42 |
+
print('Loading error in video {}'.format(self.video_names[index])) # 如果加载出错,打印出错信息
|
| 43 |
+
item = self.load_item(0) # 加载第一个项目作为兜底
|
| 44 |
+
return item
|
| 45 |
+
|
| 46 |
+
def load_item(self, index):
|
| 47 |
+
# 加载数据项的具体实现
|
| 48 |
+
video_name = self.video_names[index] # 根据索引获取视频名称
|
| 49 |
+
# 为所有视频帧生成帧文件名列表
|
| 50 |
+
all_frames = [f"{str(i).zfill(5)}.jpg" for i in range(self.video_dict[video_name])]
|
| 51 |
+
# 生成随机运动的随机形状的遮罩
|
| 52 |
+
all_masks = create_random_shape_with_random_motion(
|
| 53 |
+
len(all_frames), imageHeight=self.h, imageWidth=self.w)
|
| 54 |
+
# 获取参考帧的索引
|
| 55 |
+
ref_index = get_ref_index(len(all_frames), self.sample_length)
|
| 56 |
+
# 读取视频帧
|
| 57 |
+
frames = []
|
| 58 |
+
masks = []
|
| 59 |
+
for idx in ref_index:
|
| 60 |
+
# 读取图片,转化为RGB,调整大小并添加到列表中
|
| 61 |
+
img = ZipReader.imread('{}/{}/JPEGImages/{}.zip'.format(
|
| 62 |
+
self.args['data_root'], self.args['name'], video_name), all_frames[idx]).convert('RGB')
|
| 63 |
+
img = img.resize(self.size)
|
| 64 |
+
frames.append(img)
|
| 65 |
+
masks.append(all_masks[idx])
|
| 66 |
+
if self.split == 'train':
|
| 67 |
+
# 如果是训练集,随机水平翻转图像
|
| 68 |
+
frames = GroupRandomHorizontalFlip()(frames)
|
| 69 |
+
# 转换成张量形式
|
| 70 |
+
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 # 归一化处理
|
| 71 |
+
mask_tensors = self._to_tensors(masks) # 将遮罩转换成张量
|
| 72 |
+
return frame_tensors, mask_tensors # 返回图像和遮罩的张量
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_ref_index(length, sample_length):
|
| 76 |
+
# 获取参考帧索引的实现
|
| 77 |
+
if random.uniform(0, 1) > 0.5:
|
| 78 |
+
# 有一半的概率随机选择帧
|
| 79 |
+
ref_index = random.sample(range(length), sample_length)
|
| 80 |
+
ref_index.sort() # 排序保证顺序
|
| 81 |
+
else:
|
| 82 |
+
# 另一半概率选择连续的帧
|
| 83 |
+
pivot = random.randint(0, length-sample_length)
|
| 84 |
+
ref_index = [pivot+i for i in range(sample_length)]
|
| 85 |
+
return ref_index
|
backend/tools/train/loss_sttn.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AdversarialLoss(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
对抗性损失
|
| 8 |
+
根据论文 https://arxiv.org/abs/1711.10337 实现
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
|
| 12 |
+
"""
|
| 13 |
+
可以选择的损失类型有 'nsgan' | 'lsgan' | 'hinge'
|
| 14 |
+
type: 指定使用哪种类型的 GAN 损失。
|
| 15 |
+
target_real_label: 真实图像的目标标签值。
|
| 16 |
+
target_fake_label: 生成图像的目标标签值。
|
| 17 |
+
"""
|
| 18 |
+
super(AdversarialLoss, self).__init__()
|
| 19 |
+
self.type = type # 损失类型
|
| 20 |
+
# 使用缓冲区注册标签,这样在模型保存和加载时会一同保存和加载
|
| 21 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
| 22 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
| 23 |
+
|
| 24 |
+
# 根据选择的类型初始化不同的损失函数
|
| 25 |
+
if type == 'nsgan':
|
| 26 |
+
self.criterion = nn.BCELoss() # 二进制交叉熵损失(非饱和GAN)
|
| 27 |
+
elif type == 'lsgan':
|
| 28 |
+
self.criterion = nn.MSELoss() # 均方误差损失(最小平方GAN)
|
| 29 |
+
elif type == 'hinge':
|
| 30 |
+
self.criterion = nn.ReLU() # 适用于hinge损失的ReLU函数
|
| 31 |
+
|
| 32 |
+
def __call__(self, outputs, is_real, is_disc=None):
|
| 33 |
+
"""
|
| 34 |
+
调用函数计算损失。
|
| 35 |
+
outputs: 网络输出。
|
| 36 |
+
is_real: 如果是真实样本,则为 True;如果是生成样本,则为 False。
|
| 37 |
+
is_disc: 指示当前是否在优化判别器。
|
| 38 |
+
"""
|
| 39 |
+
if self.type == 'hinge':
|
| 40 |
+
# 对于 hinge 损失
|
| 41 |
+
if is_disc:
|
| 42 |
+
# 如果是判别器
|
| 43 |
+
if is_real:
|
| 44 |
+
outputs = -outputs # 对真实样本反向标签
|
| 45 |
+
# max(0, 1 - (真/假)示例输出)
|
| 46 |
+
return self.criterion(1 + outputs).mean()
|
| 47 |
+
else:
|
| 48 |
+
# 如果是生成器, -min(0, -输出) = max(0, 输出)
|
| 49 |
+
return (-outputs).mean()
|
| 50 |
+
else:
|
| 51 |
+
# 对于 nsgan 和 lsgan 损失
|
| 52 |
+
labels = (self.real_label if is_real else self.fake_label).expand_as(
|
| 53 |
+
outputs)
|
| 54 |
+
# 计算模型输出和目标标签之间的损失
|
| 55 |
+
loss = self.criterion(outputs, labels)
|
| 56 |
+
return loss
|
backend/tools/train/train_sttn.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
from shutil import copyfile
|
| 5 |
+
import torch
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
|
| 8 |
+
from backend.tools.train.trainer_sttn import Trainer
|
| 9 |
+
from backend.tools.train.utils_sttn import (
|
| 10 |
+
get_world_size,
|
| 11 |
+
get_local_rank,
|
| 12 |
+
get_global_rank,
|
| 13 |
+
get_master_ip,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser(description='STTN')
|
| 17 |
+
parser.add_argument('-c', '--config', default='configs_sttn/youtube-vos.json', type=str)
|
| 18 |
+
parser.add_argument('-m', '--model', default='sttn', type=str)
|
| 19 |
+
parser.add_argument('-p', '--port', default='23455', type=str)
|
| 20 |
+
parser.add_argument('-e', '--exam', action='store_true')
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main_worker(rank, config):
|
| 25 |
+
# 如果配置中没有提到局部排序(local_rank),就给它和全局排序(global_rank)赋值为传入的排序(rank)
|
| 26 |
+
if 'local_rank' not in config:
|
| 27 |
+
config['local_rank'] = config['global_rank'] = rank
|
| 28 |
+
|
| 29 |
+
# 如果配置指定为分布式训练
|
| 30 |
+
if config['distributed']:
|
| 31 |
+
# 设置使用的CUDA设备为当前的本地排名对应的GPU
|
| 32 |
+
torch.cuda.set_device(int(config['local_rank']))
|
| 33 |
+
# 初始化分布式进程组,通过nccl后端
|
| 34 |
+
torch.distributed.init_process_group(
|
| 35 |
+
backend='nccl',
|
| 36 |
+
init_method=config['init_method'],
|
| 37 |
+
world_size=config['world_size'],
|
| 38 |
+
rank=config['global_rank'],
|
| 39 |
+
group_name='mtorch'
|
| 40 |
+
)
|
| 41 |
+
# 打印当前GPU的使用情况,输出全球排名和本地排名
|
| 42 |
+
print('using GPU {}-{} for training'.format(
|
| 43 |
+
int(config['global_rank']), int(config['local_rank']))
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# 创建模型保存的目录路径,包括模型名和配置文件名
|
| 47 |
+
config['save_dir'] = os.path.join(
|
| 48 |
+
config['save_dir'], '{}_{}'.format(config['model'], os.path.basename(args.config).split('.')[0])
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# 如果CUDA可用,则设置设备为相应的CUDA设备,否则为CPU
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
|
| 54 |
+
else:
|
| 55 |
+
config['device'] = 'cpu'
|
| 56 |
+
|
| 57 |
+
# 如果不是分布式训练,或者是分布式训练的主节点(rank 0)
|
| 58 |
+
if (not config['distributed']) or config['global_rank'] == 0:
|
| 59 |
+
# 创建模型保存目录,并允许如果该目录存在则忽略创建(exist_ok=True)
|
| 60 |
+
os.makedirs(config['save_dir'], exist_ok=True)
|
| 61 |
+
# 设置配置文件的保存路径
|
| 62 |
+
config_path = os.path.join(
|
| 63 |
+
config['save_dir'], config['config'].split('/')[-1]
|
| 64 |
+
)
|
| 65 |
+
# 如果配置文件不存在,则从给定的配置文件路径复制到新路径
|
| 66 |
+
if not os.path.isfile(config_path):
|
| 67 |
+
copyfile(config['config'], config_path)
|
| 68 |
+
# 打印创建目录的信息
|
| 69 |
+
print('[**] create folder {}'.format(config['save_dir']))
|
| 70 |
+
|
| 71 |
+
# 初始化训练器,传入配置参数和debug标记
|
| 72 |
+
trainer = Trainer(config, debug=args.exam)
|
| 73 |
+
# 开始训练
|
| 74 |
+
trainer.train()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
# 加载配置文件
|
| 79 |
+
config = json.load(open(args.config))
|
| 80 |
+
config['model'] = args.model # 设置模型名称
|
| 81 |
+
config['config'] = args.config # 设置配置文件路径
|
| 82 |
+
|
| 83 |
+
# 设置分布式训练的相关配置
|
| 84 |
+
config['world_size'] = get_world_size() # 获取全局进程数,即训练过程中参与计算的总GPU数量
|
| 85 |
+
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" # 设置初始化方法,包括主节点IP和端口
|
| 86 |
+
config['distributed'] = True if config['world_size'] > 1 else False # 根据世界规模确定是否启用分布式训练
|
| 87 |
+
|
| 88 |
+
# 设置分布式并行训练环境
|
| 89 |
+
if get_master_ip() == "127.0.0.1":
|
| 90 |
+
# 如果主节点IP是本机地址,那么手动启动多个分布式训练进程
|
| 91 |
+
mp.spawn(main_worker, nprocs=config['world_size'], args=(config,))
|
| 92 |
+
else:
|
| 93 |
+
# 如果是由其他工具如OpenMPI启动的多个进程,不需手动创建进程。
|
| 94 |
+
config['local_rank'] = get_local_rank() # 获取本地(单个节点)排名
|
| 95 |
+
config['global_rank'] = get_global_rank() # 获取全局排名
|
| 96 |
+
main_worker(-1, config) # 启动主工作函数
|
backend/tools/train/trainer_sttn.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 8 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 9 |
+
from tensorboardX import SummaryWriter
|
| 10 |
+
|
| 11 |
+
from backend.inpaint.sttn.auto_sttn import Discriminator
|
| 12 |
+
from backend.inpaint.sttn.auto_sttn import InpaintGenerator
|
| 13 |
+
from backend.tools.train.dataset_sttn import Dataset
|
| 14 |
+
from backend.tools.train.loss_sttn import AdversarialLoss
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Trainer:
|
| 18 |
+
def __init__(self, config, debug=False):
|
| 19 |
+
# 训练器初始化
|
| 20 |
+
self.config = config # 保存配置信息
|
| 21 |
+
self.epoch = 0 # 当前训练所处的epoch
|
| 22 |
+
self.iteration = 0 # 当前训练迭代次数
|
| 23 |
+
if debug:
|
| 24 |
+
# 如果是调试模式,设置更频繁的保存和验证频率
|
| 25 |
+
self.config['trainer']['save_freq'] = 5
|
| 26 |
+
self.config['trainer']['valid_freq'] = 5
|
| 27 |
+
self.config['trainer']['iterations'] = 5
|
| 28 |
+
|
| 29 |
+
# 设置数据集和数据加载器
|
| 30 |
+
self.train_dataset = Dataset(config['data_loader'], split='train', debug=debug) # 创建训练集对象
|
| 31 |
+
self.train_sampler = None # 初始化训练集采样器为None
|
| 32 |
+
self.train_args = config['trainer'] # 训练过程参数
|
| 33 |
+
if config['distributed']:
|
| 34 |
+
# 如果是分布式训练,则初始化分布式采样器
|
| 35 |
+
self.train_sampler = DistributedSampler(
|
| 36 |
+
self.train_dataset,
|
| 37 |
+
num_replicas=config['world_size'],
|
| 38 |
+
rank=config['global_rank']
|
| 39 |
+
)
|
| 40 |
+
self.train_loader = DataLoader(
|
| 41 |
+
self.train_dataset,
|
| 42 |
+
batch_size=self.train_args['batch_size'] // config['world_size'],
|
| 43 |
+
shuffle=(self.train_sampler is None), # 如果没有采样器则进行打乱
|
| 44 |
+
num_workers=self.train_args['num_workers'],
|
| 45 |
+
sampler=self.train_sampler
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# 设置损失函数
|
| 49 |
+
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) # 对抗性损失
|
| 50 |
+
self.adversarial_loss = self.adversarial_loss.to(self.config['device']) # 将损失函数转移到相应设备
|
| 51 |
+
self.l1_loss = nn.L1Loss() # L1损失
|
| 52 |
+
|
| 53 |
+
# 初始化生成器和判别器模型
|
| 54 |
+
self.netG = InpaintGenerator() # 生成网络
|
| 55 |
+
self.netG = self.netG.to(self.config['device']) # 转移到设备
|
| 56 |
+
self.netD = Discriminator(
|
| 57 |
+
in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge'
|
| 58 |
+
)
|
| 59 |
+
self.netD = self.netD.to(self.config['device']) # 判别网络
|
| 60 |
+
# 初始化优化器
|
| 61 |
+
self.optimG = torch.optim.Adam(
|
| 62 |
+
self.netG.parameters(), # 生成器参数
|
| 63 |
+
lr=config['trainer']['lr'], # 学习率
|
| 64 |
+
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
|
| 65 |
+
)
|
| 66 |
+
self.optimD = torch.optim.Adam(
|
| 67 |
+
self.netD.parameters(), # 判别器参数
|
| 68 |
+
lr=config['trainer']['lr'], # 学习率
|
| 69 |
+
betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])
|
| 70 |
+
)
|
| 71 |
+
self.load() # 加载模型
|
| 72 |
+
|
| 73 |
+
if config['distributed']:
|
| 74 |
+
# 如果是分布式训练,则使用分布式数据并行包装器
|
| 75 |
+
self.netG = DDP(
|
| 76 |
+
self.netG,
|
| 77 |
+
device_ids=[self.config['local_rank']],
|
| 78 |
+
output_device=self.config['local_rank'],
|
| 79 |
+
broadcast_buffers=True,
|
| 80 |
+
find_unused_parameters=False
|
| 81 |
+
)
|
| 82 |
+
self.netD = DDP(
|
| 83 |
+
self.netD,
|
| 84 |
+
device_ids=[self.config['local_rank']],
|
| 85 |
+
output_device=self.config['local_rank'],
|
| 86 |
+
broadcast_buffers=True,
|
| 87 |
+
find_unused_parameters=False
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# 设置日志记录器
|
| 91 |
+
self.dis_writer = None # 判别器写入器
|
| 92 |
+
self.gen_writer = None # 生成器写入器
|
| 93 |
+
self.summary = {} # 存放摘要统计
|
| 94 |
+
if self.config['global_rank'] == 0 or (not config['distributed']):
|
| 95 |
+
# 如果不是分布式训练或者为分布式训练的主节点
|
| 96 |
+
self.dis_writer = SummaryWriter(
|
| 97 |
+
os.path.join(config['save_dir'], 'dis')
|
| 98 |
+
)
|
| 99 |
+
self.gen_writer = SummaryWriter(
|
| 100 |
+
os.path.join(config['save_dir'], 'gen')
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 获取当前学习率
|
| 104 |
+
def get_lr(self):
|
| 105 |
+
return self.optimG.param_groups[0]['lr']
|
| 106 |
+
|
| 107 |
+
# 调整学习率
|
| 108 |
+
def adjust_learning_rate(self):
|
| 109 |
+
# 计算衰减的学习率
|
| 110 |
+
decay = 0.1 ** (min(self.iteration, self.config['trainer']['niter_steady']) // self.config['trainer']['niter'])
|
| 111 |
+
new_lr = self.config['trainer']['lr'] * decay
|
| 112 |
+
# 如果新的学习率和当前学习率不同,则更新优化器中的学习率
|
| 113 |
+
if new_lr != self.get_lr():
|
| 114 |
+
for param_group in self.optimG.param_groups:
|
| 115 |
+
param_group['lr'] = new_lr
|
| 116 |
+
for param_group in self.optimD.param_groups:
|
| 117 |
+
param_group['lr'] = new_lr
|
| 118 |
+
|
| 119 |
+
# 添加摘要信息
|
| 120 |
+
def add_summary(self, writer, name, val):
|
| 121 |
+
# 添加并更新统计信息,每次迭代都累加
|
| 122 |
+
if name not in self.summary:
|
| 123 |
+
self.summary[name] = 0
|
| 124 |
+
self.summary[name] += val
|
| 125 |
+
# 每100次迭代记录一次
|
| 126 |
+
if writer is not None and self.iteration % 100 == 0:
|
| 127 |
+
writer.add_scalar(name, self.summary[name] / 100, self.iteration)
|
| 128 |
+
self.summary[name] = 0
|
| 129 |
+
|
| 130 |
+
# 加载模型netG and netD
|
| 131 |
+
def load(self):
|
| 132 |
+
model_path = self.config['save_dir'] # 模型的保存路径
|
| 133 |
+
# 检测是否存在最近的模型检查点
|
| 134 |
+
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
| 135 |
+
# 读取最后一个epoch的编号
|
| 136 |
+
latest_epoch = open(os.path.join(
|
| 137 |
+
model_path, 'latest.ckpt'), 'r').read().splitlines()[-1]
|
| 138 |
+
else:
|
| 139 |
+
# 如果不存在latest.ckpt,尝试读取存储好的模型文件列表,获取最近的一个
|
| 140 |
+
ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob(
|
| 141 |
+
os.path.join(model_path, '*.pth'))]
|
| 142 |
+
ckpts.sort() # 排序模型文件,以获取最近的一个
|
| 143 |
+
latest_epoch = ckpts[-1] if len(ckpts) > 0 else None # 获取最近的epoch值
|
| 144 |
+
if latest_epoch is not None:
|
| 145 |
+
# 拼接得到生成器和判别器的模型文件路径
|
| 146 |
+
gen_path = os.path.join(
|
| 147 |
+
model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
|
| 148 |
+
dis_path = os.path.join(
|
| 149 |
+
model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
|
| 150 |
+
opt_path = os.path.join(
|
| 151 |
+
model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
|
| 152 |
+
# 如果是主节点,输出加载模型的信息
|
| 153 |
+
if self.config['global_rank'] == 0:
|
| 154 |
+
print('Loading model from {}...'.format(gen_path))
|
| 155 |
+
# 加载生成器模型
|
| 156 |
+
data = torch.load(gen_path, map_location=self.config['device'])
|
| 157 |
+
self.netG.load_state_dict(data['netG'])
|
| 158 |
+
# 加载判别器模型
|
| 159 |
+
data = torch.load(dis_path, map_location=self.config['device'])
|
| 160 |
+
self.netD.load_state_dict(data['netD'])
|
| 161 |
+
# 加载优化器状态
|
| 162 |
+
data = torch.load(opt_path, map_location=self.config['device'])
|
| 163 |
+
self.optimG.load_state_dict(data['optimG'])
|
| 164 |
+
self.optimD.load_state_dict(data['optimD'])
|
| 165 |
+
# 更新当前epoch和迭代次数
|
| 166 |
+
self.epoch = data['epoch']
|
| 167 |
+
self.iteration = data['iteration']
|
| 168 |
+
else:
|
| 169 |
+
# 如果没有找到模型文件,则输出警告信息
|
| 170 |
+
if self.config['global_rank'] == 0:
|
| 171 |
+
print('Warning: There is no trained model found. An initialized model will be used.')
|
| 172 |
+
|
| 173 |
+
# 保存模型参数,每次评估周期 (eval_epoch) 调用一次
|
| 174 |
+
def save(self, it):
|
| 175 |
+
# 只在全局排名为0的进程上执行保存操作,通常代表主节点
|
| 176 |
+
if self.config['global_rank'] == 0:
|
| 177 |
+
# 生成保存生成器模型状态字典的文件路径
|
| 178 |
+
gen_path = os.path.join(
|
| 179 |
+
self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5)))
|
| 180 |
+
# 生成保存判别器模型状态字典的文件路径
|
| 181 |
+
dis_path = os.path.join(
|
| 182 |
+
self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5)))
|
| 183 |
+
# 生成保存优化器状态字典的文件路径
|
| 184 |
+
opt_path = os.path.join(
|
| 185 |
+
self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5)))
|
| 186 |
+
|
| 187 |
+
# 打印消息表示模型正在保存
|
| 188 |
+
print('\nsaving model to {} ...'.format(gen_path))
|
| 189 |
+
|
| 190 |
+
# 判断模型是否是经过DataParallel或DDP包装的,若是则获取原始的模型
|
| 191 |
+
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
| 192 |
+
netG = self.netG.module
|
| 193 |
+
netD = self.netD.module
|
| 194 |
+
else:
|
| 195 |
+
netG = self.netG
|
| 196 |
+
netD = self.netD
|
| 197 |
+
|
| 198 |
+
# 保存生成器和判别器的模型参数
|
| 199 |
+
torch.save({'netG': netG.state_dict()}, gen_path)
|
| 200 |
+
torch.save({'netD': netD.state_dict()}, dis_path)
|
| 201 |
+
# 保存当前的epoch、迭代次数和优化器的状态
|
| 202 |
+
torch.save({
|
| 203 |
+
'epoch': self.epoch,
|
| 204 |
+
'iteration': self.iteration,
|
| 205 |
+
'optimG': self.optimG.state_dict(),
|
| 206 |
+
'optimD': self.optimD.state_dict()
|
| 207 |
+
}, opt_path)
|
| 208 |
+
|
| 209 |
+
# 写入最新的迭代次数到"latest.ckpt"文件
|
| 210 |
+
os.system('echo {} > {}'.format(str(it).zfill(5),
|
| 211 |
+
os.path.join(self.config['save_dir'], 'latest.ckpt')))
|
| 212 |
+
|
| 213 |
+
# 训练入口
|
| 214 |
+
|
| 215 |
+
def train(self):
|
| 216 |
+
# 初始化进度条范围
|
| 217 |
+
pbar = range(int(self.train_args['iterations']))
|
| 218 |
+
# 如果是全局rank 0的进程,则设置显示进度条
|
| 219 |
+
if self.config['global_rank'] == 0:
|
| 220 |
+
pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
| 221 |
+
|
| 222 |
+
# 开始训练循环
|
| 223 |
+
while True:
|
| 224 |
+
self.epoch += 1 # epoch计数增加
|
| 225 |
+
if self.config['distributed']:
|
| 226 |
+
# 如果是分布式训练,则对采样器进行设置,保证每个进程获取的数据不同
|
| 227 |
+
self.train_sampler.set_epoch(self.epoch)
|
| 228 |
+
|
| 229 |
+
# 调用训练一个epoch的函数
|
| 230 |
+
self._train_epoch(pbar)
|
| 231 |
+
# 如果迭代次数超过配置中的迭代上限,则退出循环
|
| 232 |
+
if self.iteration > self.train_args['iterations']:
|
| 233 |
+
break
|
| 234 |
+
# 训练结束输出
|
| 235 |
+
print('\nEnd training....')
|
| 236 |
+
|
| 237 |
+
# 每个训练周期处理输入并计算损失
|
| 238 |
+
|
| 239 |
+
def _train_epoch(self, pbar):
|
| 240 |
+
device = self.config['device'] # 获取设备信息
|
| 241 |
+
|
| 242 |
+
# 遍历数据加载器中的数据
|
| 243 |
+
for frames, masks in self.train_loader:
|
| 244 |
+
# 调整学习率
|
| 245 |
+
self.adjust_learning_rate()
|
| 246 |
+
# 迭代次数+1
|
| 247 |
+
self.iteration += 1
|
| 248 |
+
|
| 249 |
+
# 将frames和masks转移到设备上
|
| 250 |
+
frames, masks = frames.to(device), masks.to(device)
|
| 251 |
+
b, t, c, h, w = frames.size() # 获取帧和蒙版的尺寸
|
| 252 |
+
masked_frame = (frames * (1 - masks).float()) # 应用蒙版到图像
|
| 253 |
+
pred_img = self.netG(masked_frame, masks) # 使用生成器生成填充图像
|
| 254 |
+
# 调整frames和masks的维度以符合网络的输入要求
|
| 255 |
+
frames = frames.view(b * t, c, h, w)
|
| 256 |
+
masks = masks.view(b * t, 1, h, w)
|
| 257 |
+
comp_img = frames * (1. - masks) + masks * pred_img # 生成最终的组合图像
|
| 258 |
+
|
| 259 |
+
gen_loss = 0 # 初始化生成器损失
|
| 260 |
+
dis_loss = 0 # 初始化判别器损失
|
| 261 |
+
|
| 262 |
+
# 判别器对抗性损失
|
| 263 |
+
real_vid_feat = self.netD(frames) # 判别器对真实图像判别
|
| 264 |
+
fake_vid_feat = self.netD(comp_img.detach()) # 判别器对生成图像判别,注意detach是为了不计算梯度
|
| 265 |
+
dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) # 真实图像的损失
|
| 266 |
+
dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) # 生成图像的损失
|
| 267 |
+
dis_loss += (dis_real_loss + dis_fake_loss) / 2 # 求平均的判别器损失
|
| 268 |
+
# 添加判别器损失到摘要
|
| 269 |
+
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
| 270 |
+
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
| 271 |
+
# 优化判别器
|
| 272 |
+
self.optimD.zero_grad()
|
| 273 |
+
dis_loss.backward()
|
| 274 |
+
self.optimD.step()
|
| 275 |
+
|
| 276 |
+
# 生成器对抗性损失
|
| 277 |
+
gen_vid_feat = self.netD(comp_img)
|
| 278 |
+
gan_loss = self.adversarial_loss(gen_vid_feat, True, False) # 生成器的对抗损失
|
| 279 |
+
gan_loss = gan_loss * self.config['losses']['adversarial_weight'] # 权重放大
|
| 280 |
+
gen_loss += gan_loss # 累加到生成器损失
|
| 281 |
+
# 添加生成器对抗性损失到摘要
|
| 282 |
+
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
| 283 |
+
|
| 284 |
+
# 生成器L1损失
|
| 285 |
+
hole_loss = self.l1_loss(pred_img * masks, frames * masks) # 只计算有蒙版区域的损失
|
| 286 |
+
# 考虑蒙版的平均值,乘以配置中的hole_weight
|
| 287 |
+
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
| 288 |
+
gen_loss += hole_loss # 累加到生成器损失
|
| 289 |
+
# 添加hole_loss到摘要
|
| 290 |
+
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
| 291 |
+
|
| 292 |
+
# 计算蒙版外区域的L1损失
|
| 293 |
+
valid_loss = self.l1_loss(pred_img * (1 - masks), frames * (1 - masks))
|
| 294 |
+
# 考虑非蒙版区的平均值,乘以配置中的valid_weight
|
| 295 |
+
valid_loss = valid_loss / torch.mean(1 - masks) * self.config['losses']['valid_weight']
|
| 296 |
+
gen_loss += valid_loss # 累加到生成器损失
|
| 297 |
+
# 添加valid_loss到摘要
|
| 298 |
+
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
| 299 |
+
|
| 300 |
+
# 生成器优化
|
| 301 |
+
self.optimG.zero_grad()
|
| 302 |
+
gen_loss.backward()
|
| 303 |
+
self.optimG.step()
|
| 304 |
+
|
| 305 |
+
# 控制台日志输出
|
| 306 |
+
if self.config['global_rank'] == 0:
|
| 307 |
+
pbar.update(1) # 进度条更新
|
| 308 |
+
pbar.set_description(( # 设置进度条描述
|
| 309 |
+
f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" # 打印损失数值
|
| 310 |
+
f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# 模型保存
|
| 314 |
+
if self.iteration % self.train_args['save_freq'] == 0:
|
| 315 |
+
self.save(int(self.iteration // self.train_args['save_freq']))
|
| 316 |
+
# 迭代次数终止判断
|
| 317 |
+
if self.iteration > self.train_args['iterations']:
|
| 318 |
+
break
|
| 319 |
+
|
backend/tools/train/utils_sttn.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib.patches as patches
|
| 4 |
+
from matplotlib.path import Path
|
| 5 |
+
import io
|
| 6 |
+
import cv2
|
| 7 |
+
import random
|
| 8 |
+
import zipfile
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image, ImageOps
|
| 11 |
+
import torch
|
| 12 |
+
import matplotlib
|
| 13 |
+
from matplotlib import pyplot as plt
|
| 14 |
+
matplotlib.use('agg')
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ZipReader(object):
|
| 18 |
+
file_dict = dict()
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(ZipReader, self).__init__()
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def build_file_dict(path):
|
| 25 |
+
file_dict = ZipReader.file_dict
|
| 26 |
+
if path in file_dict:
|
| 27 |
+
return file_dict[path]
|
| 28 |
+
else:
|
| 29 |
+
file_handle = zipfile.ZipFile(path, 'r')
|
| 30 |
+
file_dict[path] = file_handle
|
| 31 |
+
return file_dict[path]
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def imread(path, image_name):
|
| 35 |
+
zfile = ZipReader.build_file_dict(path)
|
| 36 |
+
data = zfile.read(image_name)
|
| 37 |
+
im = Image.open(io.BytesIO(data))
|
| 38 |
+
return im
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GroupRandomHorizontalFlip(object):
|
| 42 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, is_flow=False):
|
| 46 |
+
self.is_flow = is_flow
|
| 47 |
+
|
| 48 |
+
def __call__(self, img_group, is_flow=False):
|
| 49 |
+
v = random.random()
|
| 50 |
+
if v < 0.5:
|
| 51 |
+
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
| 52 |
+
if self.is_flow:
|
| 53 |
+
for i in range(0, len(ret), 2):
|
| 54 |
+
# invert flow pixel values when flipping
|
| 55 |
+
ret[i] = ImageOps.invert(ret[i])
|
| 56 |
+
return ret
|
| 57 |
+
else:
|
| 58 |
+
return img_group
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Stack(object):
|
| 62 |
+
def __init__(self, roll=False):
|
| 63 |
+
self.roll = roll
|
| 64 |
+
|
| 65 |
+
def __call__(self, img_group):
|
| 66 |
+
mode = img_group[0].mode
|
| 67 |
+
if mode == '1':
|
| 68 |
+
img_group = [img.convert('L') for img in img_group]
|
| 69 |
+
mode = 'L'
|
| 70 |
+
if mode == 'L':
|
| 71 |
+
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
| 72 |
+
elif mode == 'RGB':
|
| 73 |
+
if self.roll:
|
| 74 |
+
return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
|
| 75 |
+
else:
|
| 76 |
+
return np.stack(img_group, axis=2)
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError(f"Image mode {mode}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ToTorchFormatTensor(object):
|
| 82 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
| 83 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
| 84 |
+
|
| 85 |
+
def __init__(self, div=True):
|
| 86 |
+
self.div = div
|
| 87 |
+
|
| 88 |
+
def __call__(self, pic):
|
| 89 |
+
if isinstance(pic, np.ndarray):
|
| 90 |
+
# numpy img: [L, C, H, W]
|
| 91 |
+
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
| 92 |
+
else:
|
| 93 |
+
# handle PIL Image
|
| 94 |
+
img = torch.ByteTensor(
|
| 95 |
+
torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 96 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
| 97 |
+
# put it from HWC to CHW format
|
| 98 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
| 99 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 100 |
+
img = img.float().div(255) if self.div else img.float()
|
| 101 |
+
return img
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
|
| 105 |
+
# get a random shape
|
| 106 |
+
height = random.randint(imageHeight//3, imageHeight-1)
|
| 107 |
+
width = random.randint(imageWidth//3, imageWidth-1)
|
| 108 |
+
edge_num = random.randint(6, 8)
|
| 109 |
+
ratio = random.randint(6, 8)/10
|
| 110 |
+
region = get_random_shape(
|
| 111 |
+
edge_num=edge_num, ratio=ratio, height=height, width=width)
|
| 112 |
+
region_width, region_height = region.size
|
| 113 |
+
# get random position
|
| 114 |
+
x, y = random.randint(
|
| 115 |
+
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
|
| 116 |
+
velocity = get_random_velocity(max_speed=3)
|
| 117 |
+
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
| 118 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
| 119 |
+
masks = [m.convert('L')]
|
| 120 |
+
# return fixed masks
|
| 121 |
+
if random.uniform(0, 1) > 0.5:
|
| 122 |
+
return masks*video_length
|
| 123 |
+
# return moving masks
|
| 124 |
+
for _ in range(video_length-1):
|
| 125 |
+
x, y, velocity = random_move_control_points(
|
| 126 |
+
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
|
| 127 |
+
m = Image.fromarray(
|
| 128 |
+
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
| 129 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
| 130 |
+
masks.append(m.convert('L'))
|
| 131 |
+
return masks
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
|
| 135 |
+
'''
|
| 136 |
+
There is the initial point and 3 points per cubic bezier curve.
|
| 137 |
+
Thus, the curve will only pass though n points, which will be the sharp edges.
|
| 138 |
+
The other 2 modify the shape of the bezier curve.
|
| 139 |
+
edge_num, Number of possibly sharp edges
|
| 140 |
+
points_num, number of points in the Path
|
| 141 |
+
ratio, (0, 1) magnitude of the perturbation from the unit circle,
|
| 142 |
+
'''
|
| 143 |
+
points_num = edge_num*3 + 1
|
| 144 |
+
angles = np.linspace(0, 2*np.pi, points_num)
|
| 145 |
+
codes = np.full(points_num, Path.CURVE4)
|
| 146 |
+
codes[0] = Path.MOVETO
|
| 147 |
+
# Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
|
| 148 |
+
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
|
| 149 |
+
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
|
| 150 |
+
verts[-1, :] = verts[0, :]
|
| 151 |
+
path = Path(verts, codes)
|
| 152 |
+
# draw paths into images
|
| 153 |
+
fig = plt.figure()
|
| 154 |
+
ax = fig.add_subplot(111)
|
| 155 |
+
patch = patches.PathPatch(path, facecolor='black', lw=2)
|
| 156 |
+
ax.add_patch(patch)
|
| 157 |
+
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
|
| 158 |
+
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
|
| 159 |
+
ax.axis('off') # removes the axis to leave only the shape
|
| 160 |
+
fig.canvas.draw()
|
| 161 |
+
# convert plt images into numpy images
|
| 162 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 163 |
+
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
|
| 164 |
+
plt.close(fig)
|
| 165 |
+
# postprocess
|
| 166 |
+
data = cv2.resize(data, (width, height))[:, :, 0]
|
| 167 |
+
data = (1 - np.array(data > 0).astype(np.uint8))*255
|
| 168 |
+
corrdinates = np.where(data > 0)
|
| 169 |
+
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
|
| 170 |
+
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
|
| 171 |
+
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
|
| 172 |
+
return region
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
| 176 |
+
speed, angle = velocity
|
| 177 |
+
d_speed, d_angle = maxAcceleration
|
| 178 |
+
if dist == 'uniform':
|
| 179 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
| 180 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
| 181 |
+
elif dist == 'guassian':
|
| 182 |
+
speed += np.random.normal(0, d_speed / 2)
|
| 183 |
+
angle += np.random.normal(0, d_angle / 2)
|
| 184 |
+
else:
|
| 185 |
+
raise NotImplementedError(
|
| 186 |
+
f'Distribution type {dist} is not supported.')
|
| 187 |
+
return (speed, angle)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_random_velocity(max_speed=3, dist='uniform'):
|
| 191 |
+
if dist == 'uniform':
|
| 192 |
+
speed = np.random.uniform(max_speed)
|
| 193 |
+
elif dist == 'guassian':
|
| 194 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
| 195 |
+
else:
|
| 196 |
+
raise NotImplementedError(
|
| 197 |
+
f'Distribution type {dist} is not supported.')
|
| 198 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
| 199 |
+
return (speed, angle)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3):
|
| 203 |
+
region_width, region_height = region_size
|
| 204 |
+
speed, angle = lineVelocity
|
| 205 |
+
X += int(speed * np.cos(angle))
|
| 206 |
+
Y += int(speed * np.sin(angle))
|
| 207 |
+
lineVelocity = random_accelerate(
|
| 208 |
+
lineVelocity, maxLineAcceleration, dist='guassian')
|
| 209 |
+
if (X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0):
|
| 210 |
+
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
| 211 |
+
new_X = np.clip(X, 0, imageHeight - region_height)
|
| 212 |
+
new_Y = np.clip(Y, 0, imageWidth - region_width)
|
| 213 |
+
return new_X, new_Y, lineVelocity
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_world_size():
|
| 217 |
+
"""Find OMPI world size without calling mpi functions
|
| 218 |
+
:rtype: int
|
| 219 |
+
"""
|
| 220 |
+
if os.environ.get('PMI_SIZE') is not None:
|
| 221 |
+
return int(os.environ.get('PMI_SIZE') or 1)
|
| 222 |
+
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
|
| 223 |
+
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
|
| 224 |
+
else:
|
| 225 |
+
return torch.cuda.device_count()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_global_rank():
|
| 229 |
+
"""Find OMPI world rank without calling mpi functions
|
| 230 |
+
:rtype: int
|
| 231 |
+
"""
|
| 232 |
+
if os.environ.get('PMI_RANK') is not None:
|
| 233 |
+
return int(os.environ.get('PMI_RANK') or 0)
|
| 234 |
+
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
|
| 235 |
+
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
|
| 236 |
+
else:
|
| 237 |
+
return 0
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_local_rank():
|
| 241 |
+
"""Find OMPI local rank without calling mpi functions
|
| 242 |
+
:rtype: int
|
| 243 |
+
"""
|
| 244 |
+
if os.environ.get('MPI_LOCALRANKID') is not None:
|
| 245 |
+
return int(os.environ.get('MPI_LOCALRANKID') or 0)
|
| 246 |
+
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
|
| 247 |
+
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
|
| 248 |
+
else:
|
| 249 |
+
return 0
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_master_ip():
|
| 253 |
+
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
|
| 254 |
+
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
|
| 255 |
+
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
|
| 256 |
+
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
|
| 257 |
+
else:
|
| 258 |
+
return "127.0.0.1"
|
| 259 |
+
|
| 260 |
+
if __name__ == '__main__':
|
| 261 |
+
trials = 10
|
| 262 |
+
for _ in range(trials):
|
| 263 |
+
video_length = 10
|
| 264 |
+
# The returned masks are either stationary (50%) or moving (50%)
|
| 265 |
+
masks = create_random_shape_with_random_motion(
|
| 266 |
+
video_length, imageHeight=240, imageWidth=432)
|
| 267 |
+
|
| 268 |
+
for m in masks:
|
| 269 |
+
cv2.imshow('mask', np.array(m))
|
| 270 |
+
cv2.waitKey(500)
|
| 271 |
+
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12
|
| 2 |
+
|
| 3 |
+
RUN --mount=type=cache,target=/root/.cache,sharing=private \
|
| 4 |
+
apt update && \
|
| 5 |
+
apt install -y libgl1-mesa-glx && \
|
| 6 |
+
true
|
| 7 |
+
|
| 8 |
+
ADD . /vsr
|
| 9 |
+
ARG CUDA_VERSION=11.8
|
| 10 |
+
ARG USE_DIRECTML=0
|
| 11 |
+
|
| 12 |
+
# 如果是 CUDA 版本,执行 CUDA 特定设置
|
| 13 |
+
RUN --mount=type=cache,target=/root/.cache,sharing=private \
|
| 14 |
+
if [ "${USE_DIRECTML:-0}" != "1" ]; then \
|
| 15 |
+
pip install paddlepaddle==3.0 && \
|
| 16 |
+
pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu$(echo ${CUDA_VERSION} | tr -d '.') && \
|
| 17 |
+
pip install -r /vsr/requirements.txt; \
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
# 如果是 DirectML 版本,执行 DirectML 特定设置
|
| 21 |
+
RUN --mount=type=cache,target=/root/.cache,sharing=private \
|
| 22 |
+
if [ "${USE_DIRECTML:-0}" = "1" ]; then \
|
| 23 |
+
pip install paddlepaddle==3.0 && \
|
| 24 |
+
pip install torch_directml==0.2.5.dev240914 && \
|
| 25 |
+
pip install -r /vsr/requirements.txt; \
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
ENV LD_LIBRARY_PATH=/usr/local/lib/python3.12/site-packages/nvidia/cudnn/lib/
|
| 29 |
+
WORKDIR /vsr
|
| 30 |
+
CMD ["python", "/vsr/backend/main.py"]
|
google_colabs/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Google Colab Gradio Interface
|
| 2 |
+
|
| 3 |
+
This folder contains two versions of the Google Colab notebook:
|
| 4 |
+
|
| 5 |
+
## Files
|
| 6 |
+
|
| 7 |
+
### 1. `Video_Subtitle_Remover_Gradio.ipynb` ⭐ **NEW - Recommended**
|
| 8 |
+
**Gradio Web Interface** - Easy-to-use browser-based UI
|
| 9 |
+
|
| 10 |
+
**Features:**
|
| 11 |
+
- 🖱️ Click-and-upload interface (no coding required)
|
| 12 |
+
- 🎨 Visual algorithm selection
|
| 13 |
+
- ⚙️ Adjustable parameters with sliders
|
| 14 |
+
- 📊 Real-time progress tracking
|
| 15 |
+
- 📥 One-click download
|
| 16 |
+
|
| 17 |
+
**Best for:**
|
| 18 |
+
- Users who prefer GUI
|
| 19 |
+
- Quick testing
|
| 20 |
+
- Non-technical users
|
| 21 |
+
- Multiple video processing
|
| 22 |
+
|
| 23 |
+
**Usage:**
|
| 24 |
+
1. Open in Colab
|
| 25 |
+
2. Run all cells
|
| 26 |
+
3. Click the generated link
|
| 27 |
+
4. Use web interface in browser
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
### 2. `Video_Subtitle_Remover.ipynb`
|
| 32 |
+
**Traditional Notebook** - Code-based approach
|
| 33 |
+
|
| 34 |
+
**Features:**
|
| 35 |
+
- Step-by-step execution
|
| 36 |
+
- Full control over parameters
|
| 37 |
+
- Good for understanding the process
|
| 38 |
+
- Batch processing scripts
|
| 39 |
+
|
| 40 |
+
**Best for:**
|
| 41 |
+
- Users comfortable with code
|
| 42 |
+
- Custom workflows
|
| 43 |
+
- Debugging
|
| 44 |
+
- Learning the internals
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## Quick Start
|
| 49 |
+
|
| 50 |
+
### For Gradio Interface (Recommended):
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
1. Open Video_Subtitle_Remover_Gradio.ipynb in Colab
|
| 54 |
+
2. Runtime → Change runtime type → GPU
|
| 55 |
+
3. Run all cells (Ctrl+F9)
|
| 56 |
+
4. Click the gradio.live URL
|
| 57 |
+
5. Upload video and click "Remove Subtitles"
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### For Traditional Notebook:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
1. Open Video_Subtitle_Remover.ipynb in Colab
|
| 64 |
+
2. Runtime → Change runtime type → GPU
|
| 65 |
+
3. Run cells step by step
|
| 66 |
+
4. Configure settings in Step 5
|
| 67 |
+
5. Run processing in Step 7
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Algorithm Recommendations
|
| 71 |
+
|
| 72 |
+
| Use Case | Algorithm | Quality | Speed |
|
| 73 |
+
|----------|-----------|---------|-------|
|
| 74 |
+
| **Best Quality** | DiffuEraser | ⭐⭐⭐⭐⭐ | ⭐⭐ |
|
| 75 |
+
| **Fastest** | STTN | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ |
|
| 76 |
+
| **Balanced** | Stable Diffusion | ⭐⭐⭐⭐ | ⭐⭐⭐ |
|
| 77 |
+
| **High Motion** | ProPainter | ⭐⭐⭐⭐⭐ | ⭐ |
|
| 78 |
+
|
| 79 |
+
## System Requirements
|
| 80 |
+
|
| 81 |
+
- **GPU**: Required (T4/P100/V100)
|
| 82 |
+
- **Storage**: 10-20GB for models
|
| 83 |
+
- **VRAM**:
|
| 84 |
+
- STTN: 4GB
|
| 85 |
+
- DiffuEraser: 12GB
|
| 86 |
+
- Stable Diffusion: 8GB
|
| 87 |
+
|
| 88 |
+
## Performance (Colab T4 GPU)
|
| 89 |
+
|
| 90 |
+
| Video | Algorithm | Time |
|
| 91 |
+
|-------|-----------|------|
|
| 92 |
+
| 1 min 720p | STTN | ~30s |
|
| 93 |
+
| 1 min 720p | DiffuEraser | ~3-5min |
|
| 94 |
+
| 5 min 720p | STTN | ~2min |
|
| 95 |
+
| 5 min 720p | DiffuEraser | ~15-20min |
|
| 96 |
+
|
| 97 |
+
## Troubleshooting
|
| 98 |
+
|
| 99 |
+
### Gradio not loading
|
| 100 |
+
- Wait 30-60 seconds for models to load
|
| 101 |
+
- Check if all cells ran successfully
|
| 102 |
+
- Restart runtime and try again
|
| 103 |
+
|
| 104 |
+
### Out of Memory
|
| 105 |
+
- Reduce batch size in settings
|
| 106 |
+
- Use STTN instead of DiffuEraser
|
| 107 |
+
- Process shorter videos
|
| 108 |
+
|
| 109 |
+
### Slow processing
|
| 110 |
+
- Use STTN for preview
|
| 111 |
+
- Enable GPU in Colab settings
|
| 112 |
+
- Consider Colab Pro for unlimited runtime
|
| 113 |
+
|
| 114 |
+
## Links
|
| 115 |
+
|
| 116 |
+
- **GitHub**: https://github.com/YaoFANGUK/video-subtitle-remover
|
| 117 |
+
- **Documentation**: See `docs/` folder
|
| 118 |
+
- **Issues**: Report on GitHub
|
| 119 |
+
|
| 120 |
+
## Tips
|
| 121 |
+
|
| 122 |
+
1. **Start with STTN** to test quickly
|
| 123 |
+
2. **Use DiffuEraser** for final high-quality output
|
| 124 |
+
3. **Keep videos under 10 minutes** on free tier
|
| 125 |
+
4. **Save to Google Drive** to avoid data loss
|
| 126 |
+
5. **Monitor GPU usage** with `!nvidia-smi`
|
google_colabs/Video_Subtitle_Remover_Gradio.ipynb
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🎬 Video Subtitle Remover - Gradio Interface\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"**Easy-to-use web interface for removing hardcoded subtitles from videos**\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook provides a Gradio web UI that runs in your browser.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"**Features:**\n",
|
| 14 |
+
"- 🖱️ Click-and-upload interface\n",
|
| 15 |
+
"- 🎨 Multiple AI algorithms (STTN, LAMA, DiffuEraser, etc.)\n",
|
| 16 |
+
"- ⚙️ Adjustable parameters\n",
|
| 17 |
+
"- 📊 Real-time progress\n",
|
| 18 |
+
"- 📥 Direct download\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"**Requirements:**\n",
|
| 21 |
+
"- Google Colab with GPU (Runtime → Change runtime type → GPU)\n",
|
| 22 |
+
"- ~10-20GB storage (for models)"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "markdown",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"source": [
|
| 29 |
+
"## Step 1: Check GPU"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"!nvidia-smi"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "markdown",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"source": [
|
| 45 |
+
"## Step 2: Clone Repository"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"!git clone https://github.com/walidurrosyad/sub-remover.git\n",
|
| 55 |
+
"%cd sub-remover"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "markdown",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"source": [
|
| 62 |
+
"## Step 3: Install Dependencies"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": null,
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"# Core dependencies\n",
|
| 72 |
+
"!pip install -q filesplit==3.0.2 albumentations scikit-image imgaug pyclipper lmdb\n",
|
| 73 |
+
"!pip install -q PyYAML omegaconf tqdm easydict scikit-learn pandas webdataset\n",
|
| 74 |
+
"!pip install -q protobuf av einops paddleocr paddle2onnx onnxruntime-gpu\n",
|
| 75 |
+
"!pip install -q paddlepaddle-gpu==2.6.2\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"# Gradio for web interface\n",
|
| 78 |
+
"!pip install -q gradio\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"# Advanced models (optional - uncomment if using SD or DiffuEraser)\n",
|
| 81 |
+
"!pip install -q diffusers transformers accelerate\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"print(\"✓ All dependencies installed!\")"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "markdown",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"## Step 4: Launch Gradio Interface\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"This will create a web interface you can use in your browser!\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"**Click the public URL that appears below to access the interface.**"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": null,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"# Launch Gradio interface\n",
|
| 104 |
+
"import sys\n",
|
| 105 |
+
"import os\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"# Add paths\n",
|
| 108 |
+
"sys.path.insert(0, '/content/sub-remover')\n",
|
| 109 |
+
"sys.path.insert(0, '/content/sub-remover/backend')\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"# Change to google_colabs directory to import gradio_app\n",
|
| 112 |
+
"os.chdir('/content/sub-remover/google_colabs')\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"from gradio_app import create_interface\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"demo = create_interface()\n",
|
| 117 |
+
"demo.launch(share=True, debug=True)"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "markdown",
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"source": [
|
| 124 |
+
"## Alternative: Run Gradio in Notebook\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"If the above doesn't work, run Gradio directly in the notebook:"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "code",
|
| 131 |
+
"execution_count": null,
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"outputs": [],
|
| 134 |
+
"source": [
|
| 135 |
+
"import sys\n",
|
| 136 |
+
"sys.path.insert(0, '/content/sub-remover')\n",
|
| 137 |
+
"sys.path.insert(0, '/content/sub-remover/backend')\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"# Import and run\n",
|
| 140 |
+
"from google_colabs.gradio_app import create_interface\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"demo = create_interface()\n",
|
| 143 |
+
"demo.launch(share=True, debug=True)"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "markdown",
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"source": [
|
| 150 |
+
"## How to Use the Gradio Interface\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"1. **Click the public URL** (looks like: https://xxxxx.gradio.live)\n",
|
| 153 |
+
"2. **Upload your video** using the upload button\n",
|
| 154 |
+
"3. **Select algorithm**:\n",
|
| 155 |
+
" - **DiffuEraser (Recommended)**: Best quality for subtitles\n",
|
| 156 |
+
" - **STTN (Fast)**: Quickest processing\n",
|
| 157 |
+
" - **Stable Diffusion**: High quality alternative\n",
|
| 158 |
+
"4. **Adjust settings** (optional) in \"Advanced Settings\"\n",
|
| 159 |
+
"5. **Click \"Remove Subtitles\"**\n",
|
| 160 |
+
"6. **Wait for processing** (progress shown)\n",
|
| 161 |
+
"7. **Download result** using the download button\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"## Performance Guide\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"### Colab T4 GPU (Free Tier)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"| Video Length | Algorithm | Time |\n",
|
| 168 |
+
"|--------------|-----------|------|\n",
|
| 169 |
+
"| 1 min 720p | STTN | ~30s |\n",
|
| 170 |
+
"| 1 min 720p | DiffuEraser | ~3-5min |\n",
|
| 171 |
+
"| 5 min 720p | STTN | ~2min |\n",
|
| 172 |
+
"| 5 min 720p | DiffuEraser | ~15-20min |\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"### Tips\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"- **Use STTN for preview**, DiffuEraser for final\n",
|
| 177 |
+
"- **Keep videos under 10 minutes** to avoid Colab timeout\n",
|
| 178 |
+
"- **Enable GPU** in Runtime settings\n",
|
| 179 |
+
"- **Reduce batch size** if you get OOM errors\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"## Troubleshooting\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"### \"Out of Memory\" Error\n",
|
| 184 |
+
"Reduce batch size in Advanced Settings:\n",
|
| 185 |
+
"- DiffuEraser: Set \"Max Frames per Batch\" to 40\n",
|
| 186 |
+
"- STTN: Set to 30\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"### Gradio Not Loading\n",
|
| 189 |
+
"Restart runtime and run all cells again.\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"### Slow Processing\n",
|
| 192 |
+
"- Use STTN algorithm for faster results\n",
|
| 193 |
+
"- Process shorter video clips\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"### Session Timeout\n",
|
| 196 |
+
"Colab free tier has time limits. Process shorter videos or use Colab Pro."
|
| 197 |
+
]
|
| 198 |
+
}
|
| 199 |
+
],
|
| 200 |
+
"metadata": {
|
| 201 |
+
"accelerator": "GPU",
|
| 202 |
+
"colab": {
|
| 203 |
+
"gpuType": "T4",
|
| 204 |
+
"provenance": []
|
| 205 |
+
},
|
| 206 |
+
"kernelspec": {
|
| 207 |
+
"display_name": "Python 3",
|
| 208 |
+
"language": "python",
|
| 209 |
+
"name": "python3"
|
| 210 |
+
},
|
| 211 |
+
"language_info": {
|
| 212 |
+
"name": "python",
|
| 213 |
+
"version": "3.10.12"
|
| 214 |
+
}
|
| 215 |
+
},
|
| 216 |
+
"nbformat": 4,
|
| 217 |
+
"nbformat_minor": 0
|
| 218 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
albumentations~=1.4.10
|
| 2 |
+
filesplit==3.0.2
|
| 3 |
+
opencv-python==4.11.0.86
|
| 4 |
+
scikit-image==0.25.2
|
| 5 |
+
imgaug==0.4.0
|
| 6 |
+
pyclipper==1.3.0.post6
|
| 7 |
+
lmdb==1.6.2
|
| 8 |
+
PyYAML==6.0.2
|
| 9 |
+
omegaconf==2.3.0
|
| 10 |
+
tqdm==4.67.1
|
| 11 |
+
PySimpleGUI-4-foss==4.60.4.1
|
| 12 |
+
easydict==1.13
|
| 13 |
+
scikit-learn==1.6.1
|
| 14 |
+
pandas==2.2.3
|
| 15 |
+
webdataset==0.2.111
|
| 16 |
+
numpy==2.2.5
|
| 17 |
+
protobuf==6.30.2
|
| 18 |
+
av==14.3.0
|
| 19 |
+
einops==0.8.1
|
| 20 |
+
paddleocr==2.10.0
|
| 21 |
+
paddle2onnx==1.3.1
|
| 22 |
+
onnxruntime-gpu==1.20.1
|
| 23 |
+
onnxruntime-directml==1.20.1; sys_platform == 'win32'
|
| 24 |
+
|
| 25 |
+
# Advanced Inpainting Models
|
| 26 |
+
diffusers>=0.27.0 # For Stable Diffusion & DiffuEraser
|
| 27 |
+
transformers>=4.36.0 # Required by diffusers
|
| 28 |
+
accelerate>=0.25.0 # For faster inference
|
| 29 |
+
xformers>=0.0.23; sys_platform != 'darwin' # Memory optimization (not on macOS)
|