r3gm commited on
Commit
d7ac01d
·
verified ·
1 Parent(s): 4053aa4

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. app.py +75 -165
  3. inference_video_w.py +316 -0
  4. packages.txt +1 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Test Mcp1
3
- emoji: 🌍
4
  colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
 
1
  ---
2
+ title: FPS_Enhancer
3
+ emoji: 🎞️⏫
4
  colorFrom: pink
5
  colorTo: red
6
  sdk: gradio
app.py CHANGED
@@ -2,127 +2,72 @@ import gradio as gr
2
  import subprocess
3
  import os
4
  import spaces
 
 
5
 
6
  # Download the file
7
  subprocess.run([
8
  "wget",
9
- "--no-check-certificate",
10
- "https://drive.google.com/uc?id=1mj9lH6Be7ztYtHAr1xUUGT3hRtWJBy_5",
11
  "-O",
12
- "RIFE_trained_model_v4.13.2.zip"
13
  ], check=True)
14
 
15
  # Unzip the downloaded file
16
  subprocess.run([
17
  "unzip",
18
- "RIFE_trained_model_v4.13.2.zip"
 
19
  ], check=True)
20
 
21
- # The name of your script
22
- SCRIPT_NAME = "inference_video.py"
23
-
24
- @spaces.GPU()
25
  def run_rife(
26
  input_video,
27
- model_dir,
28
- multi,
29
- exp,
30
- fps,
31
- scale,
32
- uhd,
33
- fp16,
34
- skip,
35
- montage,
36
- png_mode,
37
- ext
38
  ):
39
- """
40
- Constructs the command line arguments based on Gradio inputs
41
- and runs the inference_video.py script via subprocess.
42
- """
43
-
44
  if input_video is None:
45
  raise gr.Error("Please upload a video first.")
46
 
47
- # 1. Define Output Filename
48
- output_name = f"output_{multi}X.{ext}"
49
-
50
- # 2. Build the Command
51
- cmd = ["python3", SCRIPT_NAME]
52
-
53
- # --video
54
- cmd.extend(["--video", input_video])
55
-
56
- # --output
57
- cmd.extend(["--output", output_name])
58
 
59
- # --multi (Multiplier)
60
- cmd.extend(["--multi", str(int(multi))])
 
 
61
 
62
- # --exp
63
- # Only add exp if it is not default, or if specific logic requires it.
64
- # Usually multi overrides exp in RIFE logic, but we pass it if set.
65
- if exp != 1:
66
- cmd.extend(["--exp", str(int(exp))])
67
-
68
- # --fps (Target FPS)
69
- if fps > 0:
70
- cmd.extend(["--fps", str(int(fps))])
71
-
72
- # --scale (Resolution scale)
73
- # Check against float 1.0
74
- if scale != 1.0:
75
- cmd.extend(["--scale", str(scale)])
76
-
77
- # --ext (Extension)
78
- cmd.extend(["--ext", ext])
79
-
80
- # --model (Model directory)
81
- if model_dir and model_dir.strip() != "":
82
- cmd.extend(["--model", model_dir])
83
-
84
- # --- Boolean Flags ---
85
-
86
- if uhd:
87
- cmd.append("--UHD")
88
-
89
- if fp16:
90
- cmd.append("--fp16")
91
-
92
- if skip:
93
- cmd.append("--skip")
94
-
95
- if montage:
96
- cmd.append("--montage")
97
-
98
- if png_mode:
99
- cmd.append("--png")
100
-
101
- print(f"Executing command: {' '.join(cmd)}")
102
-
103
- # 3. Run the Subprocess
104
  try:
105
- # We use a large timeout because video processing takes time
106
- process = subprocess.run(cmd, capture_output=True, text=True)
107
-
108
- # Log stdout/stderr
109
- if process.stdout:
110
- print("STDOUT:", process.stdout)
111
- if process.stderr:
112
- print("STDERR:", process.stderr)
113
-
114
- if process.returncode != 0:
115
- raise gr.Error(f"Inference failed. Error: {process.stderr}")
116
-
117
- # 4. Return Result
118
- if png_mode:
119
- gr.Info("Processing complete. Output is a folder of PNGs (Video preview unavailable for PNG mode).")
120
- return None
121
-
122
- if os.path.exists(output_name):
123
- return output_name
124
  else:
125
- raise gr.Error("Output file was not found. Check console for details.")
126
 
127
  except Exception as e:
128
  raise gr.Error(f"An error occurred: {str(e)}")
@@ -130,111 +75,75 @@ def run_rife(
130
 
131
  # --- Gradio UI Layout ---
132
 
133
- with gr.Blocks(title="RIFE Video Interpolation") as app:
134
- gr.Markdown("# 🚀 RIFE: Real-Time Intermediate Flow Estimation")
135
- gr.Markdown("Upload a video to increase its frame rate (smoothness) using AI.")
 
136
 
137
  with gr.Row():
138
  # --- Left Column: Inputs & Settings ---
139
  with gr.Column(scale=1):
140
- input_vid = gr.Video(label="Input Video", sources=["upload"])
141
 
142
- with gr.Group():
143
- gr.Markdown("### 🎯 Core Parameters")
144
-
145
- with gr.Row():
146
- multi_param = gr.Dropdown(
147
- choices=["2", "4", "8", "16", "32"],
148
- value="2",
149
- label="Interpolation Multiplier (--multi)",
150
- info="How many times to multiply the frames. 2X doubles the FPS (e.g., 30fps -> 60fps). 4X quadruples it."
151
- )
152
- ext_param = gr.Dropdown(
153
- choices=["mp4", "avi", "mov", "mkv"],
154
- value="mp4",
155
- label="Output Format (--ext)",
156
- info="The file extension for the generated video."
157
- )
158
-
159
- model_param = gr.Textbox(
160
- value="train_log",
161
- label="Model Directory (--model)",
162
- placeholder="train_log",
163
- info="Path to the folder containing the trained model files (e.g., 'train_log' or 'rife-v4.6')."
164
  )
165
 
166
- with gr.Accordion("️ Advanced Settings", open=False):
167
- gr.Markdown("Fine-tune the inference process.")
168
 
169
  with gr.Row():
170
- scale_param = gr.Slider(
171
- minimum=0.1, maximum=1.0, value=1.0, step=0.1,
172
- label="Input Scale (--scale)",
173
- info="1.0 = Original resolution. Set to 0.5 to reduce memory usage for 4K video inputs."
 
174
  )
175
  fps_param = gr.Number(
176
  value=0,
177
- label="Force Target FPS (--fps)",
178
- info="0 = Auto-calculate based on multiplier. Enter a number (e.g., 60) to force a specific output frame rate."
179
  )
180
  exp_param = gr.Number(
181
  value=1,
182
- label="Exponent Power (--exp)",
183
- info="Alternative to Multiplier. Sets multiplier to 2^exp. (Usually left at 1 if Multiplier is set)."
184
  )
185
 
186
- with gr.Row():
187
- uhd_chk = gr.Checkbox(
188
- label="UHD Mode (--UHD)",
189
- value=False,
190
- info="Optimized for 4K video. Equivalent to setting scale=0.5 manually."
191
- )
192
- fp16_chk = gr.Checkbox(
193
- label="FP16 Mode (--fp16)",
194
- value=True,
195
- info="Uses half-precision floating point. Faster and uses less VRAM with minimal quality loss."
196
- )
197
-
198
  with gr.Row():
199
  skip_chk = gr.Checkbox(
200
- label="Skip Static Frames (--skip)",
201
  value=False,
202
- info="If the video has frames that don't move, skip processing them to save time."
203
  )
204
  montage_chk = gr.Checkbox(
205
- label="Montage (--montage)",
206
- value=False,
207
- info="Creates a video with the Original on the Left and Interpolated on the Right for comparison."
208
- )
209
- png_chk = gr.Checkbox(
210
- label="Output as PNGs (--png)",
211
  value=False,
212
- info="Outputs a sequence of images instead of a video file. (Video Preview will be disabled)."
213
  )
214
 
215
- btn_run = gr.Button(" Start Interpolation", variant="primary", size="lg")
216
 
217
  # --- Right Column: Output ---
218
  with gr.Column(scale=1):
219
- output_vid = gr.Video(label="Interpolated Result")
220
- gr.Markdown("**Note:** Processing time depends on video length, resolution, and your GPU speed.")
221
 
222
  # --- Bind Logic ---
223
  btn_run.click(
224
  fn=run_rife,
225
  inputs=[
226
  input_vid,
227
- model_param,
228
  multi_param,
229
  exp_param,
230
  fps_param,
231
  scale_param,
232
- uhd_chk,
233
- fp16_chk,
234
  skip_chk,
235
- montage_chk,
236
- png_chk,
237
- ext_param
238
  ],
239
  outputs=output_vid
240
  )
@@ -242,4 +151,5 @@ with gr.Blocks(title="RIFE Video Interpolation") as app:
242
  if __name__ == "__main__":
243
  app.launch(
244
  theme=gr.themes.Soft(),
 
245
  )
 
2
  import subprocess
3
  import os
4
  import spaces
5
+ import inference_video_w
6
+ import torch
7
 
8
  # Download the file
9
  subprocess.run([
10
  "wget",
11
+ "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip",
 
12
  "-O",
13
+ "RIFEv4.26_0921.zip"
14
  ], check=True)
15
 
16
  # Unzip the downloaded file
17
  subprocess.run([
18
  "unzip",
19
+ "-o",
20
+ "RIFEv4.26_0921.zip"
21
  ], check=True)
22
 
23
+ @spaces.GPU(duration=120)
 
 
 
24
  def run_rife(
25
  input_video,
26
+ frame_multiplier,
27
+ time_exponent,
28
+ fixed_fps,
29
+ video_scale,
30
+ remove_duplicate_frames,
31
+ create_montage,
32
+ progress=gr.Progress(track_tqdm=True),
 
 
 
 
33
  ):
 
 
 
 
 
34
  if input_video is None:
35
  raise gr.Error("Please upload a video first.")
36
 
37
+ ext = "mp4"
38
+ model_dir = "train_log"
 
 
 
 
 
 
 
 
 
39
 
40
+ # Construct output filename pattern to match what inference_video.py expects/generates
41
+ video_path_wo_ext = os.path.splitext(os.path.basename(input_video))[0]
42
+ # We pass the desired output name, though the function logic tries to stick to this pattern anyway
43
+ output_base_name = "{}_{}X_fps.{}".format(video_path_wo_ext, int(frame_multiplier), ext)
44
 
45
+ if fixed_fps > 0:
46
+ gr.Warning("Will not merge audio because using fps flag!")
47
+
48
+ print(f"Starting Inference for: {input_video}")
49
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ # Call the imported function directly
52
+ result_path = inference_video_w.inference(
53
+ video=input_video,
54
+ output=output_base_name,
55
+ modelDir=model_dir,
56
+ fp16=(True if torch.cuda.is_available() else False),
57
+ UHD=False,
58
+ scale=video_scale,
59
+ skip=remove_duplicate_frames,
60
+ fps=int(fixed_fps) if fixed_fps > 0 else None,
61
+ ext=ext,
62
+ exp=int(time_exponent),
63
+ multi=int(frame_multiplier),
64
+ montage=create_montage
65
+ )
66
+
67
+ if result_path and os.path.exists(result_path):
68
+ return result_path
 
69
  else:
70
+ raise gr.Error(f"Output file not found. Expected: {result_path}")
71
 
72
  except Exception as e:
73
  raise gr.Error(f"An error occurred: {str(e)}")
 
75
 
76
  # --- Gradio UI Layout ---
77
 
78
+ with gr.Blocks(title="Frame Rate Enhancer") as app:
79
+ gr.Markdown("# RIFE: Frame Rate Enhancer")
80
+ gr.Markdown("Creates extra frames between the original ones to make motion in your videos smoother and more fluid.")
81
+ gr.Markdown("⚠️ **Notice:** Keep input videos under 60 seconds for frame interpolation to prevent GPU task aborts.")
82
 
83
  with gr.Row():
84
  # --- Left Column: Inputs & Settings ---
85
  with gr.Column(scale=1):
86
+ input_vid = gr.Video(label="🎬 Input Source Video", sources=["upload"])
87
 
88
+ with gr.Group():
89
+ multi_param = gr.Dropdown(
90
+ choices=["2", "3", "4", "5", "6"],
91
+ value="2",
92
+ label="🗃️ Frame Multiplier",
93
+ info="2X = Double FPS (e.g. 30 -> 60). Higher multipliers create more intermediate frames."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
 
96
+ with gr.Accordion("🛠️ Advanced Configuration", open=False):
97
+ gr.Markdown("Control rendering parameters.")
98
 
99
  with gr.Row():
100
+ scale_param = gr.Dropdown(
101
+ choices=[0.25, 0.5, 1.0, 2.0, 4.0],
102
+ value=1.0,
103
+ label="📉 Render Scale",
104
+ info="1.0 = Original Resolution. Reduce to 0.5 for faster processing on 4K content."
105
  )
106
  fps_param = gr.Number(
107
  value=0,
108
+ label="🎯 Force Output FPS",
109
+ info="0 = Auto-calculate. Set to 30 or 60 to lock the framerate. Audio will be removed when forcing FPS"
110
  )
111
  exp_param = gr.Number(
112
  value=1,
113
+ label="🔢 Exponent Power",
114
+ info="Alternative multiplier calculation (2^exp)."
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  with gr.Row():
118
  skip_chk = gr.Checkbox(
119
+ label="Skip Static Frames",
120
  value=False,
121
+ info="Bypass processing for static frames to save time."
122
  )
123
  montage_chk = gr.Checkbox(
124
+ label="🆚 Split-Screen Comparison",
 
 
 
 
 
125
  value=False,
126
+ info="Output video showing Original vs. Processed."
127
  )
128
 
129
+ btn_run = gr.Button("GENERATE INTERMEDIATE FRAMES", variant="primary", size="lg")
130
 
131
  # --- Right Column: Output ---
132
  with gr.Column(scale=1):
133
+ output_vid = gr.Video(label="INTERPOLATED RESULT")
134
+ gr.Markdown("**Status:** Rendering time depends on input resolution and duration.")
135
 
136
  # --- Bind Logic ---
137
  btn_run.click(
138
  fn=run_rife,
139
  inputs=[
140
  input_vid,
 
141
  multi_param,
142
  exp_param,
143
  fps_param,
144
  scale_param,
 
 
145
  skip_chk,
146
+ montage_chk
 
 
147
  ],
148
  outputs=output_vid
149
  )
 
151
  if __name__ == "__main__":
152
  app.launch(
153
  theme=gr.themes.Soft(),
154
+ mcp_server=True,
155
  )
inference_video_w.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from torch.nn import functional as F
7
+ import warnings
8
+ import _thread
9
+ import skvideo.io
10
+ from queue import Queue, Empty
11
+ from model.pytorch_msssim import ssim_matlab
12
+ import shutil
13
+ import tempfile
14
+ import time
15
+
16
+ warnings.filterwarnings("ignore")
17
+
18
+ # Utility class to mimic argparse object
19
+ class Args:
20
+ def __init__(self, **kwargs):
21
+ self.__dict__.update(kwargs)
22
+
23
+ def transferAudio(sourceVideo, targetVideo):
24
+ # generate a unique temp directory for this user
25
+ unique_temp_dir = tempfile.mkdtemp()
26
+ tempAudioFileName = os.path.join(unique_temp_dir, "audio.mkv")
27
+
28
+ # extract audio from video
29
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName))
30
+
31
+ targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
32
+ os.rename(targetVideo, targetNoAudio)
33
+ # combine audio file and new video file
34
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
35
+
36
+ if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac
37
+ tempAudioFileName = os.path.join(unique_temp_dir, "audio.m4a")
38
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName))
39
+ os.system('ffmpeg -hide_banner -loglevel error -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo))
40
+ if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format
41
+ os.rename(targetNoAudio, targetVideo)
42
+ print("Audio transfer failed. Interpolated video will have no audio")
43
+ else:
44
+ print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.")
45
+ # remove audio-less video
46
+ os.remove(targetNoAudio)
47
+ else:
48
+ os.remove(targetNoAudio)
49
+
50
+ # remove temp directory
51
+ shutil.rmtree(unique_temp_dir)
52
+
53
+ def inference(
54
+ video=None,
55
+ output=None,
56
+ img=None,
57
+ montage=False,
58
+ modelDir='train_log',
59
+ fp16=False,
60
+ UHD=False,
61
+ scale=1.0,
62
+ skip=False,
63
+ fps=None,
64
+ png=False,
65
+ ext='mp4',
66
+ exp=1,
67
+ multi=2
68
+ ):
69
+ # Initialize Arguments Object
70
+ args = Args(
71
+ video=video, output=output, img=img, montage=montage,
72
+ modelDir=modelDir, fp16=fp16, UHD=UHD, scale=scale,
73
+ skip=skip, fps=fps, png=png, ext=ext, exp=exp, multi=multi
74
+ )
75
+
76
+ # Argument Logic Adjustment
77
+ if args.exp != 1:
78
+ args.multi = (2 ** args.exp)
79
+
80
+ # Assertions
81
+ assert (not args.video is None or not args.img is None)
82
+ if args.skip:
83
+ print("skip flag is abandoned, please refer to issue #207.")
84
+ if args.UHD and args.scale==1.0:
85
+ args.scale = 0.5
86
+ assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0]
87
+ if not args.img is None:
88
+ args.png = True
89
+
90
+ # Device Setup
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ torch.set_grad_enabled(False)
93
+ if torch.cuda.is_available():
94
+ torch.backends.cudnn.enabled = True
95
+ torch.backends.cudnn.benchmark = True
96
+ if(args.fp16):
97
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
98
+
99
+ # Load Model
100
+ from train_log.RIFE_HDv3 import Model
101
+ model = Model()
102
+ if not hasattr(model, 'version'):
103
+ model.version = 0
104
+ model.load_model(args.modelDir, -1)
105
+ print("Loaded 3.x/4.x HD model.")
106
+ model.eval()
107
+ model.device()
108
+
109
+ # Video/Image Setup
110
+ if not args.video is None:
111
+ videoCapture = cv2.VideoCapture(args.video)
112
+ original_fps = videoCapture.get(cv2.CAP_PROP_FPS)
113
+ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
114
+ videoCapture.release()
115
+
116
+ if args.fps is None or args.fps == 0:
117
+ fpsNotAssigned = True
118
+ args.fps = original_fps * args.multi
119
+ else:
120
+ fpsNotAssigned = False
121
+
122
+ videogen = skvideo.io.vreader(args.video)
123
+ lastframe = next(videogen)
124
+ # fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') # Unused in original logic for skvideo
125
+ video_path_wo_ext, ext = os.path.splitext(args.video)
126
+ print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, original_fps, args.fps))
127
+
128
+ if args.png == False and fpsNotAssigned == True:
129
+ print("The audio will be merged after interpolation process")
130
+ else:
131
+ print("Will not merge audio because using png or fps flag!")
132
+ else:
133
+ videogen = []
134
+ for f in os.listdir(args.img):
135
+ if 'png' in f:
136
+ videogen.append(f)
137
+ tot_frame = len(videogen)
138
+ videogen.sort(key= lambda x:int(x[:-4]))
139
+ lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
140
+ videogen = videogen[1:]
141
+
142
+ h, w, _ = lastframe.shape
143
+ vid_out_name = None
144
+ vid_out = None
145
+
146
+ if args.png:
147
+ if not os.path.exists('vid_out'):
148
+ os.mkdir('vid_out')
149
+ else:
150
+ if args.output is not None:
151
+ vid_out_name = args.output
152
+ else:
153
+ vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext)
154
+
155
+ outputdict = {
156
+ '-c:v': 'libx264',
157
+ '-crf': '17',
158
+ '-preset': 'slow',
159
+ '-pix_fmt': 'yuv420p'
160
+ }
161
+ vid_out = skvideo.io.FFmpegWriter(vid_out_name, inputdict={'-r': str(args.fps)}, outputdict=outputdict)
162
+
163
+ # --- Nested Helper Functions to capture 'args', 'model', 'vid_out' scope ---
164
+
165
+ def clear_write_buffer(write_buffer):
166
+ cnt = 0
167
+ while True:
168
+ item = write_buffer.get()
169
+ if item is None:
170
+ break
171
+ if args.png:
172
+ cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1])
173
+ cnt += 1
174
+ else:
175
+ vid_out.writeFrame(item)
176
+
177
+ def build_read_buffer(read_buffer, videogen):
178
+ try:
179
+ for frame in videogen:
180
+ if not args.img is None:
181
+ frame = cv2.imread(os.path.join(args.img, frame), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy()
182
+ if args.montage:
183
+ frame = frame[:, left: left + w]
184
+ read_buffer.put(frame)
185
+ except:
186
+ pass
187
+ read_buffer.put(None)
188
+
189
+ def make_inference(I0, I1, n):
190
+ if model.version >= 3.9:
191
+ res = []
192
+ for i in range(n):
193
+ res.append(model.inference(I0, I1, (i+1) * 1. / (n+1), args.scale))
194
+ return res
195
+ else:
196
+ middle = model.inference(I0, I1, args.scale)
197
+ if n == 1:
198
+ return [middle]
199
+ first_half = make_inference(I0, middle, n=n//2)
200
+ second_half = make_inference(middle, I1, n=n//2)
201
+ if n%2:
202
+ return [*first_half, middle, *second_half]
203
+ else:
204
+ return [*first_half, *second_half]
205
+
206
+ def pad_image(img):
207
+ if(args.fp16):
208
+ return F.pad(img, padding).half()
209
+ else:
210
+ return F.pad(img, padding)
211
+
212
+ # --- Pre-Loop Setup ---
213
+
214
+ left = 0 # Define default
215
+ if args.montage:
216
+ left = w // 4
217
+ w = w // 2
218
+
219
+ tmp = max(128, int(128 / args.scale))
220
+ ph = ((h - 1) // tmp + 1) * tmp
221
+ pw = ((w - 1) // tmp + 1) * tmp
222
+ padding = (0, pw - w, 0, ph - h)
223
+
224
+ pbar = tqdm(total=tot_frame)
225
+ if args.montage:
226
+ lastframe = lastframe[:, left: left + w]
227
+
228
+ write_buffer = Queue(maxsize=500)
229
+ read_buffer = Queue(maxsize=500)
230
+
231
+ # Start threads
232
+ _thread.start_new_thread(build_read_buffer, (read_buffer, videogen))
233
+ _thread.start_new_thread(clear_write_buffer, (write_buffer,))
234
+
235
+ I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
236
+ I1 = pad_image(I1)
237
+ temp = None
238
+
239
+ # --- Main Loop ---
240
+
241
+ while True:
242
+ if temp is not None:
243
+ frame = temp
244
+ temp = None
245
+ else:
246
+ frame = read_buffer.get()
247
+ if frame is None:
248
+ break
249
+ I0 = I1
250
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
251
+ I1 = pad_image(I1)
252
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
253
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
254
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
255
+
256
+ break_flag = False
257
+ if ssim > 0.996:
258
+ frame = read_buffer.get() # read a new frame
259
+ if frame is None:
260
+ break_flag = True
261
+ frame = lastframe
262
+ else:
263
+ temp = frame
264
+ I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
265
+ I1 = pad_image(I1)
266
+ I1 = model.inference(I0, I1, scale=args.scale)
267
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
268
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
269
+ frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
270
+
271
+ if ssim < 0.2:
272
+ output_frames = []
273
+ for i in range(args.multi - 1):
274
+ output_frames.append(I0)
275
+ else:
276
+ output_frames = make_inference(I0, I1, args.multi - 1)
277
+
278
+ if args.montage:
279
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
280
+ for mid in output_frames:
281
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
282
+ write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
283
+ else:
284
+ write_buffer.put(lastframe)
285
+ for mid in output_frames:
286
+ mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
287
+ write_buffer.put(mid[:h, :w])
288
+ pbar.update(1)
289
+ lastframe = frame
290
+ if break_flag:
291
+ break
292
+
293
+ if args.montage:
294
+ write_buffer.put(np.concatenate((lastframe, lastframe), 1))
295
+ else:
296
+ write_buffer.put(lastframe)
297
+
298
+ write_buffer.put(None)
299
+
300
+ while(not write_buffer.empty()):
301
+ time.sleep(0.1)
302
+ pbar.close()
303
+
304
+ if not vid_out is None:
305
+ vid_out.close()
306
+
307
+ # Audio Transfer Logic
308
+ if args.png == False and fpsNotAssigned == True and not args.video is None:
309
+ try:
310
+ transferAudio(args.video, vid_out_name)
311
+ except:
312
+ print("Audio transfer failed. Interpolated video will have no audio")
313
+ targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1]
314
+ os.rename(targetNoAudio, vid_out_name)
315
+
316
+ return vid_out_name
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg