darshankr commited on
Commit
2142346
·
verified ·
1 Parent(s): 167fbbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -54
app.py CHANGED
@@ -2,64 +2,93 @@ import gradio as gr
2
  import os
3
  import subprocess
4
 
5
- # Replace with your model loading and processing function
6
- def process_audio_video(audio_file, video_file):
7
- audio_path = "input_audio.wav"
8
- video_path = "input_video.mp4"
9
- out_path = "output_video.mp4"
10
-
11
- # Save uploaded files
12
- audio_file.save(audio_path)
13
- video_file.save(video_path)
14
-
15
- # Define command flags
16
- sample_mode = "cross" # or "reconstruction"
17
- generate_from_filelist = 0
18
- model_path = "checkpoints/checkpoint.pt"
19
- pads = "0,0,0,0"
20
-
21
  if sample_mode == "reconstruction":
22
- sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
23
  elif sample_mode == "cross":
24
- sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
25
  else:
26
- return "Error: sample_mode can only be \"cross\" or \"reconstruction\""
27
-
28
- MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
29
- DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
30
- SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
32
- TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
33
- GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
34
-
35
- # Combine all flags into one command
36
- command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- try:
39
- subprocess.run(command, shell=True, check=True)
40
- return out_path
41
- except subprocess.CalledProcessError as e:
42
- return f"Error processing video: {e}"
43
- finally:
44
- # Clean up the files after processing
45
- if os.path.exists(audio_path):
46
- os.remove(audio_path)
47
- if os.path.exists(video_path):
48
- os.remove(video_path)
49
- if os.path.exists(out_path):
50
- os.remove(out_path)
51
 
52
- # Define the Gradio interface
53
- interface = gr.Interface(
54
- fn=process_audio_video,
55
- inputs=[
56
- gr.Audio(label="Audio File"),
57
- gr.Video(label="Video File"),
58
- ],
59
- outputs="video",
60
- description="Process Audio and Video with your Model",
61
- allow_flagging=False # Disable flagging as output is a video
62
- )
63
 
64
- # Launch the Gradio app
65
- interface.launch(share=True)
 
2
  import os
3
  import subprocess
4
 
5
+ # Define the paths where the input and output files will be stored
6
+ INPUT_AUDIO_PATH = "input_audio.wav"
7
+ INPUT_VIDEO_PATH = "input_video.mp4"
8
+ OUTPUT_VIDEO_PATH = "output_video.mp4"
9
+ MODEL_PATH = "checkpoints/checkpoint.pt"
10
+
11
+ # Sample mode configuration
12
+ SAMPLE_MODE = "cross" # Options: "cross" or "reconstruction"
13
+ PADS = "0,0,0,0"
14
+ GENERATE_FROM_FILELIST = 0
15
+
16
+ # Generate the appropriate flags based on the sample mode
17
+ def get_sample_flags(sample_mode):
 
 
 
18
  if sample_mode == "reconstruction":
19
+ return "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
20
  elif sample_mode == "cross":
21
+ return "--sampling_input_type=gt --sampling_ref_type=gt"
22
  else:
23
+ return None
24
+
25
+ # Function to run the model inference command
26
+ def generate_video(audio_file, video_file):
27
+ # Save uploaded files to disk
28
+ audio_file.save(INPUT_AUDIO_PATH)
29
+ video_file.save(INPUT_VIDEO_PATH)
30
+
31
+ sample_input_flags = get_sample_flags(SAMPLE_MODE)
32
+ if not sample_input_flags:
33
+ return "Error: sample_mode can only be 'cross' or 'reconstruction'"
34
+
35
+ # Build the command string
36
+ MODEL_FLAGS = (
37
+ "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True "
38
+ "--num_channels 128 --num_head_channels 64 --num_res_blocks 2 "
39
+ "--resblock_updown True --use_fp16 True --use_scale_shift_norm False"
40
+ )
41
+ DIFFUSION_FLAGS = (
42
+ "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear "
43
+ "--rescale_timesteps False"
44
+ )
45
+ SAMPLE_FLAGS = (
46
+ f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 "
47
+ f"--use_ddim True --model_path={MODEL_PATH}"
48
+ )
49
  DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
50
+ TFG_FLAGS = (
51
+ "--face_hide_percentage 0.5 --use_ref=True --use_audio=True "
52
+ "--audio_as_style=True"
53
+ )
54
+ GEN_FLAGS = (
55
+ f"--generate_from_filelist {GENERATE_FROM_FILELIST} "
56
+ f"--video_path={INPUT_VIDEO_PATH} --audio_path={INPUT_AUDIO_PATH} "
57
+ f"--out_path={OUTPUT_VIDEO_PATH} --save_orig=False "
58
+ f"--face_det_batch_size 16 --pads {PADS} --is_voxceleb2=False"
59
+ )
60
+
61
+ command = (
62
+ f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} "
63
+ f"{SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
64
+ )
65
+
66
+ # Run the command and wait for it to complete
67
+ process = subprocess.run(command, shell=True, capture_output=True, text=True)
68
+
69
+ if process.returncode != 0:
70
+ return f"Error: {process.stderr}"
71
+
72
+ # Return the generated video file
73
+ return OUTPUT_VIDEO_PATH
74
+
75
+ # Gradio interface
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("## Audio-Video Synthesis Model")
78
 
79
+ with gr.Row():
80
+ audio_input = gr.Audio(label="Upload Audio", type="file")
81
+ video_input = gr.Video(label="Upload Video", type="file")
82
+
83
+ output_video = gr.Video(label="Generated Video")
84
+
85
+ generate_button = gr.Button("Generate")
 
 
 
 
 
 
86
 
87
+ generate_button.click(
88
+ fn=generate_video,
89
+ inputs=[audio_input, video_input],
90
+ outputs=output_video
91
+ )
 
 
 
 
 
 
92
 
93
+ if __name__ == "__main__":
94
+ demo.launch()