| import cv2 |
| import glob |
| import torch |
| import gradio as gr |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
|
|
| from networks.amts import Model as AMTS |
| from networks.amtl import Model as AMTL |
| from networks.amtg import Model as AMTG |
| from utils import ( |
| img2tensor, tensor2img, |
| InputPadder, |
| check_dim_and_resize |
| ) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model_dict = { |
| 'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG |
| } |
|
|
| def img2vid(model_type, img0, img1, frame_ratio, iters): |
| model = model_dict[model_type]() |
| model.to(device) |
| ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') |
| ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) |
| model.load_state_dict(ckpt['state_dict']) |
| model.eval() |
| img0_t = img2tensor(img0).to(device) |
| img1_t = img2tensor(img1).to(device) |
| inputs = [img0_t, img1_t] |
| |
| |
| if device == 'cuda': |
| anchor_resolution = 1024 * 512 |
| anchor_memory = 1500 * 1024**2 |
| anchor_memory_bias = 2500 * 1024**2 |
| vram_avail = torch.cuda.get_device_properties(device).total_memory |
| else: |
| |
| anchor_resolution = 8192*8192 |
| anchor_memory = 1 |
| anchor_memory_bias = 0 |
| vram_avail = 1 |
| embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
| inputs = check_dim_and_resize(inputs) |
| h, w = inputs[0].shape[-2:] |
| scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) |
| scale = 1 if scale > 1 else scale |
| scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 |
| if scale < 1: |
| print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") |
| padding = int(16 / scale) |
| padder = InputPadder(inputs[0].shape, padding) |
| inputs = padder.pad(*inputs) |
|
|
| for i in range(iters): |
| print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') |
| outputs = [inputs[0]] |
| for in_0, in_1 in zip(inputs[:-1], inputs[1:]): |
| in_0 = in_0.to(device) |
| in_1 = in_1.to(device) |
| with torch.no_grad(): |
| imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] |
| outputs += [imgt_pred.cpu(), in_1.cpu()] |
| inputs = outputs |
| outputs = padder.unpad(*outputs) |
| out_path = 'results' |
| size = outputs[0].shape[2:][::-1] |
| writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size) |
| for i, imgt_pred in enumerate(outputs): |
| imgt_pred = tensor2img(imgt_pred) |
| imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) |
| writer.write(imgt_pred) |
| writer.release() |
| return 'results/demo.mp4' |
|
|
| |
| def demo_img(): |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| gr.Markdown('## Image Demo') |
| with gr.Row(): |
| gr.HTML( |
| """ |
| <div style="text-align: left; auto;"> |
| <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> |
| Description: With 2 input images, you can generate a short video from them. |
| </h3> |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| img0 = gr.Image(label='Image0') |
| img1 = gr.Image(label='Image1') |
| with gr.Column(): |
| result = gr.Video(label="Generated Video") |
| with gr.Accordion('Advanced options', open=False): |
| ratio = gr.Slider(label='Multiple Ratio', |
| minimum=4, |
| maximum=7, |
| value=6, |
| step=1) |
| frame_ratio = gr.Slider(label='Frame Ratio', |
| minimum=8, |
| maximum=64, |
| value=16, |
| step=1) |
| model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], |
| label='Model Select', |
| value='AMT-S') |
| run_button = gr.Button(label='Run') |
| inputs = [ |
| model_type, |
| img0, |
| img1, |
| frame_ratio, |
| ratio, |
| ] |
|
|
| gr.Examples(examples=glob.glob("examples/*.png"), |
| inputs=img0, |
| label='Example images (drag them to input windows)', |
| run_on_click=False, |
| ) |
|
|
| run_button.click(fn=img2vid, |
| inputs=inputs, |
| outputs=result,) |
| return demo |