Spaces:
Paused
Paused
Commit
·
45378e6
1
Parent(s):
ac802d5
dropdown added
Browse files
app.py
CHANGED
|
@@ -18,11 +18,15 @@ def process_video(
|
|
| 18 |
guidance_scale,
|
| 19 |
inference_steps,
|
| 20 |
seed,
|
|
|
|
| 21 |
):
|
| 22 |
# Create the temp directory if it doesn't exist
|
| 23 |
output_dir = Path("./temp")
|
| 24 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
# Convert paths to absolute Path objects and normalize them
|
| 27 |
video_file_path = Path(video_path)
|
| 28 |
video_path = video_file_path.absolute().as_posix()
|
|
@@ -44,7 +48,7 @@ def process_video(
|
|
| 44 |
)
|
| 45 |
|
| 46 |
# Parse the arguments
|
| 47 |
-
args = create_args(video_path, audio_path, output_path, guidance_scale, seed)
|
| 48 |
|
| 49 |
try:
|
| 50 |
result = main(
|
|
@@ -59,7 +63,7 @@ def process_video(
|
|
| 59 |
|
| 60 |
|
| 61 |
def create_args(
|
| 62 |
-
video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int
|
| 63 |
) -> argparse.Namespace:
|
| 64 |
parser = argparse.ArgumentParser()
|
| 65 |
parser.add_argument("--inference_ckpt_path", type=str, required=True)
|
|
@@ -72,7 +76,7 @@ def create_args(
|
|
| 72 |
return parser.parse_args(
|
| 73 |
[
|
| 74 |
"--inference_ckpt_path",
|
| 75 |
-
|
| 76 |
"--video_path",
|
| 77 |
video_path,
|
| 78 |
"--audio_path",
|
|
@@ -86,6 +90,12 @@ def create_args(
|
|
| 86 |
]
|
| 87 |
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# Create Gradio interface
|
| 91 |
with gr.Blocks(title="SoundImage") as demo:
|
|
@@ -99,6 +109,12 @@ with gr.Blocks(title="SoundImage") as demo:
|
|
| 99 |
|
| 100 |
with gr.Row():
|
| 101 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
video_input = gr.Video(label="Input Video")
|
| 103 |
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
| 104 |
|
|
@@ -139,6 +155,7 @@ with gr.Blocks(title="SoundImage") as demo:
|
|
| 139 |
guidance_scale,
|
| 140 |
inference_steps,
|
| 141 |
seed,
|
|
|
|
| 142 |
],
|
| 143 |
outputs=video_output,
|
| 144 |
)
|
|
|
|
| 18 |
guidance_scale,
|
| 19 |
inference_steps,
|
| 20 |
seed,
|
| 21 |
+
checkpoint_file,
|
| 22 |
):
|
| 23 |
# Create the temp directory if it doesn't exist
|
| 24 |
output_dir = Path("./temp")
|
| 25 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 26 |
|
| 27 |
+
# Use selected checkpoint or fall back to default
|
| 28 |
+
checkpoint_path = Path("checkpoints/unetFiles") / checkpoint_file if checkpoint_file else CHECKPOINT_PATH
|
| 29 |
+
|
| 30 |
# Convert paths to absolute Path objects and normalize them
|
| 31 |
video_file_path = Path(video_path)
|
| 32 |
video_path = video_file_path.absolute().as_posix()
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
# Parse the arguments
|
| 51 |
+
args = create_args(video_path, audio_path, output_path, guidance_scale, seed, checkpoint_path)
|
| 52 |
|
| 53 |
try:
|
| 54 |
result = main(
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
def create_args(
|
| 66 |
+
video_path: str, audio_path: str, output_path: str, guidance_scale: float, seed: int, checkpoint_path: Path
|
| 67 |
) -> argparse.Namespace:
|
| 68 |
parser = argparse.ArgumentParser()
|
| 69 |
parser.add_argument("--inference_ckpt_path", type=str, required=True)
|
|
|
|
| 76 |
return parser.parse_args(
|
| 77 |
[
|
| 78 |
"--inference_ckpt_path",
|
| 79 |
+
checkpoint_path.absolute().as_posix(),
|
| 80 |
"--video_path",
|
| 81 |
video_path,
|
| 82 |
"--audio_path",
|
|
|
|
| 90 |
]
|
| 91 |
)
|
| 92 |
|
| 93 |
+
# Add this function to get checkpoint files
|
| 94 |
+
def get_checkpoint_files():
|
| 95 |
+
unet_files_dir = Path("unetFiles")
|
| 96 |
+
if not unet_files_dir.exists():
|
| 97 |
+
return []
|
| 98 |
+
return [f.name for f in unet_files_dir.glob("*.pt")]
|
| 99 |
|
| 100 |
# Create Gradio interface
|
| 101 |
with gr.Blocks(title="SoundImage") as demo:
|
|
|
|
| 109 |
|
| 110 |
with gr.Row():
|
| 111 |
with gr.Column():
|
| 112 |
+
# Add checkpoint selector dropdown
|
| 113 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 114 |
+
choices=get_checkpoint_files(),
|
| 115 |
+
label="Select Checkpoint",
|
| 116 |
+
value=get_checkpoint_files()[0] if get_checkpoint_files() else None
|
| 117 |
+
)
|
| 118 |
video_input = gr.Video(label="Input Video")
|
| 119 |
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
| 120 |
|
|
|
|
| 155 |
guidance_scale,
|
| 156 |
inference_steps,
|
| 157 |
seed,
|
| 158 |
+
checkpoint_dropdown,
|
| 159 |
],
|
| 160 |
outputs=video_output,
|
| 161 |
)
|