Meloo commited on
Commit
188d68e
·
verified ·
1 Parent(s): 01e5b5f

Upload 4 files

Browse files
AIM24-VSR-SAFMNPP/SAFMNPP.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ class SimpleSAFM(nn.Module):
7
+ def __init__(self, dim):
8
+ super().__init__()
9
+
10
+ self.proj = nn.Conv2d(dim, dim, 3, 1, 1, bias=False)
11
+ self.dwconv = nn.Conv2d(dim//2, dim//2, 3, 1, 1, groups=dim//2, bias=False)
12
+ self.out = nn.Conv2d(dim, dim, 1, 1, 0, bias=False)
13
+ self.act = nn.GELU()
14
+
15
+ def forward(self, x):
16
+ h, w = x.size()[-2:]
17
+
18
+ x0, x1 = self.proj(x).chunk(2, dim=1)
19
+
20
+ x2 = F.adaptive_max_pool2d(x0, (h//8, w//8))
21
+ x2 = self.dwconv(x2)
22
+ x2 = F.interpolate(x2, size=(h, w), mode='bilinear')
23
+ x2 = self.act(x2) * x0
24
+
25
+ x = torch.cat([x1, x2], dim=1)
26
+ x = self.out(self.act(x))
27
+ return x
28
+
29
+
30
+ class CCM(nn.Module):
31
+ def __init__(self, dim, ffn_scale):
32
+ super().__init__()
33
+
34
+ self.conv = nn.Sequential(
35
+ nn.Conv2d(dim, int(dim*ffn_scale), 3, 1, 1, bias=False),
36
+ nn.GELU(),
37
+ nn.Conv2d(int(dim*ffn_scale), dim, 1, 1, 0, bias=False)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.conv(x)
42
+
43
+ class AttBlock(nn.Module):
44
+ def __init__(self, dim, ffn_scale):
45
+ super().__init__()
46
+
47
+ self.conv1 = SimpleSAFM(dim)
48
+ self.conv2 = CCM(dim, ffn_scale)
49
+
50
+ def forward(self, x):
51
+
52
+ out = self.conv1(x)
53
+ out = self.conv2(out)
54
+ return out
55
+
56
+ class SAFMNPP(nn.Module):
57
+ def __init__(self, dim=32, n_blocks=2, ffn_scale=1.5, upscaling_factor=4):
58
+ super().__init__()
59
+ self.scale = upscaling_factor
60
+
61
+ self.to_feat = nn.Conv2d(3, dim, 3, 1, 1, bias=False)
62
+
63
+ self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
64
+
65
+ self.to_img = nn.Sequential(
66
+ nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1, bias=False),
67
+ nn.PixelShuffle(upscaling_factor)
68
+ )
69
+
70
+ def forward(self, x):
71
+
72
+ b = x.shape[0]
73
+ x = rearrange(x, 'b t c h w -> (b t) c h w')
74
+ x = self.to_feat(x)
75
+ x = self.feats(x) + x
76
+ x = self.to_img(x)
77
+ x = rearrange(x, '(b t) c h w -> b t c h w', b = b)
78
+ return x
79
+
80
+
81
+
82
+
83
+ if __name__== '__main__':
84
+ #############Test Model Complexity #############
85
+ # import time
86
+ from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
87
+ from tqdm import tqdm
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+
90
+ scale = 4
91
+ h, w = 3840, 2160
92
+
93
+ # scale = 3
94
+ # h, w = 1920, 1080
95
+
96
+ x = torch.randn(1, 30, 3, h// scale, w // scale)
97
+
98
+ model = SAFMNPP(upscaling_factor=scale)
99
+ model.load_state_dict(torch.load('light_safmnpp.pth')['params'], strict=True)
100
+
101
+ # output = model(x)
102
+ print(model)
103
+ # print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
104
+
105
+ # print(output.shape)
106
+
107
+
108
+ # num_frame = 30
109
+ # clip = 5
110
+
111
+ # torch.cuda.current_device()
112
+ # torch.cuda.empty_cache()
113
+ # torch.backends.cudnn.benchmark = False
114
+
115
+ # start = torch.cuda.Event(enable_timing=True)
116
+ # end = torch.cuda.Event(enable_timing=True)
117
+ # runtime = 0
118
+
119
+ # dummy_input = torch.randn((1, num_frame, 3, h // scale, w // scale)).to(device)
120
+ # # warm_up
121
+ # model.eval().to(device)
122
+ # with torch.no_grad():
123
+ # for _ in tqdm(range(clip)):
124
+ # _ = model(dummy_input)
125
+
126
+ # for _ in tqdm(range(clip)):
127
+ # start.record()
128
+ # _ = model(dummy_input)
129
+ # end.record()
130
+ # torch.cuda.synchronize()
131
+ # runtime += start.elapsed_time(end)
132
+
133
+ # per_frame_time = runtime / (num_frame * clip)
134
+
135
+ # print(f'{model.__class__.__name__} {num_frame * clip} Number Frames x{scale}SR Per Frame Time: {per_frame_time:.6f} ms')
136
+ # print(f'{model.__class__.__name__} x{scale}SR FPS: {(1000 / per_frame_time):.6f} FPS')
137
+
138
+
139
+
140
+
AIM24-VSR-SAFMNPP/light_safmnpp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a542c92072cb25adab1f9cc5209d4f4f4ca8549db084e6703d2e032357cd50a7
3
+ size 538077
AIM24-VSR-SAFMNPP/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=1.8
2
+ av
3
+ torchvision
AIM24-VSR-SAFMNPP/vsr_run.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
+ import subprocess
4
+ import torch
5
+ import torchvision
6
+ import imageio
7
+ import glob
8
+
9
+ from SAFMNPP import SAFMNPP
10
+
11
+ def main(input_path, output_path, video_name, model):
12
+ """ Script for testing video super resolution models.
13
+
14
+ This script uses BasicVSR++ as a demo. Please replace the model loading
15
+ and prediction sections with your own model.
16
+ """
17
+
18
+ tmp_path = os.path.join('/frams', video_name[:-4])
19
+ os.makedirs(tmp_path, exist_ok=True)
20
+
21
+ video_path = os.path.join(output_path, video_name)
22
+ if os.path.exists(video_path):
23
+ return
24
+
25
+ input_video = torchvision.io.read_video( os.path.join(input_path, video_name)) #torchvision.io.read_video(args.input)
26
+ normalized_frames = input_video[0].permute(0, 3, 1, 2) # THWC to TCHW
27
+ normalized_frames = torchvision.transforms.functional.convert_image_dtype(normalized_frames, torch.float32)
28
+ input_data = normalized_frames.unsqueeze(0)
29
+
30
+ device = torch.device('cuda', 0)
31
+
32
+ #==========Replace the model loading and prediction in this section========
33
+ print(f'total frames: {input_data.size(1)}')
34
+ with torch.no_grad():
35
+ frame_idx = 0
36
+ for xi in input_data.chunk(100, dim=1):
37
+ # output.append()
38
+ frames = model(xi.to(device)).detach_().cpu()
39
+ for _, frame in enumerate(frames.squeeze(0).unbind(dim=0)):
40
+ frame = frame.clamp(0, 1) # Clamp values to be between 0 and 1
41
+ frame = torchvision.transforms.functional.convert_image_dtype(frame, torch.uint8)
42
+ frame = frame.squeeze(0).permute(1, 2, 0) # CTHW to HWC
43
+
44
+ if not os.path.exists(os.path.join(tmp_path, f'{frame_idx:08d}.png')):
45
+ imageio.imwrite(os.path.join(tmp_path, f'{frame_idx:08d}.png'), frame.numpy())
46
+ print('save frames : ', os.path.join(tmp_path, f'{frame_idx:08d}.png'))
47
+ else:
48
+ print('exist frame : ', os.path.join(tmp_path, f'{frame_idx:08d}.png'))
49
+ frame_idx+= 1
50
+
51
+ fps = input_video[2]['video_fps']
52
+ cmd = (
53
+ f"ffmpeg -r {fps} -i {tmp_path}/%08d.png "
54
+ f"-c:v libx264 -crf 12 -preset veryfast {video_path}"
55
+ )
56
+
57
+ try:
58
+ subprocess.run(cmd, shell=True, check=True)
59
+ print("Video created successfully.")
60
+
61
+ # 删除帧图片
62
+ for frame_filename in glob.glob(os.path.join(tmp_path, '*.png')):
63
+ os.remove(frame_filename)
64
+ print(f"Deleted {frame_filename}")
65
+
66
+ except subprocess.CalledProcessError as e:
67
+ print(f"An error occurred while trying to run FFmpeg: {e}")
68
+
69
+
70
+ if __name__ == '__main__':
71
+ device = torch.device('cuda', 0)
72
+ model = SAFMNPP(upscaling_factor=4).to(device)
73
+ model_path = os.path.join(r'light_safmnpp.pth')
74
+ model.load_state_dict(torch.load(model_path)['params'], strict=True)
75
+
76
+ input_path = r'ValidationSet-1080p/bitstreams'
77
+ output_path = r'Video_Output_4X'
78
+
79
+ if not os.path.exists(output_path):
80
+ os.makedirs(output_path)
81
+
82
+ for video_name in os.listdir(input_path):
83
+ main(input_path, output_path, video_name, model)
84
+ print("Done", video_name)
85
+
86
+