darshankr commited on
Commit
167fbbe
·
verified ·
1 Parent(s): b5490e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -68
app.py CHANGED
@@ -1,78 +1,65 @@
1
  import gradio as gr
2
- import subprocess
3
  import os
4
- import requests
5
 
6
- def process_video(audio_file, video_file, status):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
- # Unpack the audio and video file paths
9
- audio_path = audio_file[1] if isinstance(audio_file, tuple) else audio_file
10
- video_path = video_file if isinstance(video_file, str) else video_file.name
11
- out_path = "output_video.mp4"
12
-
13
- # Define command flags
14
- sample_mode = "cross" # or "reconstruction"
15
- generate_from_filelist = 0
16
- model_path = "checkpoints/checkpoint.pt"
17
- pads = "0,0,0,0"
18
-
19
- if sample_mode == "reconstruction":
20
- sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
21
- elif sample_mode == "cross":
22
- sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
23
- else:
24
- return None, "Error: sample_mode can only be \"cross\" or \"reconstruction\""
25
-
26
- MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
27
- DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
28
- SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
29
- DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
30
- TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
31
- GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
32
-
33
- # Combine all flags into one command
34
- command = f"python generate.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
35
-
36
- # Execute command
37
  subprocess.run(command, shell=True, check=True)
38
-
39
- # If successful, return the output path and success message
40
- return out_path, "Processing completed successfully!"
41
-
42
- except Exception as e:
43
- # If there's an error, return None for the video and the error message
44
- return None, f"Error during processing: {str(e)}"
45
-
46
  finally:
47
- # Clean up output file if it exists
 
 
 
 
48
  if os.path.exists(out_path):
49
  os.remove(out_path)
50
 
51
- # Create a Gradio interface
52
- with gr.Blocks() as iface:
53
- gr.Markdown("# Audio-Video Processing")
54
- gr.Markdown("Upload an audio file and a video file to process the video based on the audio input.")
55
-
56
- with gr.Row():
57
- audio_input = gr.Audio(label="Input Audio")
58
- video_input = gr.Video(label="Input Video")
59
-
60
- status_msg = gr.Textbox(label="Status", interactive=False)
61
- process_button = gr.Button("Process Video")
62
- video_output = gr.Video(label="Processed Video")
63
-
64
- process_button.click(
65
- fn=process_video,
66
- inputs=[
67
- audio_input,
68
- video_input,
69
- status_msg
70
- ],
71
- outputs=[
72
- video_output,
73
- status_msg
74
- ]
75
- )
76
 
77
- # Launch the interface
78
- iface.launch()
 
1
  import gradio as gr
 
2
  import os
3
+ import subprocess
4
 
5
+ # Replace with your model loading and processing function
6
+ def process_audio_video(audio_file, video_file):
7
+ audio_path = "input_audio.wav"
8
+ video_path = "input_video.mp4"
9
+ out_path = "output_video.mp4"
10
+
11
+ # Save uploaded files
12
+ audio_file.save(audio_path)
13
+ video_file.save(video_path)
14
+
15
+ # Define command flags
16
+ sample_mode = "cross" # or "reconstruction"
17
+ generate_from_filelist = 0
18
+ model_path = "checkpoints/checkpoint.pt"
19
+ pads = "0,0,0,0"
20
+
21
+ if sample_mode == "reconstruction":
22
+ sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
23
+ elif sample_mode == "cross":
24
+ sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
25
+ else:
26
+ return "Error: sample_mode can only be \"cross\" or \"reconstruction\""
27
+
28
+ MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
29
+ DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
30
+ SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
31
+ DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
32
+ TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
33
+ GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
34
+
35
+ # Combine all flags into one command
36
+ command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
37
+
38
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  subprocess.run(command, shell=True, check=True)
40
+ return out_path
41
+ except subprocess.CalledProcessError as e:
42
+ return f"Error processing video: {e}"
 
 
 
 
 
43
  finally:
44
+ # Clean up the files after processing
45
+ if os.path.exists(audio_path):
46
+ os.remove(audio_path)
47
+ if os.path.exists(video_path):
48
+ os.remove(video_path)
49
  if os.path.exists(out_path):
50
  os.remove(out_path)
51
 
52
+ # Define the Gradio interface
53
+ interface = gr.Interface(
54
+ fn=process_audio_video,
55
+ inputs=[
56
+ gr.Audio(label="Audio File"),
57
+ gr.Video(label="Video File"),
58
+ ],
59
+ outputs="video",
60
+ description="Process Audio and Video with your Model",
61
+ allow_flagging=False # Disable flagging as output is a video
62
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Launch the Gradio app
65
+ interface.launch(share=True)