File size: 2,397 Bytes
9bdccea
f84f6c9
9bdccea
 
 
 
f84f6c9
 
9bdccea
f84f6c9
9bdccea
681d8de
71ca23a
 
 
 
681d8de
71ca23a
681d8de
71ca23a
 
9bdccea
 
 
681d8de
bca482d
f84f6c9
9bdccea
 
 
 
f026dff
 
 
 
 
 
f84f6c9
 
9bdccea
f84f6c9
 
681d8de
 
 
 
9bdccea
 
f84f6c9
 
9bdccea
 
 
f84f6c9
9bdccea
f84f6c9
a734e0b
f84f6c9
9bdccea
 
 
 
f026dff
681d8de
 
 
a734e0b
681d8de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from PIL import Image
import cv2 as cv
import torch
from RealESRGAN import RealESRGAN
import tempfile
import numpy as np
import tqdm
import ffmpeg

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def infer_image(img: Image.Image, size_modifier: int) -> Image.Image:
    if img is None:
        raise Exception("Image not uploaded")
    width, height = img.size
    if width >= 5000 or height >= 5000:
        raise Exception("Image is too large")
    model = RealESRGAN(device, scale=size_modifier)
    model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=True)
    result = model.predict(img.convert('RGB'))
    return result

def infer_video(video_filepath: str, size_modifier: int) -> str:
    model = RealESRGAN(device, scale=size_modifier)
    model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=True)

    cap = cv.VideoCapture(video_filepath)
    tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    vid_output = tmpfile.name
    tmpfile.close()

    probe = ffmpeg.probe(video_filepath)
    has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])

    if has_audio:
        audio_file = video_filepath.replace(".mp4", ".wav")
        ffmpeg.input(video_filepath).output(audio_file, format='wav', ac=1).run(overwrite_output=True)

    vid_writer = cv.VideoWriter(
        vid_output,
        fourcc=cv.VideoWriter.fourcc(*'mp4v'),
        fps=cap.get(cv.CAP_PROP_FPS),
        frameSize=(
            int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier,
            int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier
        )
    )

    n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
    for _ in tqdm.tqdm(range(n_frames)):
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
        frame = Image.fromarray(frame)
        upscaled_frame = model.predict(frame.convert('RGB'))
        upscaled_frame = np.array(upscaled_frame)
        upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
        vid_writer.write(upscaled_frame)

    vid_writer.release()

    if has_audio:
        out_path = video_filepath.replace(".mp4", "_upscaled.mp4")
        ffmpeg.input(vid_output).output(out_path, vcodec='libx264', acodec='aac').run(overwrite_output=True)
        return out_path

    return vid_output