Spaces:
Runtime error
Runtime error
Delete video_super_resolution
Browse files- video_super_resolution/__pycache__/color_fix.cpython-39.pyc +0 -0
- video_super_resolution/color_fix.py +0 -122
- video_super_resolution/dataset.py +0 -113
- video_super_resolution/scripts/__pycache__/inference_sr.cpython-39.pyc +0 -0
- video_super_resolution/scripts/inference_sr.py +0 -142
- video_super_resolution/scripts/inference_sr.sh +0 -56
video_super_resolution/__pycache__/color_fix.cpython-39.pyc
DELETED
|
Binary file (4.01 kB)
|
|
|
video_super_resolution/color_fix.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
# --------------------------------------------------------------------------------
|
| 3 |
-
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
| 4 |
-
# --------------------------------------------------------------------------------
|
| 5 |
-
'''
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from PIL import Image
|
| 9 |
-
from torch import Tensor
|
| 10 |
-
from torch.nn import functional as F
|
| 11 |
-
|
| 12 |
-
from torchvision.transforms import ToTensor, ToPILImage
|
| 13 |
-
from einops import rearrange
|
| 14 |
-
|
| 15 |
-
def adain_color_fix(target: Image, source: Image):
|
| 16 |
-
# Convert images to tensors
|
| 17 |
-
target = rearrange(target, 'T H W C -> T C H W') / 255
|
| 18 |
-
source = (source + 1) / 2
|
| 19 |
-
|
| 20 |
-
# Apply adaptive instance normalization
|
| 21 |
-
result_tensor_list = []
|
| 22 |
-
for i in range(0, target.shape[0]):
|
| 23 |
-
result_tensor_list.append(adaptive_instance_normalization(target[i].unsqueeze(0), source[i].unsqueeze(0)))
|
| 24 |
-
|
| 25 |
-
# Convert tensor back to image
|
| 26 |
-
result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
|
| 27 |
-
result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
|
| 28 |
-
|
| 29 |
-
return result_video
|
| 30 |
-
|
| 31 |
-
def wavelet_color_fix(target, source):
|
| 32 |
-
# Convert images to tensors
|
| 33 |
-
target = rearrange(target, 'T H W C -> T C H W') / 255
|
| 34 |
-
source = (source + 1) / 2
|
| 35 |
-
|
| 36 |
-
# Apply wavelet reconstruction
|
| 37 |
-
result_tensor_list = []
|
| 38 |
-
for i in range(0, target.shape[0]):
|
| 39 |
-
result_tensor_list.append(wavelet_reconstruction(target[i].unsqueeze(0), source[i].unsqueeze(0)))
|
| 40 |
-
|
| 41 |
-
# Convert tensor back to image
|
| 42 |
-
result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
|
| 43 |
-
result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
|
| 44 |
-
|
| 45 |
-
return result_video
|
| 46 |
-
|
| 47 |
-
def calc_mean_std(feat: Tensor, eps=1e-5):
|
| 48 |
-
"""Calculate mean and std for adaptive_instance_normalization.
|
| 49 |
-
Args:
|
| 50 |
-
feat (Tensor): 4D tensor.
|
| 51 |
-
eps (float): A small value added to the variance to avoid
|
| 52 |
-
divide-by-zero. Default: 1e-5.
|
| 53 |
-
"""
|
| 54 |
-
size = feat.size()
|
| 55 |
-
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 56 |
-
b, c = size[:2]
|
| 57 |
-
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
| 58 |
-
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
| 59 |
-
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
| 60 |
-
return feat_mean, feat_std
|
| 61 |
-
|
| 62 |
-
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
| 63 |
-
"""Adaptive instance normalization.
|
| 64 |
-
Adjust the reference features to have the similar color and illuminations
|
| 65 |
-
as those in the degradate features.
|
| 66 |
-
Args:
|
| 67 |
-
content_feat (Tensor): The reference feature.
|
| 68 |
-
style_feat (Tensor): The degradate features.
|
| 69 |
-
"""
|
| 70 |
-
size = content_feat.size()
|
| 71 |
-
style_mean, style_std = calc_mean_std(style_feat)
|
| 72 |
-
content_mean, content_std = calc_mean_std(content_feat)
|
| 73 |
-
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 74 |
-
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 75 |
-
|
| 76 |
-
def wavelet_blur(image: Tensor, radius: int):
|
| 77 |
-
"""
|
| 78 |
-
Apply wavelet blur to the input tensor.
|
| 79 |
-
"""
|
| 80 |
-
# input shape: (1, 3, H, W)
|
| 81 |
-
# convolution kernel
|
| 82 |
-
kernel_vals = [
|
| 83 |
-
[0.0625, 0.125, 0.0625],
|
| 84 |
-
[0.125, 0.25, 0.125],
|
| 85 |
-
[0.0625, 0.125, 0.0625],
|
| 86 |
-
]
|
| 87 |
-
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
| 88 |
-
# add channel dimensions to the kernel to make it a 4D tensor
|
| 89 |
-
kernel = kernel[None, None]
|
| 90 |
-
# repeat the kernel across all input channels
|
| 91 |
-
kernel = kernel.repeat(3, 1, 1, 1)
|
| 92 |
-
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
| 93 |
-
# apply convolution
|
| 94 |
-
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
| 95 |
-
return output
|
| 96 |
-
|
| 97 |
-
def wavelet_decomposition(image: Tensor, levels=5):
|
| 98 |
-
"""
|
| 99 |
-
Apply wavelet decomposition to the input tensor.
|
| 100 |
-
This function only returns the low frequency & the high frequency.
|
| 101 |
-
"""
|
| 102 |
-
high_freq = torch.zeros_like(image)
|
| 103 |
-
for i in range(levels):
|
| 104 |
-
radius = 2 ** i
|
| 105 |
-
low_freq = wavelet_blur(image, radius)
|
| 106 |
-
high_freq += (image - low_freq)
|
| 107 |
-
image = low_freq
|
| 108 |
-
|
| 109 |
-
return high_freq, low_freq
|
| 110 |
-
|
| 111 |
-
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
| 112 |
-
"""
|
| 113 |
-
Apply wavelet decomposition, so that the content will have the same color as the style.
|
| 114 |
-
"""
|
| 115 |
-
# calculate the wavelet decomposition of the content feature
|
| 116 |
-
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
| 117 |
-
del content_low_freq
|
| 118 |
-
# calculate the wavelet decomposition of the style feature
|
| 119 |
-
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
| 120 |
-
del style_high_freq
|
| 121 |
-
# reconstruct the content feature with the style's high frequency
|
| 122 |
-
return content_high_freq + style_low_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_super_resolution/dataset.py
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import random
|
| 3 |
-
import glob
|
| 4 |
-
import torchvision
|
| 5 |
-
from einops import rearrange
|
| 6 |
-
from torch.utils import data as data
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from torchvision import transforms
|
| 9 |
-
from PIL import Image
|
| 10 |
-
|
| 11 |
-
class PairedCaptionVideoDataset(data.Dataset):
|
| 12 |
-
def __init__(
|
| 13 |
-
self,
|
| 14 |
-
root_folders=None,
|
| 15 |
-
null_text_ratio=0.5,
|
| 16 |
-
num_frames=16
|
| 17 |
-
):
|
| 18 |
-
super(PairedCaptionVideoDataset, self).__init__()
|
| 19 |
-
|
| 20 |
-
self.null_text_ratio = null_text_ratio
|
| 21 |
-
self.lr_list = []
|
| 22 |
-
self.gt_list = []
|
| 23 |
-
self.tag_path_list = []
|
| 24 |
-
self.num_frames = num_frames
|
| 25 |
-
|
| 26 |
-
# root_folders = root_folders.split(',')
|
| 27 |
-
for root_folder in root_folders:
|
| 28 |
-
lr_path = root_folder +'/lq'
|
| 29 |
-
tag_path = root_folder +'/text'
|
| 30 |
-
gt_path = root_folder +'/gt'
|
| 31 |
-
|
| 32 |
-
self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
|
| 33 |
-
self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
|
| 34 |
-
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
|
| 35 |
-
|
| 36 |
-
assert len(self.lr_list) == len(self.gt_list)
|
| 37 |
-
assert len(self.lr_list) == len(self.tag_path_list)
|
| 38 |
-
|
| 39 |
-
def __getitem__(self, index):
|
| 40 |
-
|
| 41 |
-
gt_path = self.gt_list[index]
|
| 42 |
-
vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
|
| 43 |
-
fps = info['video_fps']
|
| 44 |
-
vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
|
| 45 |
-
# gt = self.trandform(vframes_gt)
|
| 46 |
-
|
| 47 |
-
lq_path = self.lr_list[index]
|
| 48 |
-
vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
|
| 49 |
-
vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
|
| 50 |
-
# lq = self.trandform(vframes_lq)
|
| 51 |
-
|
| 52 |
-
if random.random() < self.null_text_ratio:
|
| 53 |
-
tag = ''
|
| 54 |
-
else:
|
| 55 |
-
tag_path = self.tag_path_list[index]
|
| 56 |
-
with open(tag_path, 'r', encoding='utf-8') as file:
|
| 57 |
-
tag = file.read()
|
| 58 |
-
|
| 59 |
-
return {"gt": vframes_gt[:, :self.num_frames, :, :], "lq": vframes_lq[:, :self.num_frames, :, :], "text": tag, 'fps': fps}
|
| 60 |
-
|
| 61 |
-
def __len__(self):
|
| 62 |
-
return len(self.gt_list)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class PairedCaptionImageDataset(data.Dataset):
|
| 66 |
-
def __init__(
|
| 67 |
-
self,
|
| 68 |
-
root_folder=None,
|
| 69 |
-
):
|
| 70 |
-
super(PairedCaptionImageDataset, self).__init__()
|
| 71 |
-
|
| 72 |
-
self.lr_list = []
|
| 73 |
-
self.gt_list = []
|
| 74 |
-
self.tag_path_list = []
|
| 75 |
-
|
| 76 |
-
lr_path = root_folder +'/sr_bicubic'
|
| 77 |
-
gt_path = root_folder +'/gt'
|
| 78 |
-
|
| 79 |
-
self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
|
| 80 |
-
self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
|
| 81 |
-
|
| 82 |
-
assert len(self.lr_list) == len(self.gt_list)
|
| 83 |
-
|
| 84 |
-
self.img_preproc = transforms.Compose([
|
| 85 |
-
transforms.ToTensor(),
|
| 86 |
-
])
|
| 87 |
-
|
| 88 |
-
# Define the crop size (e.g., 256x256)
|
| 89 |
-
crop_size = (720, 1280)
|
| 90 |
-
|
| 91 |
-
# CenterCrop transform
|
| 92 |
-
self.center_crop = transforms.CenterCrop(crop_size)
|
| 93 |
-
|
| 94 |
-
def __getitem__(self, index):
|
| 95 |
-
|
| 96 |
-
gt_path = self.gt_list[index]
|
| 97 |
-
gt_img = Image.open(gt_path).convert('RGB')
|
| 98 |
-
gt_img = self.center_crop(self.img_preproc(gt_img))
|
| 99 |
-
|
| 100 |
-
lq_path = self.lr_list[index]
|
| 101 |
-
lq_img = Image.open(lq_path).convert('RGB')
|
| 102 |
-
lq_img = self.center_crop(self.img_preproc(lq_img))
|
| 103 |
-
|
| 104 |
-
example = dict()
|
| 105 |
-
|
| 106 |
-
example["lq"] = (lq_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
|
| 107 |
-
example["gt"] = (gt_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
|
| 108 |
-
example["text"] = ""
|
| 109 |
-
|
| 110 |
-
return example
|
| 111 |
-
|
| 112 |
-
def __len__(self):
|
| 113 |
-
return len(self.gt_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_super_resolution/scripts/__pycache__/inference_sr.cpython-39.pyc
DELETED
|
Binary file (3.97 kB)
|
|
|
video_super_resolution/scripts/inference_sr.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
from argparse import ArgumentParser, Namespace
|
| 4 |
-
import json
|
| 5 |
-
from typing import Any, Dict, List, Mapping, Tuple
|
| 6 |
-
from easydict import EasyDict
|
| 7 |
-
|
| 8 |
-
from video_to_video.video_to_video_model import VideoToVideo_sr
|
| 9 |
-
from video_to_video.utils.seed import setup_seed
|
| 10 |
-
from video_to_video.utils.logger import get_logger
|
| 11 |
-
from video_super_resolution.color_fix import adain_color_fix
|
| 12 |
-
|
| 13 |
-
from inference_utils import *
|
| 14 |
-
|
| 15 |
-
logger = get_logger()
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class STAR_sr():
|
| 19 |
-
def __init__(self,
|
| 20 |
-
result_dir='./results/',
|
| 21 |
-
file_name='000_video.mp4',
|
| 22 |
-
model_path='./pretrained_weight',
|
| 23 |
-
solver_mode='fast',
|
| 24 |
-
steps=15,
|
| 25 |
-
guide_scale=7.5,
|
| 26 |
-
upscale=4,
|
| 27 |
-
max_chunk_len=32,
|
| 28 |
-
variant_info=None,
|
| 29 |
-
chunk_size=3,
|
| 30 |
-
):
|
| 31 |
-
self.model_path=model_path
|
| 32 |
-
logger.info('checkpoint_path: {}'.format(self.model_path))
|
| 33 |
-
|
| 34 |
-
self.result_dir = result_dir
|
| 35 |
-
self.file_name = file_name
|
| 36 |
-
os.makedirs(self.result_dir, exist_ok=True)
|
| 37 |
-
|
| 38 |
-
model_cfg = EasyDict(__name__='model_cfg')
|
| 39 |
-
model_cfg.model_path = self.model_path
|
| 40 |
-
model_cfg.chunk_size = chunk_size
|
| 41 |
-
self.model = VideoToVideo_sr(model_cfg)
|
| 42 |
-
|
| 43 |
-
steps = 15 if solver_mode == 'fast' else steps
|
| 44 |
-
self.solver_mode=solver_mode
|
| 45 |
-
self.steps=steps
|
| 46 |
-
self.guide_scale=guide_scale
|
| 47 |
-
self.upscale = upscale
|
| 48 |
-
self.max_chunk_len=max_chunk_len
|
| 49 |
-
self.variant_info=variant_info
|
| 50 |
-
|
| 51 |
-
def enhance_a_video(self, video_path, prompt):
|
| 52 |
-
logger.info('input video path: {}'.format(video_path))
|
| 53 |
-
text = prompt
|
| 54 |
-
logger.info('text: {}'.format(text))
|
| 55 |
-
caption = text + self.model.positive_prompt
|
| 56 |
-
|
| 57 |
-
input_frames, input_fps = load_video(video_path)
|
| 58 |
-
in_f_num = len(input_frames)
|
| 59 |
-
logger.info('input frames length: {}'.format(in_f_num))
|
| 60 |
-
logger.info('input fps: {}'.format(input_fps))
|
| 61 |
-
|
| 62 |
-
video_data = preprocess(input_frames)
|
| 63 |
-
_, _, h, w = video_data.shape
|
| 64 |
-
logger.info('input resolution: {}'.format((h, w)))
|
| 65 |
-
target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
|
| 66 |
-
logger.info('target resolution: {}'.format((target_h, target_w)))
|
| 67 |
-
|
| 68 |
-
pre_data = {'video_data': video_data, 'y': caption}
|
| 69 |
-
pre_data['target_res'] = (target_h, target_w)
|
| 70 |
-
|
| 71 |
-
total_noise_levels = 900
|
| 72 |
-
setup_seed(666)
|
| 73 |
-
|
| 74 |
-
with torch.no_grad():
|
| 75 |
-
data_tensor = collate_fn(pre_data, 'cuda:0')
|
| 76 |
-
output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
|
| 77 |
-
solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
|
| 78 |
-
max_chunk_len=self.max_chunk_len
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
output = tensor2vid(output)
|
| 82 |
-
|
| 83 |
-
# Using color fix
|
| 84 |
-
output = adain_color_fix(output, video_data)
|
| 85 |
-
|
| 86 |
-
save_video(output, self.result_dir, self.file_name, fps=input_fps)
|
| 87 |
-
return os.path.join(self.result_dir, self.file_name)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def parse_args():
|
| 91 |
-
parser = ArgumentParser()
|
| 92 |
-
|
| 93 |
-
parser.add_argument("--input_path", required=True, type=str, help="input video path")
|
| 94 |
-
parser.add_argument("--save_dir", type=str, default='results', help="save directory")
|
| 95 |
-
parser.add_argument("--file_name", type=str, help="file name")
|
| 96 |
-
parser.add_argument("--model_path", type=str, default='./pretrained_weight/I2VGen-XL-based/heavy_deg.pt', help="model path")
|
| 97 |
-
parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
|
| 98 |
-
parser.add_argument("--upscale", type=int, default=4, help='up-scale')
|
| 99 |
-
parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
|
| 100 |
-
parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
|
| 101 |
-
|
| 102 |
-
parser.add_argument("--cfg", type=float, default=7.5)
|
| 103 |
-
parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
|
| 104 |
-
parser.add_argument("--steps", type=int, default=15)
|
| 105 |
-
|
| 106 |
-
return parser.parse_args()
|
| 107 |
-
|
| 108 |
-
def main():
|
| 109 |
-
|
| 110 |
-
args = parse_args()
|
| 111 |
-
|
| 112 |
-
input_path = args.input_path
|
| 113 |
-
prompt = args.prompt
|
| 114 |
-
model_path = args.model_path
|
| 115 |
-
save_dir = args.save_dir
|
| 116 |
-
file_name = args.file_name
|
| 117 |
-
upscale = args.upscale
|
| 118 |
-
max_chunk_len = args.max_chunk_len
|
| 119 |
-
|
| 120 |
-
steps = args.steps
|
| 121 |
-
solver_mode = args.solver_mode
|
| 122 |
-
guide_scale = args.cfg
|
| 123 |
-
|
| 124 |
-
assert solver_mode in ('fast', 'normal')
|
| 125 |
-
|
| 126 |
-
star_sr = STAR_sr(
|
| 127 |
-
result_dir=save_dir,
|
| 128 |
-
file_name=file_name, # new added
|
| 129 |
-
model_path=model_path,
|
| 130 |
-
solver_mode=solver_mode,
|
| 131 |
-
steps=steps,
|
| 132 |
-
guide_scale=guide_scale,
|
| 133 |
-
upscale=upscale,
|
| 134 |
-
max_chunk_len=max_chunk_len,
|
| 135 |
-
variant_info=None,
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
star_sr.enhance_a_video(input_path, prompt)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
if __name__ == '__main__':
|
| 142 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_super_resolution/scripts/inference_sr.sh
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
|
| 3 |
-
# Folder paths
|
| 4 |
-
video_folder_path='./input/video'
|
| 5 |
-
txt_file_path='./input/text/prompt.txt'
|
| 6 |
-
|
| 7 |
-
# Get all .mp4 files in the folder using find to handle special characters
|
| 8 |
-
mapfile -t mp4_files < <(find "$video_folder_path" -type f -name "*.mp4")
|
| 9 |
-
|
| 10 |
-
# Print the list of MP4 files
|
| 11 |
-
echo "MP4 files to be processed:"
|
| 12 |
-
for mp4_file in "${mp4_files[@]}"; do
|
| 13 |
-
echo "$mp4_file"
|
| 14 |
-
done
|
| 15 |
-
|
| 16 |
-
# Read lines from the text file, skipping empty lines
|
| 17 |
-
mapfile -t lines < <(grep -v '^\s*$' "$txt_file_path")
|
| 18 |
-
|
| 19 |
-
# List of frame counts
|
| 20 |
-
frame_length=32
|
| 21 |
-
|
| 22 |
-
# Debugging output
|
| 23 |
-
echo "Number of MP4 files: ${#mp4_files[@]}"
|
| 24 |
-
echo "Number of lines in the text file: ${#lines[@]}"
|
| 25 |
-
|
| 26 |
-
# Ensure the number of video files matches the number of lines
|
| 27 |
-
if [ ${#mp4_files[@]} -ne ${#lines[@]} ]; then
|
| 28 |
-
echo "Number of MP4 files and lines in the text file do not match."
|
| 29 |
-
exit 1
|
| 30 |
-
fi
|
| 31 |
-
|
| 32 |
-
# Loop through video files and corresponding lines
|
| 33 |
-
for i in "${!mp4_files[@]}"; do
|
| 34 |
-
mp4_file="${mp4_files[$i]}"
|
| 35 |
-
line="${lines[$i]}"
|
| 36 |
-
|
| 37 |
-
# Extract the filename without the extension
|
| 38 |
-
file_name=$(basename "$mp4_file" .mp4)
|
| 39 |
-
|
| 40 |
-
echo "Processing video file: $mp4_file with prompt: $line"
|
| 41 |
-
|
| 42 |
-
# Run Python script with parameters
|
| 43 |
-
python \
|
| 44 |
-
./video_super_resolution/scripts/inference_sr.py \
|
| 45 |
-
--solver_mode 'fast' \
|
| 46 |
-
--steps 15 \
|
| 47 |
-
--input_path "${mp4_file}" \
|
| 48 |
-
--model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/heavy_deg.pt \
|
| 49 |
-
--prompt "${line}" \
|
| 50 |
-
--upscale 4 \
|
| 51 |
-
--max_chunk_len ${frame_length} \
|
| 52 |
-
--file_name "${file_name}.mp4" \
|
| 53 |
-
--save_dir ./results
|
| 54 |
-
done
|
| 55 |
-
|
| 56 |
-
echo "All videos processed successfully."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|