Spaces:
Sleeping
Sleeping
roychao19477
commited on
Commit
·
2cb0aee
1
Parent(s):
0b3d66c
Test on lengths
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ import shutil
|
|
| 7 |
import glob
|
| 8 |
import gradio as gr
|
| 9 |
|
|
|
|
|
|
|
| 10 |
# install packages for mamba
|
| 11 |
def install_mamba():
|
| 12 |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
|
@@ -63,8 +65,6 @@ from moviepy import ImageSequenceClip
|
|
| 63 |
from scipy.io import wavfile
|
| 64 |
from avse_code import run_avse
|
| 65 |
|
| 66 |
-
# Load face detector
|
| 67 |
-
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
| 68 |
|
| 69 |
|
| 70 |
from decord import VideoReader, cpu
|
|
@@ -75,15 +75,18 @@ import spaces
|
|
| 75 |
# Load model once globally
|
| 76 |
#ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
| 77 |
#model = AVSEModule.load_from_checkpoint(ckpt_path)
|
| 78 |
-
avse_model = AVSEModule()
|
| 79 |
#avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
avse_model.eval()
|
| 84 |
|
| 85 |
@spaces.GPU
|
| 86 |
def run_avse_inference(video_path, audio_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
estimated = run_avse(video_path, audio_path)
|
| 88 |
# Load audio
|
| 89 |
#noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
|
@@ -101,15 +104,39 @@ def run_avse_inference(video_path, audio_path):
|
|
| 101 |
]).astype(np.float32)
|
| 102 |
bg_frames /= 255.0
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# Combine into input dict (match what model.enhance expects)
|
| 106 |
-
data = {
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
with torch.no_grad():
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# Save result
|
| 115 |
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
|
@@ -135,9 +162,32 @@ def extract_resampled_audio(video_path, target_sr=16000):
|
|
| 135 |
torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
|
| 136 |
return resampled_audio_path
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
@spaces.GPU
|
| 140 |
def extract_faces(video_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
cap = cv2.VideoCapture(video_file)
|
| 142 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 143 |
frames = []
|
|
@@ -148,7 +198,8 @@ def extract_faces(video_file):
|
|
| 148 |
break
|
| 149 |
|
| 150 |
# Inference
|
| 151 |
-
results = model(frame, verbose=False)[0]
|
|
|
|
| 152 |
for box in results.boxes:
|
| 153 |
# version 1
|
| 154 |
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
@@ -218,14 +269,7 @@ def extract_faces(video_file):
|
|
| 218 |
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
| 219 |
|
| 220 |
|
| 221 |
-
|
| 222 |
-
flipped_output_path = os.path.join(tmpdir, "face_only_video_flipped.mp4")
|
| 223 |
-
flipped_clip = VideoFileClip(output_path, fps=25)
|
| 224 |
-
flipped_clip = flipped_clip.fx(vfx.mirror_y)
|
| 225 |
-
flipped_clip.write_videofile(flipped_output_path, codec="libx264", audio=False, fps=25)
|
| 226 |
-
|
| 227 |
-
return flipped_output_path, enhanced_audio_path
|
| 228 |
-
#return output_path, enhanced_audio_path
|
| 229 |
#return output_path, audio_path
|
| 230 |
|
| 231 |
iface = gr.Interface(
|
|
@@ -237,7 +281,9 @@ iface = gr.Interface(
|
|
| 237 |
gr.Audio(label="Enhanced Audio", type="filepath")
|
| 238 |
],
|
| 239 |
title="Face Detector",
|
| 240 |
-
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
iface.launch()
|
|
|
|
|
|
| 7 |
import glob
|
| 8 |
import gradio as gr
|
| 9 |
|
| 10 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 11 |
+
|
| 12 |
# install packages for mamba
|
| 13 |
def install_mamba():
|
| 14 |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
|
|
|
| 65 |
from scipy.io import wavfile
|
| 66 |
from avse_code import run_avse
|
| 67 |
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
from decord import VideoReader, cpu
|
|
|
|
| 75 |
# Load model once globally
|
| 76 |
#ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
| 77 |
#model = AVSEModule.load_from_checkpoint(ckpt_path)
|
|
|
|
| 78 |
#avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
|
| 79 |
+
|
| 80 |
+
CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz
|
| 81 |
+
CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec
|
|
|
|
| 82 |
|
| 83 |
@spaces.GPU
|
| 84 |
def run_avse_inference(video_path, audio_path):
|
| 85 |
+
avse_model = AVSEModule()
|
| 86 |
+
avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
|
| 87 |
+
avse_model.load_state_dict(avse_state_dict, strict=True)
|
| 88 |
+
avse_model.to("cuda")
|
| 89 |
+
avse_model.eval()
|
| 90 |
estimated = run_avse(video_path, audio_path)
|
| 91 |
# Load audio
|
| 92 |
#noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
|
|
|
| 104 |
]).astype(np.float32)
|
| 105 |
bg_frames /= 255.0
|
| 106 |
|
| 107 |
+
audio_chunks = [
|
| 108 |
+
noisy[i:i + CHUNK_SIZE_AUDIO]
|
| 109 |
+
for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
video_chunks = [
|
| 113 |
+
bg_frames[i:i + CHUNK_SIZE_VIDEO]
|
| 114 |
+
for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
min_len = min(len(audio_chunks), len(video_chunks)) # sync length
|
| 118 |
+
|
| 119 |
|
| 120 |
# Combine into input dict (match what model.enhance expects)
|
| 121 |
+
#data = {
|
| 122 |
+
# "noisy_audio": noisy,
|
| 123 |
+
# "video_frames": bg_frames[np.newaxis, ...]
|
| 124 |
+
#}
|
| 125 |
+
|
| 126 |
+
#with torch.no_grad():
|
| 127 |
+
# estimated = avse_model.enhance(data).reshape(-1)
|
| 128 |
+
estimated_chunks = []
|
| 129 |
|
| 130 |
with torch.no_grad():
|
| 131 |
+
for i in range(min_len):
|
| 132 |
+
chunk_data = {
|
| 133 |
+
"noisy_audio": audio_chunks[i],
|
| 134 |
+
"video_frames": video_chunks[i][np.newaxis, ...]
|
| 135 |
+
}
|
| 136 |
+
est = avse_model.enhance(chunk_data).reshape(-1)
|
| 137 |
+
estimated_chunks.append(est)
|
| 138 |
+
|
| 139 |
+
estimated = np.concatenate(estimated_chunks, axis=0)
|
| 140 |
|
| 141 |
# Save result
|
| 142 |
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
|
|
|
| 162 |
torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
|
| 163 |
return resampled_audio_path
|
| 164 |
|
| 165 |
+
@spaces.GPU
|
| 166 |
+
def yolo_detection(frame, verbose=False):
|
| 167 |
+
# Load face detector
|
| 168 |
+
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
| 169 |
+
return model(frame, verbose=verbose)[0]
|
| 170 |
|
| 171 |
@spaces.GPU
|
| 172 |
def extract_faces(video_file):
|
| 173 |
+
if isinstance(video_input, dict):
|
| 174 |
+
video_path = video_input.get("path") or video_input.get("url")
|
| 175 |
+
if video_path.startswith("http"):
|
| 176 |
+
# download video
|
| 177 |
+
tmpdir = tempfile.mkdtemp()
|
| 178 |
+
ext = os.path.splitext(urlparse(video_path).path)[1]
|
| 179 |
+
local_path = os.path.join(tmpdir, "input_video" + ext)
|
| 180 |
+
with open(local_path, "wb") as f:
|
| 181 |
+
f.write(requests.get(video_path).content)
|
| 182 |
+
video_file = local_path
|
| 183 |
+
else:
|
| 184 |
+
video_file = video_path
|
| 185 |
+
else:
|
| 186 |
+
video_file = video_input # string path from UI
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
cap = cv2.VideoCapture(video_file)
|
| 192 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 193 |
frames = []
|
|
|
|
| 198 |
break
|
| 199 |
|
| 200 |
# Inference
|
| 201 |
+
#results = model(frame, verbose=False)[0]
|
| 202 |
+
results = yolo_detection(frame, verbose=False)
|
| 203 |
for box in results.boxes:
|
| 204 |
# version 1
|
| 205 |
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
|
|
| 269 |
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
| 270 |
|
| 271 |
|
| 272 |
+
return output_path, enhanced_audio_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
#return output_path, audio_path
|
| 274 |
|
| 275 |
iface = gr.Interface(
|
|
|
|
| 281 |
gr.Audio(label="Enhanced Audio", type="filepath")
|
| 282 |
],
|
| 283 |
title="Face Detector",
|
| 284 |
+
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio.",
|
| 285 |
+
api_name="/predict"
|
| 286 |
)
|
| 287 |
|
| 288 |
iface.launch()
|
| 289 |
+
|