Spaces:
Runtime error
Runtime error
| import torch | |
| # the first flag below was False when we tested this script but True makes A100 training a lot faster: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import os | |
| import spaces | |
| from diffusers.models import AutoencoderKL | |
| from models import FLAV | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from diffusion.rectified_flow import RectifiedFlow | |
| from diffusers.training_utils import EMAModel | |
| from converter import Generator | |
| from utils import * | |
| import tempfile | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| AUDIO_T_PER_FRAME = 1600 // 160 | |
| ################################################################################# | |
| # Global Model Setup # | |
| ################################################################################# | |
| # These variables will be initialized in setup_models() and used in main() | |
| vae = None | |
| model = None | |
| vocoder = None | |
| audio_scale = 3.5009668382765917 | |
| def setup_models(): | |
| global vae, model, vocoder | |
| device = "cuda" | |
| vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema") | |
| vae.eval() | |
| model_ckpt = "MaverickAlex/R-FLAV-B-1-AIST" # MaverickAlex/R-FLAV-B-1-LS | |
| model = FLAV.from_pretrained(model_ckpt) | |
| hf_hub_download(repo_id=model_ckpt, filename="vocoder/config.json") | |
| vocoder_path = hf_hub_download(repo_id=model_ckpt, filename="vocoder/vocoder.pt") | |
| vocoder_path = vocoder_path.replace("vocoder.pt", "") | |
| vocoder = Generator.from_pretrained(vocoder_path) | |
| vae.to(device) | |
| model.to(device) | |
| vocoder.to(device) | |
| def generate_video(num_frames=10, steps=2, seed=42): | |
| global vae, model, vocoder | |
| # Setup device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.manual_seed(seed) | |
| # Set up generation parameters | |
| video_latent_size = (1, 10, 4, 256//8, 256//8) | |
| audio_latent_size = (1, 10, 1, 256, AUDIO_T_PER_FRAME) | |
| rectified_flow = RectifiedFlow(num_timesteps=steps, | |
| warmup_timesteps=10, | |
| window_size=10) | |
| # Generate sample | |
| video, audio = generate_sample( | |
| vae=vae, # These globals are set by setup_models | |
| rectified_flow=rectified_flow, | |
| forward_fn=model.forward, | |
| video_length=num_frames, | |
| video_latent_size=video_latent_size, | |
| audio_latent_size=audio_latent_size, | |
| y=None, | |
| cfg_scale=None, | |
| device=device | |
| ) | |
| # Convert to wav | |
| wavs = get_wavs(audio, vocoder, audio_scale, device) | |
| # Save to temporary files | |
| temp_dir = tempfile.mkdtemp() | |
| video_path = os.path.join(temp_dir, "video", "generated_video.mp4") | |
| # Use the first video and wav | |
| vid, wav = video[0], wavs[0] | |
| save_multimodal(vid, wav, temp_dir, "generated") | |
| return video_path | |
| def ui_generate_video(num_frames, steps, seed): | |
| try: | |
| return generate_video(int(num_frames), int(steps), int(seed)) | |
| except Exception as e: | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks(title="FLAV Video Generator") as demo: | |
| gr.Markdown("# FLAV Video Generator") | |
| gr.Markdown("Generate videos using the FLAV model") | |
| num_frames = None | |
| steps = None | |
| seed = None | |
| video_output = None | |
| with gr.Row(): | |
| with gr.Column(): | |
| num_frames = gr.Slider(minimum=5, maximum=30, step=1, value=10, label="Number of Frames") | |
| steps = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Steps (multiplied by a factor of 10)") | |
| seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") | |
| generate_btn = gr.Button("Generate Video") | |
| with gr.Column(): | |
| video_output = gr.PlayableVideo(label="Generated Video", width=256, height=256) | |
| generate_btn.click( | |
| fn=ui_generate_video, | |
| inputs=[num_frames, steps, seed], | |
| outputs=[video_output] | |
| ) | |
| if __name__ == "__main__": | |
| setup_models() | |
| demo.launch() | |