Asjad135 commited on
Commit
0ddcb8c
·
verified ·
1 Parent(s): b978d3c

Update real_esrgan.py

Browse files
Files changed (1) hide show
  1. real_esrgan.py +56 -18
real_esrgan.py CHANGED
@@ -1,22 +1,60 @@
 
1
  import os
2
  import tempfile
3
- import subprocess
 
 
 
4
 
5
- def upscale_video(video):
6
- with tempfile.TemporaryDirectory() as tempdir:
7
- input_path = os.path.join(tempdir, "input.mp4")
8
- output_path = os.path.join(tempdir, "output.mp4")
9
-
 
10
  with open(input_path, "wb") as f:
11
- f.write(video.read())
12
-
13
- cmd = [
14
- "python", "inference_realesrgan.py",
15
- "-i", input_path,
16
- "-o", output_path,
17
- "-n", "RealESRGAN_x4plus",
18
- "--outscale", "2"
19
- ]
20
- subprocess.run(cmd, check=True)
21
-
22
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import os
3
  import tempfile
4
+ from basicsr.archs.rrdbnet_arch import RRDBNet
5
+ from realesrgan import RealESRGANer
6
+ import torch
7
+ import cv2
8
 
9
+ def upscale_video(video_file):
10
+ with tempfile.TemporaryDirectory() as tmpdir:
11
+ input_path = os.path.join(tmpdir, "input.mp4")
12
+ output_path = os.path.join(tmpdir, "output.mp4")
13
+
14
+ # Save uploaded video
15
  with open(input_path, "wb") as f:
16
+ f.write(video_file.read())
17
+
18
+ cap = cv2.VideoCapture(input_path)
19
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
20
+ fps = cap.get(cv2.CAP_PROP_FPS)
21
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
22
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
23
+
24
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width*2, height*2))
25
+
26
+ # Load model
27
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
28
+ num_block=23, num_grow_ch=32, scale=4)
29
+ upsampler = RealESRGANer(
30
+ scale=4,
31
+ model_path='RealESRGAN_x4plus.pth',
32
+ model=model,
33
+ tile=0,
34
+ tile_pad=10,
35
+ pre_pad=0,
36
+ half=True
37
+ )
38
+
39
+ while True:
40
+ ret, frame = cap.read()
41
+ if not ret:
42
+ break
43
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
44
+ try:
45
+ output, _ = upsampler.enhance(frame)
46
+ except RuntimeError as error:
47
+ print(f"Error: {error}")
48
+ break
49
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
50
+ out.write(output)
51
+
52
+ cap.release()
53
+ out.release()
54
+
55
+ return output_path
56
+
57
+ demo = gr.Interface(fn=upscale_video, inputs=gr.Video(), outputs=gr.Video(), title="Video Upscaler (Real-ESRGAN)")
58
+
59
+ if __name__ == "__main__":
60
+ demo.launch()