Spaces:
Build error
Build error
| import spaces | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.models import AutoencoderKL | |
| from diffusers.schedulers import PNDMScheduler | |
| from unet import AudioUNet3DConditionModel | |
| from audio_encoder import ImageBindSegmaskAudioEncoder | |
| from pipeline import AudioCondAnimationPipeline, generate_videos | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| def freeze_and_make_eval(model: nn.Module): | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| def create_pipeline(device=torch.device("cuda"), dtype=torch.float32): | |
| # 2. Prepare model | |
| pretrained_stable_diffusion_path = "./pretrained/stable-diffusion-v1-5" | |
| checkpoint_path = f"checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules" | |
| category_text_encoding_mapping = torch.load('datasets/AVSync15/class_clip_text_encodings_stable-diffusion-v1-5.pt', map_location="cpu") | |
| scheduler = PNDMScheduler.from_pretrained(pretrained_stable_diffusion_path, subfolder="scheduler") | |
| vae = AutoencoderKL.from_pretrained(pretrained_stable_diffusion_path, subfolder="vae").to(device=device, dtype=dtype) | |
| audio_encoder = ImageBindSegmaskAudioEncoder(n_segment=12).to(device=device, dtype=dtype) | |
| freeze_and_make_eval(audio_encoder) | |
| unet = AudioUNet3DConditionModel.from_pretrained(checkpoint_path, subfolder="unet").to(device=device, dtype=dtype) | |
| pipeline = AudioCondAnimationPipeline( | |
| unet=unet, | |
| scheduler=scheduler, | |
| vae=vae, | |
| audio_encoder=audio_encoder, | |
| null_text_encodings_path="./pretrained/openai-clip-l_null_text_encoding.pt" | |
| ) | |
| pipeline.to(torch_device=device, dtype=dtype) | |
| pipeline.set_progress_bar_config(disable=True) | |
| return pipeline, category_text_encoding_mapping | |
| pipeline, category_text_encoding_mapping = create_pipeline(device, dtype) | |
| def generate_video(image, audio, text, audio_guidance_scale, denoising_step): | |
| category_text_encoding = category_text_encoding_mapping[text].view(1, 77, 768) | |
| generate_videos( | |
| pipeline, | |
| audio_path=audio, | |
| image_path=image, | |
| category_text_encoding=category_text_encoding, | |
| image_size=(256, 256), | |
| video_fps=6, | |
| video_num_frame=12, | |
| text_guidance_scale=1.0, | |
| audio_guidance_scale=audio_guidance_scale, | |
| denoising_step=denoising_step, | |
| seed=123, | |
| save_path="./output_video.mp4", | |
| device=device | |
| ) | |
| return "./output_video.mp4" | |
| if __name__ == "__main__": | |
| categories = [ | |
| "baby babbling crying", "dog barking", "hammering", "striking bowling", "cap gun shooting", | |
| "chicken crowing", "frog croaking", "lions roaring", "machine gun shooting", "playing cello", | |
| "playing trombone", "playing trumpet", "playing violin fiddle", "sharpen knife", "toilet flushing" | |
| ] | |
| title = "" | |
| description = """ | |
| <div align="center"> | |
| <h1 style="font-size: 60px;">Audio-Synchronized Visual Animation</h1> | |
| <p style="font-size: 30px;"> | |
| <a href="https://lzhangbj.github.io/projects/asva/asva.html">Project Webpage</a> | |
| </p> | |
| <p style="font-size: 30px;"> | |
| <a href="https://lzhangbj.github.io/">Lin Zhang</a>, | |
| <a href="https://scholar.google.com/citations?user=6aYncPAAAAAJ">Shentong Mo</a>, | |
| <a href="https://yijingz02.github.io/">Yijing Zhang</a>, | |
| <a href="https://pedro-morgado.github.io/">Pedro Morgado</a> | |
| </p> | |
| <p style="font-size: 30px;"> | |
| University of Wisconsin Madison, | |
| Carnegie Mellon University | |
| <p> | |
| <strong style="font-size: 30px;">ECCV 2024</strong> | |
| <strong style="font-size: 25px;">Animate your images with audio-synchronized motion! </strong> | |
| <p style="font-size: 18px;">Notes:</p> | |
| <p style="font-size: 18px;">(1) Only the first 2 seconds of audio is used. </p> | |
| <p style="font-size: 18px;">(2) Increase audio guidance scale for amplified visual dynamics. </p> | |
| <p style="font-size: 18px;">(3) Increase sampling steps for higher visual quality. </p> | |
| </div> | |
| """ | |
| # <p style="font-size: 20px;">Please be patient. Due to limited resources on huggingface, the generation may take up to 10mins </p> | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=generate_video, | |
| inputs=[ | |
| gr.Image( label="Upload Image", type="filepath", height=256), | |
| gr.Audio(label="Upload Audio", type="filepath"), | |
| gr.Dropdown(choices=categories, label="Select Audio Category"), | |
| gr.Slider(minimum=1.0, maximum=12.0, step=0.1, value=4.0, label="Audio Guidance Scale"), | |
| gr.Slider(minimum=1, maximum=50, step=1, value=20, label="Sampling steps") | |
| ], | |
| outputs=gr.Video(label="Generated Video", height=256), | |
| title=title, | |
| description=description, | |
| examples = [ | |
| ["./assets/lion_and_gun.png", "./assets/lions_roaring.wav", "lions roaring", 4.0, 20], | |
| ["./assets/lion_and_gun.png", "./assets/machine_gun_shooting.wav", "machine gun shooting", 4.0, 20], | |
| ] | |
| ) | |
| # Launch the interface | |
| iface.launch() |