Fabrice-TIERCELIN commited on
Commit
0ba18b1
·
verified ·
1 Parent(s): 1fa1fb9

Delete video_super_resolution

Browse files
video_super_resolution/__pycache__/color_fix.cpython-39.pyc DELETED
Binary file (4.01 kB)
 
video_super_resolution/color_fix.py DELETED
@@ -1,122 +0,0 @@
1
- '''
2
- # --------------------------------------------------------------------------------
3
- # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
- # --------------------------------------------------------------------------------
5
- '''
6
-
7
- import torch
8
- from PIL import Image
9
- from torch import Tensor
10
- from torch.nn import functional as F
11
-
12
- from torchvision.transforms import ToTensor, ToPILImage
13
- from einops import rearrange
14
-
15
- def adain_color_fix(target: Image, source: Image):
16
- # Convert images to tensors
17
- target = rearrange(target, 'T H W C -> T C H W') / 255
18
- source = (source + 1) / 2
19
-
20
- # Apply adaptive instance normalization
21
- result_tensor_list = []
22
- for i in range(0, target.shape[0]):
23
- result_tensor_list.append(adaptive_instance_normalization(target[i].unsqueeze(0), source[i].unsqueeze(0)))
24
-
25
- # Convert tensor back to image
26
- result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
27
- result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
28
-
29
- return result_video
30
-
31
- def wavelet_color_fix(target, source):
32
- # Convert images to tensors
33
- target = rearrange(target, 'T H W C -> T C H W') / 255
34
- source = (source + 1) / 2
35
-
36
- # Apply wavelet reconstruction
37
- result_tensor_list = []
38
- for i in range(0, target.shape[0]):
39
- result_tensor_list.append(wavelet_reconstruction(target[i].unsqueeze(0), source[i].unsqueeze(0)))
40
-
41
- # Convert tensor back to image
42
- result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
43
- result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
44
-
45
- return result_video
46
-
47
- def calc_mean_std(feat: Tensor, eps=1e-5):
48
- """Calculate mean and std for adaptive_instance_normalization.
49
- Args:
50
- feat (Tensor): 4D tensor.
51
- eps (float): A small value added to the variance to avoid
52
- divide-by-zero. Default: 1e-5.
53
- """
54
- size = feat.size()
55
- assert len(size) == 4, 'The input feature should be 4D tensor.'
56
- b, c = size[:2]
57
- feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
58
- feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
59
- feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
60
- return feat_mean, feat_std
61
-
62
- def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
63
- """Adaptive instance normalization.
64
- Adjust the reference features to have the similar color and illuminations
65
- as those in the degradate features.
66
- Args:
67
- content_feat (Tensor): The reference feature.
68
- style_feat (Tensor): The degradate features.
69
- """
70
- size = content_feat.size()
71
- style_mean, style_std = calc_mean_std(style_feat)
72
- content_mean, content_std = calc_mean_std(content_feat)
73
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
74
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
75
-
76
- def wavelet_blur(image: Tensor, radius: int):
77
- """
78
- Apply wavelet blur to the input tensor.
79
- """
80
- # input shape: (1, 3, H, W)
81
- # convolution kernel
82
- kernel_vals = [
83
- [0.0625, 0.125, 0.0625],
84
- [0.125, 0.25, 0.125],
85
- [0.0625, 0.125, 0.0625],
86
- ]
87
- kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
88
- # add channel dimensions to the kernel to make it a 4D tensor
89
- kernel = kernel[None, None]
90
- # repeat the kernel across all input channels
91
- kernel = kernel.repeat(3, 1, 1, 1)
92
- image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
93
- # apply convolution
94
- output = F.conv2d(image, kernel, groups=3, dilation=radius)
95
- return output
96
-
97
- def wavelet_decomposition(image: Tensor, levels=5):
98
- """
99
- Apply wavelet decomposition to the input tensor.
100
- This function only returns the low frequency & the high frequency.
101
- """
102
- high_freq = torch.zeros_like(image)
103
- for i in range(levels):
104
- radius = 2 ** i
105
- low_freq = wavelet_blur(image, radius)
106
- high_freq += (image - low_freq)
107
- image = low_freq
108
-
109
- return high_freq, low_freq
110
-
111
- def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
112
- """
113
- Apply wavelet decomposition, so that the content will have the same color as the style.
114
- """
115
- # calculate the wavelet decomposition of the content feature
116
- content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
117
- del content_low_freq
118
- # calculate the wavelet decomposition of the style feature
119
- style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
120
- del style_high_freq
121
- # reconstruct the content feature with the style's high frequency
122
- return content_high_freq + style_low_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_super_resolution/dataset.py DELETED
@@ -1,113 +0,0 @@
1
- import os
2
- import random
3
- import glob
4
- import torchvision
5
- from einops import rearrange
6
- from torch.utils import data as data
7
- import torch.nn.functional as F
8
- from torchvision import transforms
9
- from PIL import Image
10
-
11
- class PairedCaptionVideoDataset(data.Dataset):
12
- def __init__(
13
- self,
14
- root_folders=None,
15
- null_text_ratio=0.5,
16
- num_frames=16
17
- ):
18
- super(PairedCaptionVideoDataset, self).__init__()
19
-
20
- self.null_text_ratio = null_text_ratio
21
- self.lr_list = []
22
- self.gt_list = []
23
- self.tag_path_list = []
24
- self.num_frames = num_frames
25
-
26
- # root_folders = root_folders.split(',')
27
- for root_folder in root_folders:
28
- lr_path = root_folder +'/lq'
29
- tag_path = root_folder +'/text'
30
- gt_path = root_folder +'/gt'
31
-
32
- self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
33
- self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
34
- self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
35
-
36
- assert len(self.lr_list) == len(self.gt_list)
37
- assert len(self.lr_list) == len(self.tag_path_list)
38
-
39
- def __getitem__(self, index):
40
-
41
- gt_path = self.gt_list[index]
42
- vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
43
- fps = info['video_fps']
44
- vframes_gt = (rearrange(vframes_gt, "T C H W -> C T H W") / 255) * 2 - 1
45
- # gt = self.trandform(vframes_gt)
46
-
47
- lq_path = self.lr_list[index]
48
- vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
49
- vframes_lq = (rearrange(vframes_lq, "T C H W -> C T H W") / 255) * 2 - 1
50
- # lq = self.trandform(vframes_lq)
51
-
52
- if random.random() < self.null_text_ratio:
53
- tag = ''
54
- else:
55
- tag_path = self.tag_path_list[index]
56
- with open(tag_path, 'r', encoding='utf-8') as file:
57
- tag = file.read()
58
-
59
- return {"gt": vframes_gt[:, :self.num_frames, :, :], "lq": vframes_lq[:, :self.num_frames, :, :], "text": tag, 'fps': fps}
60
-
61
- def __len__(self):
62
- return len(self.gt_list)
63
-
64
-
65
- class PairedCaptionImageDataset(data.Dataset):
66
- def __init__(
67
- self,
68
- root_folder=None,
69
- ):
70
- super(PairedCaptionImageDataset, self).__init__()
71
-
72
- self.lr_list = []
73
- self.gt_list = []
74
- self.tag_path_list = []
75
-
76
- lr_path = root_folder +'/sr_bicubic'
77
- gt_path = root_folder +'/gt'
78
-
79
- self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
80
- self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
81
-
82
- assert len(self.lr_list) == len(self.gt_list)
83
-
84
- self.img_preproc = transforms.Compose([
85
- transforms.ToTensor(),
86
- ])
87
-
88
- # Define the crop size (e.g., 256x256)
89
- crop_size = (720, 1280)
90
-
91
- # CenterCrop transform
92
- self.center_crop = transforms.CenterCrop(crop_size)
93
-
94
- def __getitem__(self, index):
95
-
96
- gt_path = self.gt_list[index]
97
- gt_img = Image.open(gt_path).convert('RGB')
98
- gt_img = self.center_crop(self.img_preproc(gt_img))
99
-
100
- lq_path = self.lr_list[index]
101
- lq_img = Image.open(lq_path).convert('RGB')
102
- lq_img = self.center_crop(self.img_preproc(lq_img))
103
-
104
- example = dict()
105
-
106
- example["lq"] = (lq_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
107
- example["gt"] = (gt_img.squeeze(0) * 2.0 - 1.0).unsqueeze(1)
108
- example["text"] = ""
109
-
110
- return example
111
-
112
- def __len__(self):
113
- return len(self.gt_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_super_resolution/scripts/__pycache__/inference_sr.cpython-39.pyc DELETED
Binary file (3.97 kB)
 
video_super_resolution/scripts/inference_sr.py DELETED
@@ -1,142 +0,0 @@
1
- import os
2
- import torch
3
- from argparse import ArgumentParser, Namespace
4
- import json
5
- from typing import Any, Dict, List, Mapping, Tuple
6
- from easydict import EasyDict
7
-
8
- from video_to_video.video_to_video_model import VideoToVideo_sr
9
- from video_to_video.utils.seed import setup_seed
10
- from video_to_video.utils.logger import get_logger
11
- from video_super_resolution.color_fix import adain_color_fix
12
-
13
- from inference_utils import *
14
-
15
- logger = get_logger()
16
-
17
-
18
- class STAR_sr():
19
- def __init__(self,
20
- result_dir='./results/',
21
- file_name='000_video.mp4',
22
- model_path='./pretrained_weight',
23
- solver_mode='fast',
24
- steps=15,
25
- guide_scale=7.5,
26
- upscale=4,
27
- max_chunk_len=32,
28
- variant_info=None,
29
- chunk_size=3,
30
- ):
31
- self.model_path=model_path
32
- logger.info('checkpoint_path: {}'.format(self.model_path))
33
-
34
- self.result_dir = result_dir
35
- self.file_name = file_name
36
- os.makedirs(self.result_dir, exist_ok=True)
37
-
38
- model_cfg = EasyDict(__name__='model_cfg')
39
- model_cfg.model_path = self.model_path
40
- model_cfg.chunk_size = chunk_size
41
- self.model = VideoToVideo_sr(model_cfg)
42
-
43
- steps = 15 if solver_mode == 'fast' else steps
44
- self.solver_mode=solver_mode
45
- self.steps=steps
46
- self.guide_scale=guide_scale
47
- self.upscale = upscale
48
- self.max_chunk_len=max_chunk_len
49
- self.variant_info=variant_info
50
-
51
- def enhance_a_video(self, video_path, prompt):
52
- logger.info('input video path: {}'.format(video_path))
53
- text = prompt
54
- logger.info('text: {}'.format(text))
55
- caption = text + self.model.positive_prompt
56
-
57
- input_frames, input_fps = load_video(video_path)
58
- in_f_num = len(input_frames)
59
- logger.info('input frames length: {}'.format(in_f_num))
60
- logger.info('input fps: {}'.format(input_fps))
61
-
62
- video_data = preprocess(input_frames)
63
- _, _, h, w = video_data.shape
64
- logger.info('input resolution: {}'.format((h, w)))
65
- target_h, target_w = h * self.upscale, w * self.upscale # adjust_resolution(h, w, up_scale=4)
66
- logger.info('target resolution: {}'.format((target_h, target_w)))
67
-
68
- pre_data = {'video_data': video_data, 'y': caption}
69
- pre_data['target_res'] = (target_h, target_w)
70
-
71
- total_noise_levels = 900
72
- setup_seed(666)
73
-
74
- with torch.no_grad():
75
- data_tensor = collate_fn(pre_data, 'cuda:0')
76
- output = self.model.test(data_tensor, total_noise_levels, steps=self.steps, \
77
- solver_mode=self.solver_mode, guide_scale=self.guide_scale, \
78
- max_chunk_len=self.max_chunk_len
79
- )
80
-
81
- output = tensor2vid(output)
82
-
83
- # Using color fix
84
- output = adain_color_fix(output, video_data)
85
-
86
- save_video(output, self.result_dir, self.file_name, fps=input_fps)
87
- return os.path.join(self.result_dir, self.file_name)
88
-
89
-
90
- def parse_args():
91
- parser = ArgumentParser()
92
-
93
- parser.add_argument("--input_path", required=True, type=str, help="input video path")
94
- parser.add_argument("--save_dir", type=str, default='results', help="save directory")
95
- parser.add_argument("--file_name", type=str, help="file name")
96
- parser.add_argument("--model_path", type=str, default='./pretrained_weight/I2VGen-XL-based/heavy_deg.pt', help="model path")
97
- parser.add_argument("--prompt", type=str, default='a good video', help="prompt")
98
- parser.add_argument("--upscale", type=int, default=4, help='up-scale')
99
- parser.add_argument("--max_chunk_len", type=int, default=32, help='max_chunk_len')
100
- parser.add_argument("--variant_info", type=str, default=None, help='information of inference strategy')
101
-
102
- parser.add_argument("--cfg", type=float, default=7.5)
103
- parser.add_argument("--solver_mode", type=str, default='fast', help='fast | normal')
104
- parser.add_argument("--steps", type=int, default=15)
105
-
106
- return parser.parse_args()
107
-
108
- def main():
109
-
110
- args = parse_args()
111
-
112
- input_path = args.input_path
113
- prompt = args.prompt
114
- model_path = args.model_path
115
- save_dir = args.save_dir
116
- file_name = args.file_name
117
- upscale = args.upscale
118
- max_chunk_len = args.max_chunk_len
119
-
120
- steps = args.steps
121
- solver_mode = args.solver_mode
122
- guide_scale = args.cfg
123
-
124
- assert solver_mode in ('fast', 'normal')
125
-
126
- star_sr = STAR_sr(
127
- result_dir=save_dir,
128
- file_name=file_name, # new added
129
- model_path=model_path,
130
- solver_mode=solver_mode,
131
- steps=steps,
132
- guide_scale=guide_scale,
133
- upscale=upscale,
134
- max_chunk_len=max_chunk_len,
135
- variant_info=None,
136
- )
137
-
138
- star_sr.enhance_a_video(input_path, prompt)
139
-
140
-
141
- if __name__ == '__main__':
142
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_super_resolution/scripts/inference_sr.sh DELETED
@@ -1,56 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Folder paths
4
- video_folder_path='./input/video'
5
- txt_file_path='./input/text/prompt.txt'
6
-
7
- # Get all .mp4 files in the folder using find to handle special characters
8
- mapfile -t mp4_files < <(find "$video_folder_path" -type f -name "*.mp4")
9
-
10
- # Print the list of MP4 files
11
- echo "MP4 files to be processed:"
12
- for mp4_file in "${mp4_files[@]}"; do
13
- echo "$mp4_file"
14
- done
15
-
16
- # Read lines from the text file, skipping empty lines
17
- mapfile -t lines < <(grep -v '^\s*$' "$txt_file_path")
18
-
19
- # List of frame counts
20
- frame_length=32
21
-
22
- # Debugging output
23
- echo "Number of MP4 files: ${#mp4_files[@]}"
24
- echo "Number of lines in the text file: ${#lines[@]}"
25
-
26
- # Ensure the number of video files matches the number of lines
27
- if [ ${#mp4_files[@]} -ne ${#lines[@]} ]; then
28
- echo "Number of MP4 files and lines in the text file do not match."
29
- exit 1
30
- fi
31
-
32
- # Loop through video files and corresponding lines
33
- for i in "${!mp4_files[@]}"; do
34
- mp4_file="${mp4_files[$i]}"
35
- line="${lines[$i]}"
36
-
37
- # Extract the filename without the extension
38
- file_name=$(basename "$mp4_file" .mp4)
39
-
40
- echo "Processing video file: $mp4_file with prompt: $line"
41
-
42
- # Run Python script with parameters
43
- python \
44
- ./video_super_resolution/scripts/inference_sr.py \
45
- --solver_mode 'fast' \
46
- --steps 15 \
47
- --input_path "${mp4_file}" \
48
- --model_path /mnt/bn/videodataset/VSR/pretrained_models/STAR/heavy_deg.pt \
49
- --prompt "${line}" \
50
- --upscale 4 \
51
- --max_chunk_len ${frame_length} \
52
- --file_name "${file_name}.mp4" \
53
- --save_dir ./results
54
- done
55
-
56
- echo "All videos processed successfully."