darshankr commited on
Commit
ba39569
·
verified ·
1 Parent(s): 8a9b78c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -65
app.py CHANGED
@@ -1,88 +1,89 @@
1
  import gradio as gr
2
- import os
3
  import subprocess
 
 
 
4
 
5
- # Define the paths where the output video will be stored
6
- OUTPUT_VIDEO_PATH = "output_video.mp4"
7
- MODEL_PATH = "checkpoints/checkpoint.pt"
 
 
8
 
9
- # Sample mode configuration
10
- SAMPLE_MODE = "cross" # Options: "cross" or "reconstruction"
11
- PADS = "0,0,0,0"
12
- GENERATE_FROM_FILELIST = 0
 
 
 
 
13
 
14
- # Generate the appropriate flags based on the sample mode
15
- def get_sample_flags(sample_mode):
16
  if sample_mode == "reconstruction":
17
- return "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
18
  elif sample_mode == "cross":
19
- return "--sampling_input_type=gt --sampling_ref_type=gt"
20
  else:
21
- return None
22
-
23
- # Function to run the model inference command
24
- def generate_video(audio_path, video_path):
25
- sample_input_flags = get_sample_flags(SAMPLE_MODE)
26
- if not sample_input_flags:
27
  return "Error: sample_mode can only be 'cross' or 'reconstruction'"
28
 
29
- # Build the command string
30
- MODEL_FLAGS = (
31
- "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True "
32
- "--num_channels 128 --num_head_channels 64 --num_res_blocks 2 "
33
- "--resblock_updown True --use_fp16 True --use_scale_shift_norm False"
34
- )
35
- DIFFUSION_FLAGS = (
36
- "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear "
37
- "--rescale_timesteps False"
38
- )
39
- SAMPLE_FLAGS = (
40
- f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 "
41
- f"--use_ddim True --model_path={MODEL_PATH}"
42
- )
43
  DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
44
- TFG_FLAGS = (
45
- "--face_hide_percentage 0.5 --use_ref=True --use_audio=True "
46
- "--audio_as_style=True"
47
- )
48
- GEN_FLAGS = (
49
- f"--generate_from_filelist {GENERATE_FROM_FILELIST} "
50
- f"--video_path={video_path} --audio_path={audio_path} "
51
- f"--out_path={OUTPUT_VIDEO_PATH} --save_orig=False "
52
- f"--face_det_batch_size 16 --pads {PADS} --is_voxceleb2=False"
53
- )
54
 
55
- command = (
56
- f"python generate.py {MODEL_FLAGS} {DIFFUSION_FLAGS} "
57
- f"{SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
58
- )
59
 
60
- # Run the command and wait for it to complete
61
- process = subprocess.run(command, shell=True, capture_output=True, text=True)
 
 
62
 
63
- if process.returncode != 0:
64
- return f"Error: {process.stderr}"
65
 
66
- # Return the generated video file
67
- return OUTPUT_VIDEO_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Gradio interface
70
  with gr.Blocks() as demo:
71
- gr.Markdown("## Audio-Video Synthesis Model")
72
 
73
- with gr.Row():
74
- audio_input = gr.Audio(label="Upload Audio", type="filepath")
75
- video_input = gr.Video(label="Upload Video") # No 'type' argument here
76
-
77
  output_video = gr.Video(label="Generated Video")
78
 
79
- generate_button = gr.Button("Generate")
 
 
 
 
 
 
 
80
 
81
- generate_button.click(
82
- fn=generate_video,
83
- inputs=[audio_input, video_input],
84
  outputs=output_video
85
- )
86
 
87
- if __name__ == "__main__":
88
- demo.launch()
 
1
  import gradio as gr
 
2
  import subprocess
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
 
7
+ # Paths and Model Config
8
+ sample_mode = "cross" # "reconstruction" or "cross"
9
+ model_path = "checkpoints/checkpoint.pt"
10
+ pads = "0,0,0,0"
11
+ generate_from_filelist = 0 # 0 means real-time generation
12
 
13
+ def process_video(audio_path, video_path):
14
+ # Step 1: Check if input files exist
15
+ audio_exists = os.path.exists(audio_path)
16
+ video_exists = os.path.exists(video_path)
17
+ print(f"Audio exists: {audio_exists}, Video exists: {video_exists}")
18
+
19
+ if not (audio_exists and video_exists):
20
+ return "Error: One or both input files do not exist."
21
 
22
+ # Set flags based on sample mode
 
23
  if sample_mode == "reconstruction":
24
+ sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
25
  elif sample_mode == "cross":
26
+ sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
27
  else:
 
 
 
 
 
 
28
  return "Error: sample_mode can only be 'cross' or 'reconstruction'"
29
 
30
+ # Model flags and configurations
31
+ MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
32
+ DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
33
+ SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
 
 
 
 
 
 
 
 
 
 
34
  DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
35
+ TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
36
+ GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path=output.mp4 --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
 
 
 
 
 
 
 
 
37
 
38
+ # Step 2: Combine all flags into one command
39
+ command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
40
+ print(f"Running command: {command}")
 
41
 
42
+ # Step 3: Execute the command and capture output
43
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
44
+ print("STDOUT:", result.stdout)
45
+ print("STDERR:", result.stderr)
46
 
47
+ if result.returncode != 0:
48
+ return f"Error during video generation: {result.stderr}"
49
 
50
+ # Step 4: Verify that the output video is generated correctly
51
+ if not os.path.exists("output.mp4"):
52
+ return "Error: Output video not generated."
53
+
54
+ print("Video generation successful!")
55
+ return "output.mp4"
56
+
57
+ # Step 5: Create a test function for video writing
58
+ def create_test_video():
59
+ print("Creating test video...")
60
+ out = cv2.VideoWriter('test_output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (128, 128))
61
+ frame = 255 * np.ones((128, 128, 3), dtype=np.uint8)
62
+ for _ in range(60): # 2 seconds of video
63
+ out.write(frame)
64
+ out.release()
65
+ print("Test video created.")
66
 
67
+ # Gradio Interface
68
  with gr.Blocks() as demo:
69
+ gr.Markdown("### Upload an Audio and Video file to generate an output video.")
70
 
71
+ audio_input = gr.Audio(label="Upload Audio", type="filepath")
72
+ video_input = gr.Video(label="Upload Video")
 
 
73
  output_video = gr.Video(label="Generated Video")
74
 
75
+ create_test_video() # Run the test video function once to ensure setup is correct
76
+
77
+ def inference(audio, video):
78
+ result = process_video(audio, video)
79
+ if result.endswith(".mp4"):
80
+ return result # Return path to the generated video
81
+ else:
82
+ return f"Error: {result}" # Display any errors
83
 
84
+ gr.Interface(
85
+ fn=inference,
86
+ inputs=[audio_input, video_input],
87
  outputs=output_video
88
+ ).launch()
89