| | import gradio as gr |
| | import subprocess |
| |
|
| | import yaml |
| | from tqdm import tqdm |
| |
|
| | import imageio |
| | import numpy as np |
| | from skimage.transform import resize |
| | from skimage import img_as_ubyte |
| | import torch |
| | from sync_batchnorm import DataParallelWithCallback |
| |
|
| | from modules.generator import OcclusionAwareGenerator |
| | from modules.keypoint_detector import KPDetector |
| | from animate import normalize_kp |
| |
|
| |
|
| | def load_checkpoints(config_path, checkpoint_path, cpu=False): |
| |
|
| | with open(config_path) as f: |
| | config = yaml.load(f) |
| |
|
| | generator = OcclusionAwareGenerator( |
| | **config["model_params"]["generator_params"], **config["model_params"]["common_params"] |
| | ) |
| | if not cpu: |
| | generator.cuda() |
| |
|
| | kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) |
| | if not cpu: |
| | kp_detector.cuda() |
| |
|
| | if cpu: |
| | checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) |
| | else: |
| | checkpoint = torch.load(checkpoint_path) |
| |
|
| | generator.load_state_dict(checkpoint["generator"]) |
| | kp_detector.load_state_dict(checkpoint["kp_detector"]) |
| |
|
| | if not cpu: |
| | generator = DataParallelWithCallback(generator) |
| | kp_detector = DataParallelWithCallback(kp_detector) |
| |
|
| | generator.eval() |
| | kp_detector.eval() |
| |
|
| | return generator, kp_detector |
| |
|
| |
|
| | def make_animation( |
| | source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False |
| | ): |
| | with torch.no_grad(): |
| | predictions = [] |
| | source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) |
| | if not cpu: |
| | source = source.cuda() |
| | driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) |
| | kp_source = kp_detector(source) |
| | kp_driving_initial = kp_detector(driving[:, :, 0]) |
| |
|
| | for frame_idx in tqdm(range(driving.shape[2])): |
| | driving_frame = driving[:, :, frame_idx] |
| | if not cpu: |
| | driving_frame = driving_frame.cuda() |
| | kp_driving = kp_detector(driving_frame) |
| | kp_norm = normalize_kp( |
| | kp_source=kp_source, |
| | kp_driving=kp_driving, |
| | kp_driving_initial=kp_driving_initial, |
| | use_relative_movement=relative, |
| | use_relative_jacobian=relative, |
| | adapt_movement_scale=adapt_movement_scale, |
| | ) |
| | out = generator(source, kp_source=kp_source, kp_driving=kp_norm) |
| |
|
| | predictions.append(np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0]) |
| | return predictions |
| |
|
| |
|
| | def inference(video, image): |
| | |
| | cmd = f"ffmpeg -y -ss 00:00:00 -i {video} -to 00:00:08 -c copy video_input.mp4" |
| | subprocess.run(cmd.split()) |
| | video = "video_input.mp4" |
| |
|
| | source_image = imageio.imread(image) |
| | reader = imageio.get_reader(video) |
| | fps = reader.get_meta_data()["fps"] |
| | driving_video = [] |
| | try: |
| | for im in reader: |
| | driving_video.append(im) |
| | except RuntimeError: |
| | pass |
| | reader.close() |
| |
|
| | source_image = resize(source_image, (256, 256))[..., :3] |
| | driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] |
| |
|
| | predictions = make_animation( |
| | source_image, |
| | driving_video, |
| | generator, |
| | kp_detector, |
| | relative=True, |
| | adapt_movement_scale=True, |
| | cpu=True, |
| | ) |
| | imageio.mimsave("result.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps) |
| | imageio.mimsave("driving.mp4", [img_as_ubyte(frame) for frame in driving_video], fps=fps) |
| | cmd = f"ffmpeg -y -i result.mp4 -i {video} -c copy -map 0:0 -map 1:1 -shortest out.mp4" |
| | subprocess.run(cmd.split()) |
| | cmd = "ffmpeg -y -i driving.mp4 -i out.mp4 -filter_complex hstack=inputs=2 final.mp4" |
| | subprocess.run(cmd.split()) |
| | return "final.mp4" |
| |
|
| |
|
| | title = "First Order Motion Model" |
| | description = "Gradio demo for First Order Motion Model. Read more at the links below. Upload a video file (cropped to face), a facial image and have fun :D. Please note that your video will be trimmed to first 8 seconds." |
| | article = "<p style='text-align: center'><a href='https://papers.nips.cc/paper/2019/file/31c0b36aef265d9221af80872ceb62f9-Paper.pdf' target='_blank'>First Order Motion Model for Image Animation</a> | <a href='https://github.com/AliaksandrSiarohin/first-order-model' target='_blank'>Github Repo</a></p>" |
| | examples = [["bella_porch.mp4", "julien.png"]] |
| | generator, kp_detector = load_checkpoints( |
| | config_path="config/vox-256.yaml", |
| | checkpoint_path="weights/vox-adv-cpk.pth.tar", |
| | cpu=True, |
| | ) |
| |
|
| | iface = gr.Interface( |
| | inference, |
| | [ |
| | gr.inputs.Video(type="mp4"), |
| | gr.inputs.Image(type="filepath"), |
| | ], |
| | outputs=gr.outputs.Video(label="Output Video"), |
| | examples=examples, |
| | enable_queue=True, |
| | title=title, |
| | article=article, |
| | description=description, |
| | ) |
| | iface.launch(debug=True) |
| |
|