diff2lip / app.py
darshankr's picture
Update app.py
3d7f5ab verified
raw
history blame
2.99 kB
import gradio as gr
import os
import subprocess
# Define the paths where the output video will be stored
OUTPUT_VIDEO_PATH = "output_video.mp4"
MODEL_PATH = "checkpoints/checkpoint.pt"
# Sample mode configuration
SAMPLE_MODE = "cross" # Options: "cross" or "reconstruction"
PADS = "0,0,0,0"
GENERATE_FROM_FILELIST = 0
# Generate the appropriate flags based on the sample mode
def get_sample_flags(sample_mode):
if sample_mode == "reconstruction":
return "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
elif sample_mode == "cross":
return "--sampling_input_type=gt --sampling_ref_type=gt"
else:
return None
# Function to run the model inference command
def generate_video(audio_path, video_path):
sample_input_flags = get_sample_flags(SAMPLE_MODE)
if not sample_input_flags:
return "Error: sample_mode can only be 'cross' or 'reconstruction'"
# Build the command string
MODEL_FLAGS = (
"--attention_resolutions 32,16,8 --class_cond False --learn_sigma True "
"--num_channels 128 --num_head_channels 64 --num_res_blocks 2 "
"--resblock_updown True --use_fp16 True --use_scale_shift_norm False"
)
DIFFUSION_FLAGS = (
"--predict_xstart False --diffusion_steps 1000 --noise_schedule linear "
"--rescale_timesteps False"
)
SAMPLE_FLAGS = (
f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 "
f"--use_ddim True --model_path={MODEL_PATH}"
)
DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
TFG_FLAGS = (
"--face_hide_percentage 0.5 --use_ref=True --use_audio=True "
"--audio_as_style=True"
)
GEN_FLAGS = (
f"--generate_from_filelist {GENERATE_FROM_FILELIST} "
f"--video_path={video_path} --audio_path={audio_path} "
f"--out_path={OUTPUT_VIDEO_PATH} --save_orig=False "
f"--face_det_batch_size 16 --pads {PADS} --is_voxceleb2=False"
)
command = (
f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} "
f"{SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
)
# Run the command and wait for it to complete
process = subprocess.run(command, shell=True, capture_output=True, text=True)
if process.returncode != 0:
return f"Error: {process.stderr}"
# Return the generated video file
return OUTPUT_VIDEO_PATH
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Audio-Video Synthesis Model")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio", type="filepath")
video_input = gr.Video(label="Upload Video") # No 'type' argument here
output_video = gr.Video(label="Generated Video")
generate_button = gr.Button("Generate")
generate_button.click(
fn=generate_video,
inputs=[audio_input, video_input],
outputs=output_video
)
if __name__ == "__main__":
demo.launch()