|
|
import warnings |
|
|
from enum import Enum, unique |
|
|
warnings.filterwarnings('ignore') |
|
|
import os |
|
|
import torch |
|
|
import logging |
|
|
import platform |
|
|
import stat |
|
|
from fsplit.filesplit import Filesplit |
|
|
import onnxruntime as ort |
|
|
|
|
|
|
|
|
VERSION = "1.1.1" |
|
|
|
|
|
logging.disable(logging.DEBUG) |
|
|
logging.disable(logging.WARNING) |
|
|
try: |
|
|
import torch_directml |
|
|
device = torch_directml.device(torch_directml.default_device()) |
|
|
USE_DML = True |
|
|
except: |
|
|
USE_DML = False |
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
LAMA_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'big-lama') |
|
|
STTN_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'sttn', 'infer_model.pth') |
|
|
VIDEO_INPAINT_MODEL_PATH = os.path.join(BASE_DIR, 'models', 'video') |
|
|
MODEL_VERSION = 'V4' |
|
|
DET_MODEL_BASE = os.path.join(BASE_DIR, 'models') |
|
|
DET_MODEL_PATH = os.path.join(DET_MODEL_BASE, MODEL_VERSION, 'ch_det') |
|
|
|
|
|
|
|
|
if 'big-lama.pt' not in (os.listdir(LAMA_MODEL_PATH)): |
|
|
fs = Filesplit() |
|
|
fs.merge(input_dir=LAMA_MODEL_PATH) |
|
|
|
|
|
if 'inference.pdiparams' not in os.listdir(DET_MODEL_PATH): |
|
|
fs = Filesplit() |
|
|
fs.merge(input_dir=DET_MODEL_PATH) |
|
|
|
|
|
if 'ProPainter.pth' not in os.listdir(VIDEO_INPAINT_MODEL_PATH): |
|
|
fs = Filesplit() |
|
|
fs.merge(input_dir=VIDEO_INPAINT_MODEL_PATH) |
|
|
|
|
|
|
|
|
sys_str = platform.system() |
|
|
if sys_str == "Windows": |
|
|
ffmpeg_bin = os.path.join('win_x64', 'ffmpeg.exe') |
|
|
elif sys_str == "Linux": |
|
|
ffmpeg_bin = os.path.join('linux_x64', 'ffmpeg') |
|
|
else: |
|
|
ffmpeg_bin = os.path.join('macos', 'ffmpeg') |
|
|
FFMPEG_PATH = os.path.join(BASE_DIR, '', 'ffmpeg', ffmpeg_bin) |
|
|
|
|
|
if 'ffmpeg.exe' not in os.listdir(os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')): |
|
|
fs = Filesplit() |
|
|
fs.merge(input_dir=os.path.join(BASE_DIR, '', 'ffmpeg', 'win_x64')) |
|
|
|
|
|
os.chmod(FFMPEG_PATH, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO) |
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' |
|
|
|
|
|
|
|
|
ONNX_PROVIDERS = [] |
|
|
available_providers = ort.get_available_providers() |
|
|
for provider in available_providers: |
|
|
if provider in [ |
|
|
"CPUExecutionProvider" |
|
|
]: |
|
|
continue |
|
|
if provider not in [ |
|
|
"DmlExecutionProvider", |
|
|
"ROCMExecutionProvider", |
|
|
"MIGraphXExecutionProvider", |
|
|
"VitisAIExecutionProvider", |
|
|
"OpenVINOExecutionProvider", |
|
|
"MetalExecutionProvider", |
|
|
"CoreMLExecutionProvider", |
|
|
"CUDAExecutionProvider", |
|
|
]: |
|
|
continue |
|
|
ONNX_PROVIDERS.append(provider) |
|
|
|
|
|
|
|
|
|
|
|
@unique |
|
|
class InpaintMode(Enum): |
|
|
""" |
|
|
图像重绘算法枚举 |
|
|
""" |
|
|
STTN = 'sttn' |
|
|
LAMA = 'lama' |
|
|
PROPAINTER = 'propainter' |
|
|
STABLE_DIFFUSION = 'sd' |
|
|
DIFFUERASER = 'diffueraser' |
|
|
E2FGVI = 'e2fgvi' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_H264 = True |
|
|
|
|
|
|
|
|
""" |
|
|
MODE可选算法类型 |
|
|
- InpaintMode.STTN 算法:对于真人视频效果较好,速度快,可以跳过字幕检测 |
|
|
- InpaintMode.LAMA 算法:对于动画类视频效果好,速度一般,不可以跳过字幕检测 |
|
|
- InpaintMode.PROPAINTER 算法: 需要消耗大量显存,速度较慢,对运动非常剧烈的视频效果较好 |
|
|
""" |
|
|
|
|
|
MODE = InpaintMode.STTN |
|
|
|
|
|
|
|
|
SD_MODEL_PATH = 'backend/models/stable-diffusion-inpainting' |
|
|
SD_STEPS = 50 |
|
|
SD_GUIDANCE_SCALE = 7.5 |
|
|
SD_PROMPT = "natural scene, high quality" |
|
|
SD_USE_FP16 = True |
|
|
|
|
|
|
|
|
DIFFUERASER_MODEL_PATH = 'backend/models/diffueraser' |
|
|
DIFFUERASER_STEPS = 50 |
|
|
DIFFUERASER_GUIDANCE = 7.5 |
|
|
DIFFUERASER_USE_SAM2 = False |
|
|
DIFFUERASER_MAX_LOAD_NUM = 80 |
|
|
|
|
|
|
|
|
E2FGVI_MODEL_PATH = 'backend/models/e2fgvi' |
|
|
E2FGVI_MAX_LOAD_NUM = 80 |
|
|
E2FGVI_NEIGHBOR_LENGTH = 10 |
|
|
|
|
|
|
|
|
THRESHOLD_HEIGHT_WIDTH_DIFFERENCE = 10 |
|
|
|
|
|
SUBTITLE_AREA_DEVIATION_PIXEL = 20 |
|
|
|
|
|
THRESHOLD_HEIGHT_DIFFERENCE = 20 |
|
|
|
|
|
PIXEL_TOLERANCE_Y = 20 |
|
|
PIXEL_TOLERANCE_X = 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
1. STTN_SKIP_DETECTION |
|
|
含义:是否使用跳过检测 |
|
|
效果:设置为True跳过字幕检测,会省去很大时间,但是可能误伤无字幕的视频帧或者会导致去除的字幕漏了 |
|
|
|
|
|
2. STTN_NEIGHBOR_STRIDE |
|
|
含义:相邻帧数步长, 如果需要为第50帧填充缺失的区域,STTN_NEIGHBOR_STRIDE=5,那么算法会使用第45帧、第40帧等作为参照。 |
|
|
效果:用于控制参考帧选择的密度,较大的步长意味着使用更少、更分散的参考帧,较小的步长意味着使用更多、更集中的参考帧。 |
|
|
|
|
|
3. STTN_REFERENCE_LENGTH |
|
|
含义:参数帧数量,STTN算法会查看每个待修复帧的前后若干帧来获得用于修复的上下文信息 |
|
|
效果:调大会增加显存占用,处理效果变好,但是处理速度变慢 |
|
|
|
|
|
4. STTN_MAX_LOAD_NUM |
|
|
含义:STTN算法每次最多加载的视频帧数量 |
|
|
效果:设置越大速度越慢,但效果越好 |
|
|
注意:要保证STTN_MAX_LOAD_NUM大于STTN_NEIGHBOR_STRIDE和STTN_REFERENCE_LENGTH |
|
|
""" |
|
|
STTN_SKIP_DETECTION = True |
|
|
|
|
|
STTN_NEIGHBOR_STRIDE = 5 |
|
|
|
|
|
STTN_REFERENCE_LENGTH = 10 |
|
|
|
|
|
STTN_MAX_LOAD_NUM = 50 |
|
|
if STTN_MAX_LOAD_NUM < STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE: |
|
|
STTN_MAX_LOAD_NUM = STTN_REFERENCE_LENGTH * STTN_NEIGHBOR_STRIDE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROPAINTER_MAX_LOAD_NUM = 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LAMA_SUPER_FAST = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from backend.model_downloader import ensure_models_exist |
|
|
|
|
|
|
|
|
models_to_check = {} |
|
|
if not os.path.exists(os.path.join(LAMA_MODEL_PATH, 'big-lama.pt')): |
|
|
models_to_check['LAMA'] = LAMA_MODEL_PATH |
|
|
if not os.path.exists(STTN_MODEL_PATH): |
|
|
models_to_check['STTN'] = STTN_MODEL_PATH |
|
|
if not os.path.exists(os.path.join(VIDEO_INPAINT_MODEL_PATH, 'ProPainter.pth')): |
|
|
models_to_check['ProPainter'] = VIDEO_INPAINT_MODEL_PATH |
|
|
if not os.path.exists(os.path.join(DET_MODEL_PATH, 'inference.pdiparams')): |
|
|
models_to_check['Detection'] = DET_MODEL_PATH |
|
|
|
|
|
if models_to_check: |
|
|
print("[Auto-Download] Missing models detected. Downloading from Hugging Face...") |
|
|
ensure_models_exist(models_to_check) |
|
|
except ImportError: |
|
|
|
|
|
pass |
|
|
except Exception as e: |
|
|
print(f"[Warning] Could not auto-download models: {e}") |
|
|
print("[Info] You may need to manually download models from https://huggingface.co/Rasta02/dataku") |
|
|
|