from PIL import Image import cv2 as cv import torch from RealESRGAN import RealESRGAN import tempfile import numpy as np import tqdm import ffmpeg device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def infer_image(img: Image.Image, size_modifier: int) -> Image.Image: if img is None: raise Exception("Image not uploaded") width, height = img.size if width >= 5000 or height >= 5000: raise Exception("Image is too large") model = RealESRGAN(device, scale=size_modifier) model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=True) result = model.predict(img.convert('RGB')) return result def infer_video(video_filepath: str, size_modifier: int) -> str: model = RealESRGAN(device, scale=size_modifier) model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=True) cap = cv.VideoCapture(video_filepath) tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) vid_output = tmpfile.name tmpfile.close() probe = ffmpeg.probe(video_filepath) has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) if has_audio: audio_file = video_filepath.replace(".mp4", ".wav") ffmpeg.input(video_filepath).output(audio_file, format='wav', ac=1).run(overwrite_output=True) vid_writer = cv.VideoWriter( vid_output, fourcc=cv.VideoWriter.fourcc(*'mp4v'), fps=cap.get(cv.CAP_PROP_FPS), frameSize=( int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier ) ) n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT)) for _ in tqdm.tqdm(range(n_frames)): ret, frame = cap.read() if not ret: break frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB) frame = Image.fromarray(frame) upscaled_frame = model.predict(frame.convert('RGB')) upscaled_frame = np.array(upscaled_frame) upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR) vid_writer.write(upscaled_frame) vid_writer.release() if has_audio: out_path = video_filepath.replace(".mp4", "_upscaled.mp4") ffmpeg.input(vid_output).output(out_path, vcodec='libx264', acodec='aac').run(overwrite_output=True) return out_path return vid_output