darshankr commited on
Commit
b673ad3
·
verified ·
1 Parent(s): 98bc45c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -51
app.py CHANGED
@@ -3,41 +3,46 @@ import subprocess
3
  import os
4
  import requests
5
 
6
- def process_video(audio_file, video_file):
7
- print(gradio.__version__)
8
- # Unpack the audio and video file paths
9
- audio_path = audio_file[1] if isinstance(audio_file, tuple) else audio_file
10
- video_path = video_file if isinstance(video_file, str) else video_file.name
11
- out_path = "output_video.mp4"
12
-
13
- # Define command flags
14
- sample_mode = "cross" # or "reconstruction"
15
- generate_from_filelist = 0
16
- model_path = "checkpoints/checkpoint.pt"
17
- pads = "0,0,0,0"
18
-
19
- if sample_mode == "reconstruction":
20
- sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
21
- elif sample_mode == "cross":
22
- sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
23
- else:
24
- return "Error: sample_mode can only be \"cross\" or \"reconstruction\""
25
-
26
- MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
27
- DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
28
- SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
29
- DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
30
- TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
31
- GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
32
-
33
- # Combine all flags into one command
34
- command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
35
-
36
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  subprocess.run(command, shell=True, check=True)
38
- return out_path
39
- except subprocess.CalledProcessError as e:
40
- return f"Error processing video: {e}"
 
 
 
 
 
41
  finally:
42
  # Clean up output file if it exists
43
  if os.path.exists(out_path):
@@ -52,27 +57,21 @@ with gr.Blocks() as iface:
52
  audio_input = gr.Audio(label="Input Audio")
53
  video_input = gr.Video(label="Input Video")
54
 
55
- with gr.Row():
56
- process_button = gr.Button("Process Video")
57
- status_msg = gr.Textbox(label="Status", interactive=False)
58
-
59
  video_output = gr.Video(label="Processed Video")
60
 
61
- def process_with_status(audio, video):
62
- try:
63
- status_msg.update(value="Processing... Please wait.")
64
- result = process_video(audio, video)
65
- status_msg.update(value="Done!")
66
- return [result, "Processing completed successfully!"]
67
- except Exception as e:
68
- error_msg = f"Error during processing: {str(e)}"
69
- status_msg.update(value=error_msg)
70
- return [None, error_msg]
71
-
72
  process_button.click(
73
- fn=process_with_status,
74
- inputs=[audio_input, video_input],
75
- outputs=[video_output, status_msg]
 
 
 
 
 
 
 
76
  )
77
 
78
  # Launch the interface
 
3
  import os
4
  import requests
5
 
6
+ def process_video(audio_file, video_file, status):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  try:
8
+ # Unpack the audio and video file paths
9
+ audio_path = audio_file[1] if isinstance(audio_file, tuple) else audio_file
10
+ video_path = video_file if isinstance(video_file, str) else video_file.name
11
+ out_path = "output_video.mp4"
12
+
13
+ # Define command flags
14
+ sample_mode = "cross" # or "reconstruction"
15
+ generate_from_filelist = 0
16
+ model_path = "checkpoints/checkpoint.pt"
17
+ pads = "0,0,0,0"
18
+
19
+ if sample_mode == "reconstruction":
20
+ sample_input_flags = "--sampling_input_type=first_frame --sampling_ref_type=first_frame"
21
+ elif sample_mode == "cross":
22
+ sample_input_flags = "--sampling_input_type=gt --sampling_ref_type=gt"
23
+ else:
24
+ return None, "Error: sample_mode can only be \"cross\" or \"reconstruction\""
25
+
26
+ MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --learn_sigma True --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm False"
27
+ DIFFUSION_FLAGS = "--predict_xstart False --diffusion_steps 1000 --noise_schedule linear --rescale_timesteps False"
28
+ SAMPLE_FLAGS = f"--sampling_seed=7 {sample_input_flags} --timestep_respacing ddim25 --use_ddim True --model_path={model_path}"
29
+ DATA_FLAGS = "--nframes 5 --nrefer 1 --image_size 128 --sampling_batch_size=32"
30
+ TFG_FLAGS = "--face_hide_percentage 0.5 --use_ref=True --use_audio=True --audio_as_style=True"
31
+ GEN_FLAGS = f"--generate_from_filelist {generate_from_filelist} --video_path={video_path} --audio_path={audio_path} --out_path={out_path} --save_orig=False --face_det_batch_size 16 --pads {pads} --is_voxceleb2=False"
32
+
33
+ # Combine all flags into one command
34
+ command = f"python your_model_script.py {MODEL_FLAGS} {DIFFUSION_FLAGS} {SAMPLE_FLAGS} {DATA_FLAGS} {TFG_FLAGS} {GEN_FLAGS}"
35
+
36
+ # Execute command
37
  subprocess.run(command, shell=True, check=True)
38
+
39
+ # If successful, return the output path and success message
40
+ return out_path, "Processing completed successfully!"
41
+
42
+ except Exception as e:
43
+ # If there's an error, return None for the video and the error message
44
+ return None, f"Error during processing: {str(e)}"
45
+
46
  finally:
47
  # Clean up output file if it exists
48
  if os.path.exists(out_path):
 
57
  audio_input = gr.Audio(label="Input Audio")
58
  video_input = gr.Video(label="Input Video")
59
 
60
+ status_msg = gr.Textbox(label="Status", interactive=False)
61
+ process_button = gr.Button("Process Video")
 
 
62
  video_output = gr.Video(label="Processed Video")
63
 
 
 
 
 
 
 
 
 
 
 
 
64
  process_button.click(
65
+ fn=process_video,
66
+ inputs=[
67
+ audio_input,
68
+ video_input,
69
+ status_msg
70
+ ],
71
+ outputs=[
72
+ video_output,
73
+ status_msg
74
+ ]
75
  )
76
 
77
  # Launch the interface