| import gradio as gr |
| import os |
| import subprocess |
|
|
| |
| OUTPUT_VIDEO_PATH = "output_video.mp4" |
| MODEL_PATH = "checkpoints/checkpoint.pt" |
|
|
| |
| SAMPLE_MODE = "cross" |
| PADS = "0,0,0,0" |
| GENERATE_FROM_FILELIST = 0 |
|
|
| |
| def get_sample_flags(sample_mode): |
| if sample_mode == "reconstruction": |
| return "--sampling_input_type=first_frame --sampling_ref_type=first_frame" |
| elif sample_mode == "cross": |
| return "--sampling_input_type=gt --sampling_ref_type=gt" |
| else: |
| return None |
|
|
| |
| def generate_video(audio_path, video_path): |
| sample_input_flags = get_sample_flags(SAMPLE_MODE) |
| if not sample_input_flags: |
| return "Error: sample_mode can only be 'cross' or 'reconstruction'" |
|
|
| |
| MODEL_FLAGS = ( |
| "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True " |
| "--num_channels 128 --num_head_channels 64 --num_res_blocks 2 " |
| "--resblock_updown True --use_fp16 True --use_scale_shift_norm False" |
| ) |
| DIFFUSION_FLAGS = ( |
| "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear " |
| "--rescale_timesteps False" |
| ) |
| SAMPLE_FLAGS = ( |
| f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 " |
| f"--use_ddim True --model_path={MODEL_PATH}" |
| ) |
| DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32" |
| TFG_FLAGS = ( |
| "--face_hide_percentage 0.5 --use_ref=True --use_audio=True " |
| "--audio_as_style=True" |
| ) |
| GEN_FLAGS = ( |
| f"--generate_from_filelist {GENERATE_FROM_FILELIST} " |
| f"--video_path={video_path} --audio_path={audio_path} " |
| f"--out_path={OUTPUT_VIDEO_PATH} --save_orig=False " |
| f"--face_det_batch_size 16 --pads {PADS} --is_voxceleb2=False" |
| ) |
|
|
| command = ( |
| f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} " |
| f"{SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}" |
| ) |
|
|
| |
| process = subprocess.run(command, shell=True, capture_output=True, text=True) |
|
|
| if process.returncode != 0: |
| return f"Error: {process.stderr}" |
|
|
| |
| return OUTPUT_VIDEO_PATH |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## Audio-Video Synthesis Model") |
| |
| with gr.Row(): |
| audio_input = gr.Audio(label="Upload Audio", type="filepath") |
| video_input = gr.Video(label="Upload Video") |
|
|
| output_video = gr.Video(label="Generated Video") |
|
|
| generate_button = gr.Button("Generate") |
|
|
| generate_button.click( |
| fn=generate_video, |
| inputs=[audio_input, video_input], |
| outputs=output_video |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|