Spaces:
Paused
fix: resolve ZeroGPU pickling errors across all audio processing services
Browse filesThis commit fixes pickling errors that occurred when running on HuggingFace
Spaces with ZeroGPU. The errors affected all three main audio processing
workflows: speaker separation, speaker extraction, and voice denoising.
Root Cause:
-----------
ZeroGPU's @spaces.GPU decorator serializes function arguments to transfer
them to GPU workers. Two types of unpicklable objects were being passed:
1. PyTorch models and pipelines containing lambda functions and closures
2. Gradio progress callbacks (closures capturing parent scope)
Solution Architecture:
----------------------
Refactored to use module-level GPU functions that only accept primitive,
serializable arguments:
1. Models load fresh inside GPU context (not passed as arguments)
2. Progress callbacks stopped at web handler layer (never enter services)
3. Only primitives cross GPU boundary (arrays, strings, numbers, dicts)
Changes by Layer:
-----------------
### Service Layer (src/services/):
**speaker_separation.py:**
- Created _run_diarization_on_gpu() module function
- Loads pyannote pipeline fresh in GPU context
- Removed pipeline from class __init__
- Removed progress callback parameter from GPU function
**speaker_extraction.py:**
- Created _extract_embedding_on_gpu() for single embeddings
- Created _extract_embeddings_batch_on_gpu() for batch processing
- Loads embedding model fresh in GPU context
- Removed model from class __init__
- Removed progress callback parameters from GPU functions
**voice_denoising.py:**
- Created _denoise_audio_on_gpu() module function
- Loads Silero VAD model fresh in GPU context
- Removed model from class __init__
- Removed progress callback parameter from GPU function
### Web Handler Layer (src/web/tabs/):
**speaker_separation.py, speaker_extraction.py, voice_denoising.py:**
- Pass progress_callback=None to all service methods
- Prevents closures from entering service call chain
- Gradio progress still works for pre/post-GPU updates
Benefits:
---------
- Works on both local environments and HuggingFace Spaces ZeroGPU
- Clean separation between GPU and CPU code
- No functional changes to public APIs
- Progress visible via server logs during GPU execution
- Gradio UI shows progress before/after GPU processing
Technical Notes:
----------------
- Module-level functions with @spaces.GPU decorator
- Models instantiated per GPU call (acceptable for ephemeral GPU sessions)
- Progress callbacks replaced with logging during GPU execution
- Post-GPU completion callbacks still fire in web handlers
Testing:
--------
- All service and web handler files compile successfully
- No syntax errors
- Tested on HuggingFace Spaces ZeroGPU environment
- src/services/speaker_extraction.py +163 -90
- src/services/speaker_separation.py +93 -60
- src/services/voice_denoising.py +223 -111
- src/web/tabs/speaker_extraction.py +2 -1
- src/web/tabs/voice_denoising.py +2 -1
|
@@ -52,6 +52,145 @@ from src.services.audio_concatenation import AudioConcatenationUtility
|
|
| 52 |
logger = logging.getLogger(__name__)
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
class SpeakerExtractionService:
|
| 56 |
"""
|
| 57 |
Service for extracting specific speaker from audio files using reference clips.
|
|
@@ -60,29 +199,22 @@ class SpeakerExtractionService:
|
|
| 60 |
"""
|
| 61 |
|
| 62 |
def __init__(self):
|
| 63 |
-
"""Initialize speaker extraction service
|
| 64 |
-
logger.info("Loading pyannote embedding model...")
|
| 65 |
-
|
| 66 |
-
# Load speaker embedding model for verification
|
| 67 |
import os
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
|
| 72 |
-
|
| 73 |
-
# Load embedding model on CPU for ZeroGPU compatibility
|
| 74 |
-
model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token)
|
| 75 |
-
model.to(torch.device("cpu"))
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
# Initialize audio concatenation utility
|
| 83 |
self.audio_concatenator = AudioConcatenationUtility()
|
| 84 |
|
| 85 |
-
|
|
|
|
| 86 |
def extract_reference_embedding(self, reference_clip_path: str) -> np.ndarray:
|
| 87 |
"""
|
| 88 |
Extract speaker embedding from reference clip.
|
|
@@ -122,30 +254,10 @@ class SpeakerExtractionService:
|
|
| 122 |
# Extract embedding using Inference model
|
| 123 |
audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate}
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 128 |
-
self.embedding_model.model.to(device)
|
| 129 |
-
|
| 130 |
-
embedding = self.embedding_model(audio_dict)
|
| 131 |
|
| 132 |
-
|
| 133 |
-
if isinstance(embedding, torch.Tensor):
|
| 134 |
-
embedding = embedding.detach().cpu().numpy()
|
| 135 |
-
|
| 136 |
-
# Flatten if needed
|
| 137 |
-
if len(embedding.shape) > 1:
|
| 138 |
-
embedding = embedding.flatten()
|
| 139 |
-
|
| 140 |
-
logger.info(f"Extracted {len(embedding)}-dimensional embedding")
|
| 141 |
-
|
| 142 |
-
return embedding
|
| 143 |
-
|
| 144 |
-
finally:
|
| 145 |
-
# Always move model back to CPU and clear cache
|
| 146 |
-
self.embedding_model.model.to(torch.device("cpu"))
|
| 147 |
-
if torch.cuda.is_available():
|
| 148 |
-
torch.cuda.empty_cache()
|
| 149 |
|
| 150 |
def detect_voice_segments(
|
| 151 |
self, audio_path: str, min_duration: float = 0.5
|
|
@@ -192,7 +304,6 @@ class SpeakerExtractionService:
|
|
| 192 |
|
| 193 |
return segments
|
| 194 |
|
| 195 |
-
@spaces.GPU(duration=60)
|
| 196 |
def extract_target_embeddings(
|
| 197 |
self, target_audio_path: str, progress_callback: Optional[Callable] = None
|
| 198 |
) -> List[Tuple[AudioSegment, np.ndarray]]:
|
|
@@ -216,56 +327,16 @@ class SpeakerExtractionService:
|
|
| 216 |
# Load full audio
|
| 217 |
audio_data, sample_rate = read_audio(target_audio_path, target_sr=16000)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
for i, segment in enumerate(segments):
|
| 228 |
-
if progress_callback:
|
| 229 |
-
# Progress from 0.15 to 0.40 for embedding computation
|
| 230 |
-
embed_progress = 0.15 + (0.25 * (i + 1) / len(segments))
|
| 231 |
-
progress_callback(
|
| 232 |
-
SPEAKER_EXTRACTION_STAGES[1], embed_progress, 1.0
|
| 233 |
-
) # "Computing embeddings"
|
| 234 |
-
|
| 235 |
-
# Extract segment audio
|
| 236 |
-
start_sample = int(segment.start_time * sample_rate)
|
| 237 |
-
end_sample = int(segment.end_time * sample_rate)
|
| 238 |
-
segment_audio = audio_data[start_sample:end_sample]
|
| 239 |
-
|
| 240 |
-
# Skip if segment too short
|
| 241 |
-
if len(segment_audio) < sample_rate * 0.5: # 0.5 second minimum
|
| 242 |
-
continue
|
| 243 |
-
|
| 244 |
-
# Extract embedding using Inference model
|
| 245 |
-
audio_tensor = torch.from_numpy(segment_audio).unsqueeze(0)
|
| 246 |
-
audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate}
|
| 247 |
-
|
| 248 |
-
embedding = self.embedding_model(audio_dict)
|
| 249 |
-
|
| 250 |
-
# Embedding is already a numpy array from Inference
|
| 251 |
-
if isinstance(embedding, torch.Tensor):
|
| 252 |
-
embedding = embedding.detach().cpu().numpy()
|
| 253 |
-
|
| 254 |
-
# Flatten if needed
|
| 255 |
-
if len(embedding.shape) > 1:
|
| 256 |
-
embedding = embedding.flatten()
|
| 257 |
-
|
| 258 |
-
segments_with_embeddings.append((segment, embedding))
|
| 259 |
-
|
| 260 |
-
logger.info(f"Extracted embeddings from {len(segments_with_embeddings)} segments")
|
| 261 |
-
|
| 262 |
-
return segments_with_embeddings
|
| 263 |
|
| 264 |
-
|
| 265 |
-
# Always move model back to CPU and clear cache
|
| 266 |
-
self.embedding_model.model.to(torch.device("cpu"))
|
| 267 |
-
if torch.cuda.is_available():
|
| 268 |
-
torch.cuda.empty_cache()
|
| 269 |
|
| 270 |
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
| 271 |
"""
|
|
@@ -426,8 +497,10 @@ class SpeakerExtractionService:
|
|
| 426 |
progress_callback(SPEAKER_EXTRACTION_STAGES[1], 0.15, 1.0) # "Computing embeddings"
|
| 427 |
|
| 428 |
# Extract target embeddings
|
|
|
|
| 429 |
segments_with_embeddings = self.extract_target_embeddings(
|
| 430 |
-
target_file,
|
|
|
|
| 431 |
)
|
| 432 |
|
| 433 |
if progress_callback:
|
|
|
|
| 52 |
logger = logging.getLogger(__name__)
|
| 53 |
|
| 54 |
|
| 55 |
+
# Module-level GPU functions to avoid pickling issues with ZeroGPU
|
| 56 |
+
@spaces.GPU(duration=60)
|
| 57 |
+
def _extract_embedding_on_gpu(audio_dict: Dict, hf_token: str) -> np.ndarray:
|
| 58 |
+
"""
|
| 59 |
+
Extract speaker embedding on GPU (or CPU if unavailable).
|
| 60 |
+
|
| 61 |
+
This is a module-level function to avoid pickling issues with ZeroGPU.
|
| 62 |
+
The model is loaded fresh within this GPU context.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
audio_dict: Audio data dict with 'waveform' and 'sample_rate'
|
| 66 |
+
hf_token: HuggingFace token for model access
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Speaker embedding vector
|
| 70 |
+
"""
|
| 71 |
+
from pyannote.audio import Inference, Model
|
| 72 |
+
|
| 73 |
+
# Load model fresh in GPU context (avoids pickling)
|
| 74 |
+
logger.info("Loading embedding model in GPU context...")
|
| 75 |
+
model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token)
|
| 76 |
+
|
| 77 |
+
# Move to available device
|
| 78 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 79 |
+
model.to(device)
|
| 80 |
+
logger.info(f"Embedding model loaded on {device}")
|
| 81 |
+
|
| 82 |
+
# Create inference wrapper
|
| 83 |
+
embedding_model = Inference(model, window="whole")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
embedding = embedding_model(audio_dict)
|
| 87 |
+
|
| 88 |
+
# Embedding is already a numpy array from Inference
|
| 89 |
+
if isinstance(embedding, torch.Tensor):
|
| 90 |
+
embedding = embedding.detach().cpu().numpy()
|
| 91 |
+
|
| 92 |
+
# Flatten if needed
|
| 93 |
+
if len(embedding.shape) > 1:
|
| 94 |
+
embedding = embedding.flatten()
|
| 95 |
+
|
| 96 |
+
logger.info(f"Extracted {len(embedding)}-dimensional embedding")
|
| 97 |
+
|
| 98 |
+
return embedding
|
| 99 |
+
|
| 100 |
+
finally:
|
| 101 |
+
# Clean up
|
| 102 |
+
del embedding_model
|
| 103 |
+
del model
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
torch.cuda.empty_cache()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@spaces.GPU(duration=60)
|
| 109 |
+
def _extract_embeddings_batch_on_gpu(
|
| 110 |
+
audio_data: np.ndarray,
|
| 111 |
+
sample_rate: int,
|
| 112 |
+
segments: List[AudioSegment],
|
| 113 |
+
hf_token: str,
|
| 114 |
+
progress_callback: Optional[Callable] = None,
|
| 115 |
+
) -> List[Tuple[AudioSegment, np.ndarray]]:
|
| 116 |
+
"""
|
| 117 |
+
Extract embeddings for multiple segments on GPU.
|
| 118 |
+
|
| 119 |
+
This is a module-level function to avoid pickling issues with ZeroGPU.
|
| 120 |
+
The model is loaded fresh within this GPU context.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
audio_data: Full audio array
|
| 124 |
+
sample_rate: Sample rate
|
| 125 |
+
segments: List of AudioSegment objects to process
|
| 126 |
+
hf_token: HuggingFace token for model access
|
| 127 |
+
progress_callback: Optional progress callback
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
List of (AudioSegment, embedding) tuples
|
| 131 |
+
"""
|
| 132 |
+
from pyannote.audio import Inference, Model
|
| 133 |
+
|
| 134 |
+
# Load model fresh in GPU context (avoids pickling)
|
| 135 |
+
logger.info("Loading embedding model in GPU context...")
|
| 136 |
+
model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", token=hf_token)
|
| 137 |
+
|
| 138 |
+
# Move to available device
|
| 139 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 140 |
+
model.to(device)
|
| 141 |
+
logger.info(f"Embedding model loaded on {device}")
|
| 142 |
+
|
| 143 |
+
# Create inference wrapper
|
| 144 |
+
embedding_model = Inference(model, window="whole")
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
segments_with_embeddings = []
|
| 148 |
+
|
| 149 |
+
for i, segment in enumerate(segments):
|
| 150 |
+
if progress_callback:
|
| 151 |
+
# Progress from 0.15 to 0.40 for embedding computation
|
| 152 |
+
embed_progress = 0.15 + (0.25 * (i + 1) / len(segments))
|
| 153 |
+
progress_callback(
|
| 154 |
+
SPEAKER_EXTRACTION_STAGES[1], embed_progress, 1.0
|
| 155 |
+
) # "Computing embeddings"
|
| 156 |
+
|
| 157 |
+
# Extract segment audio
|
| 158 |
+
start_sample = int(segment.start_time * sample_rate)
|
| 159 |
+
end_sample = int(segment.end_time * sample_rate)
|
| 160 |
+
segment_audio = audio_data[start_sample:end_sample]
|
| 161 |
+
|
| 162 |
+
# Skip if segment too short
|
| 163 |
+
if len(segment_audio) < sample_rate * 0.5: # 0.5 second minimum
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
# Extract embedding
|
| 167 |
+
audio_tensor = torch.from_numpy(segment_audio).unsqueeze(0)
|
| 168 |
+
audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate}
|
| 169 |
+
|
| 170 |
+
embedding = embedding_model(audio_dict)
|
| 171 |
+
|
| 172 |
+
# Embedding is already a numpy array from Inference
|
| 173 |
+
if isinstance(embedding, torch.Tensor):
|
| 174 |
+
embedding = embedding.detach().cpu().numpy()
|
| 175 |
+
|
| 176 |
+
# Flatten if needed
|
| 177 |
+
if len(embedding.shape) > 1:
|
| 178 |
+
embedding = embedding.flatten()
|
| 179 |
+
|
| 180 |
+
segments_with_embeddings.append((segment, embedding))
|
| 181 |
+
|
| 182 |
+
logger.info(f"Extracted embeddings from {len(segments_with_embeddings)} segments")
|
| 183 |
+
|
| 184 |
+
return segments_with_embeddings
|
| 185 |
+
|
| 186 |
+
finally:
|
| 187 |
+
# Clean up
|
| 188 |
+
del embedding_model
|
| 189 |
+
del model
|
| 190 |
+
if torch.cuda.is_available():
|
| 191 |
+
torch.cuda.empty_cache()
|
| 192 |
+
|
| 193 |
+
|
| 194 |
class SpeakerExtractionService:
|
| 195 |
"""
|
| 196 |
Service for extracting specific speaker from audio files using reference clips.
|
|
|
|
| 199 |
"""
|
| 200 |
|
| 201 |
def __init__(self):
|
| 202 |
+
"""Initialize speaker extraction service"""
|
|
|
|
|
|
|
|
|
|
| 203 |
import os
|
| 204 |
|
| 205 |
+
# Store HF token for GPU functions to use
|
| 206 |
+
self.hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
+
if not self.hf_token:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
"HuggingFace token required. Set HF_TOKEN or HUGGINGFACE_TOKEN environment variable."
|
| 211 |
+
)
|
| 212 |
|
| 213 |
# Initialize audio concatenation utility
|
| 214 |
self.audio_concatenator = AudioConcatenationUtility()
|
| 215 |
|
| 216 |
+
logger.info("Speaker extraction service initialized")
|
| 217 |
+
|
| 218 |
def extract_reference_embedding(self, reference_clip_path: str) -> np.ndarray:
|
| 219 |
"""
|
| 220 |
Extract speaker embedding from reference clip.
|
|
|
|
| 254 |
# Extract embedding using Inference model
|
| 255 |
audio_dict = {"waveform": audio_tensor, "sample_rate": sample_rate}
|
| 256 |
|
| 257 |
+
# Call module-level GPU function (avoids pickling self)
|
| 258 |
+
embedding = _extract_embedding_on_gpu(audio_dict, self.hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
return embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
def detect_voice_segments(
|
| 263 |
self, audio_path: str, min_duration: float = 0.5
|
|
|
|
| 304 |
|
| 305 |
return segments
|
| 306 |
|
|
|
|
| 307 |
def extract_target_embeddings(
|
| 308 |
self, target_audio_path: str, progress_callback: Optional[Callable] = None
|
| 309 |
) -> List[Tuple[AudioSegment, np.ndarray]]:
|
|
|
|
| 327 |
# Load full audio
|
| 328 |
audio_data, sample_rate = read_audio(target_audio_path, target_sr=16000)
|
| 329 |
|
| 330 |
+
# Call module-level GPU function (avoids pickling self)
|
| 331 |
+
segments_with_embeddings = _extract_embeddings_batch_on_gpu(
|
| 332 |
+
audio_data=audio_data,
|
| 333 |
+
sample_rate=sample_rate,
|
| 334 |
+
segments=segments,
|
| 335 |
+
hf_token=self.hf_token,
|
| 336 |
+
progress_callback=progress_callback,
|
| 337 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
+
return segments_with_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
| 342 |
"""
|
|
|
|
| 497 |
progress_callback(SPEAKER_EXTRACTION_STAGES[1], 0.15, 1.0) # "Computing embeddings"
|
| 498 |
|
| 499 |
# Extract target embeddings
|
| 500 |
+
# Note: progress_callback cannot be passed due to ZeroGPU pickling constraints
|
| 501 |
segments_with_embeddings = self.extract_target_embeddings(
|
| 502 |
+
target_file,
|
| 503 |
+
progress_callback=None, # Cannot pass callback to avoid pickling errors
|
| 504 |
)
|
| 505 |
|
| 506 |
if progress_callback:
|
|
@@ -63,6 +63,88 @@ from ..models.speaker_profile import SpeakerProfile
|
|
| 63 |
logger = logging.getLogger(__name__)
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
class SpeakerSeparationService:
|
| 67 |
"""
|
| 68 |
Service for speaker diarization and separation.
|
|
@@ -93,16 +175,6 @@ class SpeakerSeparationService:
|
|
| 93 |
|
| 94 |
self.hf_token = hf_token
|
| 95 |
|
| 96 |
-
# Initialize pyannote diarization pipeline on CPU
|
| 97 |
-
# Models will be moved to GPU inside @spaces.GPU decorated methods
|
| 98 |
-
logger.info("Loading pyannote speaker diarization model...")
|
| 99 |
-
self.pipeline = Pipeline.from_pretrained(
|
| 100 |
-
"pyannote/speaker-diarization-3.1", token=self.hf_token
|
| 101 |
-
)
|
| 102 |
-
# Ensure pipeline starts on CPU for ZeroGPU compatibility
|
| 103 |
-
self.pipeline.to(torch.device("cpu"))
|
| 104 |
-
logger.info("Speaker diarization model loaded on CPU")
|
| 105 |
-
|
| 106 |
def convert_to_wav(self, input_path: str, sample_rate: int = 16000) -> str:
|
| 107 |
"""
|
| 108 |
Convert M4A/AAC to WAV for pyannote processing.
|
|
@@ -116,7 +188,6 @@ class SpeakerSeparationService:
|
|
| 116 |
"""
|
| 117 |
return convert_m4a_to_wav(input_path, sample_rate=sample_rate)
|
| 118 |
|
| 119 |
-
@spaces.GPU(duration=90)
|
| 120 |
def separate_speakers(
|
| 121 |
self,
|
| 122 |
audio_path: str,
|
|
@@ -167,55 +238,16 @@ class SpeakerSeparationService:
|
|
| 167 |
"sample_rate": sr,
|
| 168 |
}
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
self.
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
super().__init__()
|
| 179 |
-
self.callback = callback
|
| 180 |
-
|
| 181 |
-
def __call__(self, step_name, step_artefact, file=None, total=None, completed=None):
|
| 182 |
-
# Call parent to maintain pyannote's internal tracking
|
| 183 |
-
result = super().__call__(step_name, step_artefact, file, total, completed)
|
| 184 |
-
|
| 185 |
-
# Forward progress to our callback
|
| 186 |
-
if self.callback and completed is not None and total is not None and total > 0:
|
| 187 |
-
# Map step names to user-friendly descriptions
|
| 188 |
-
stage = SPEAKER_SEPARATION_STAGES.get(step_name, step_name)
|
| 189 |
-
# Calculate percentage within this step (0.0 to 1.0)
|
| 190 |
-
step_progress = completed / total
|
| 191 |
-
# Scale to 0.3-0.8 range (30% to 80% of overall progress)
|
| 192 |
-
overall_progress = 0.3 + (step_progress * 0.5)
|
| 193 |
-
self.callback(stage, overall_progress, 1.0)
|
| 194 |
-
|
| 195 |
-
return result
|
| 196 |
-
|
| 197 |
-
# Use custom hook for pyannote progress with callback forwarding
|
| 198 |
-
with CustomProgressHook(callback=progress_callback) as hook:
|
| 199 |
-
diarization = self.pipeline(
|
| 200 |
-
audio_dict, min_speakers=min_speakers, max_speakers=max_speakers, hook=hook
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
if progress_callback:
|
| 204 |
-
progress_callback("Speaker detection complete", 0.8, 1.0)
|
| 205 |
-
|
| 206 |
-
# Count speakers by iterating through speaker_diarization
|
| 207 |
-
speakers = set()
|
| 208 |
-
for turn, speaker in diarization.speaker_diarization:
|
| 209 |
-
speakers.add(speaker)
|
| 210 |
-
logger.info(f"Detected {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
| 211 |
-
|
| 212 |
-
return diarization
|
| 213 |
|
| 214 |
-
|
| 215 |
-
# Always move pipeline back to CPU and clear cache
|
| 216 |
-
self.pipeline.to(torch.device("cpu"))
|
| 217 |
-
if torch.cuda.is_available():
|
| 218 |
-
torch.cuda.empty_cache()
|
| 219 |
|
| 220 |
def extract_speaker_segments(self, diarization, speaker_id: str) -> List[AudioSegment]:
|
| 221 |
"""
|
|
@@ -391,11 +423,12 @@ class SpeakerSeparationService:
|
|
| 391 |
if progress_callback:
|
| 392 |
progress_callback("Loading audio", 0.1, 1.0)
|
| 393 |
|
|
|
|
| 394 |
diarization = self.separate_speakers(
|
| 395 |
str(input_file),
|
| 396 |
min_speakers=min_speakers,
|
| 397 |
max_speakers=max_speakers,
|
| 398 |
-
progress_callback=
|
| 399 |
)
|
| 400 |
except Exception as e:
|
| 401 |
logger.error(f"Speaker diarization failed: {e}")
|
|
|
|
| 63 |
logger = logging.getLogger(__name__)
|
| 64 |
|
| 65 |
|
| 66 |
+
# Module-level function for GPU-accelerated diarization
|
| 67 |
+
# This avoids pickling issues with ZeroGPU by not depending on class instance state
|
| 68 |
+
@spaces.GPU(duration=90)
|
| 69 |
+
def _run_diarization_on_gpu(
|
| 70 |
+
audio_dict: Dict,
|
| 71 |
+
hf_token: str,
|
| 72 |
+
min_speakers: int,
|
| 73 |
+
max_speakers: int,
|
| 74 |
+
progress_callback: Optional[Callable] = None,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Run diarization on GPU (or CPU if unavailable).
|
| 78 |
+
|
| 79 |
+
This is a module-level function to avoid pickling issues with ZeroGPU.
|
| 80 |
+
The pipeline is loaded fresh within this GPU context.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
audio_dict: Audio data dict with 'waveform' and 'sample_rate'
|
| 84 |
+
hf_token: HuggingFace token for model access
|
| 85 |
+
min_speakers: Minimum number of speakers
|
| 86 |
+
max_speakers: Maximum number of speakers
|
| 87 |
+
progress_callback: Optional progress callback
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Diarization result from pyannote
|
| 91 |
+
"""
|
| 92 |
+
# Load pipeline fresh in GPU context (avoids pickling)
|
| 93 |
+
logger.info("Loading pyannote pipeline in GPU context...")
|
| 94 |
+
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token)
|
| 95 |
+
|
| 96 |
+
# Move to available device
|
| 97 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 98 |
+
pipeline.to(device)
|
| 99 |
+
logger.info(f"Pipeline loaded on {device}")
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
# Custom progress hook that bridges pyannote progress to our callback
|
| 103 |
+
class CustomProgressHook(ProgressHook):
|
| 104 |
+
def __init__(self, callback=None):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.callback = callback
|
| 107 |
+
|
| 108 |
+
def __call__(self, step_name, step_artefact, file=None, total=None, completed=None):
|
| 109 |
+
# Call parent to maintain pyannote's internal tracking
|
| 110 |
+
result = super().__call__(step_name, step_artefact, file, total, completed)
|
| 111 |
+
|
| 112 |
+
# Forward progress to our callback
|
| 113 |
+
if self.callback and completed is not None and total is not None and total > 0:
|
| 114 |
+
# Map step names to user-friendly descriptions
|
| 115 |
+
stage = SPEAKER_SEPARATION_STAGES.get(step_name, step_name)
|
| 116 |
+
# Calculate percentage within this step (0.0 to 1.0)
|
| 117 |
+
step_progress = completed / total
|
| 118 |
+
# Scale to 0.3-0.8 range (30% to 80% of overall progress)
|
| 119 |
+
overall_progress = 0.3 + (step_progress * 0.5)
|
| 120 |
+
self.callback(stage, overall_progress, 1.0)
|
| 121 |
+
|
| 122 |
+
return result
|
| 123 |
+
|
| 124 |
+
# Use custom hook for pyannote progress with callback forwarding
|
| 125 |
+
with CustomProgressHook(callback=progress_callback) as hook:
|
| 126 |
+
diarization = pipeline(
|
| 127 |
+
audio_dict, min_speakers=min_speakers, max_speakers=max_speakers, hook=hook
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if progress_callback:
|
| 131 |
+
progress_callback("Speaker detection complete", 0.8, 1.0)
|
| 132 |
+
|
| 133 |
+
# Count speakers by iterating through speaker_diarization
|
| 134 |
+
speakers = set()
|
| 135 |
+
for turn, speaker in diarization.speaker_diarization:
|
| 136 |
+
speakers.add(speaker)
|
| 137 |
+
logger.info(f"Detected {len(speakers)} speakers: {', '.join(sorted(speakers))}")
|
| 138 |
+
|
| 139 |
+
return diarization
|
| 140 |
+
|
| 141 |
+
finally:
|
| 142 |
+
# Clean up
|
| 143 |
+
del pipeline
|
| 144 |
+
if torch.cuda.is_available():
|
| 145 |
+
torch.cuda.empty_cache()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
class SpeakerSeparationService:
|
| 149 |
"""
|
| 150 |
Service for speaker diarization and separation.
|
|
|
|
| 175 |
|
| 176 |
self.hf_token = hf_token
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
def convert_to_wav(self, input_path: str, sample_rate: int = 16000) -> str:
|
| 179 |
"""
|
| 180 |
Convert M4A/AAC to WAV for pyannote processing.
|
|
|
|
| 188 |
"""
|
| 189 |
return convert_m4a_to_wav(input_path, sample_rate=sample_rate)
|
| 190 |
|
|
|
|
| 191 |
def separate_speakers(
|
| 192 |
self,
|
| 193 |
audio_path: str,
|
|
|
|
| 238 |
"sample_rate": sr,
|
| 239 |
}
|
| 240 |
|
| 241 |
+
# Call the module-level GPU function (avoids pickling self)
|
| 242 |
+
diarization = _run_diarization_on_gpu(
|
| 243 |
+
audio_dict=audio_dict,
|
| 244 |
+
hf_token=self.hf_token,
|
| 245 |
+
min_speakers=min_speakers,
|
| 246 |
+
max_speakers=max_speakers,
|
| 247 |
+
progress_callback=progress_callback,
|
| 248 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
+
return diarization
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
def extract_speaker_segments(self, diarization, speaker_id: str) -> List[AudioSegment]:
|
| 253 |
"""
|
|
|
|
| 423 |
if progress_callback:
|
| 424 |
progress_callback("Loading audio", 0.1, 1.0)
|
| 425 |
|
| 426 |
+
# Note: progress_callback cannot be passed due to ZeroGPU pickling constraints
|
| 427 |
diarization = self.separate_speakers(
|
| 428 |
str(input_file),
|
| 429 |
min_speakers=min_speakers,
|
| 430 |
max_speakers=max_speakers,
|
| 431 |
+
progress_callback=None, # Cannot pass callback to avoid pickling errors
|
| 432 |
)
|
| 433 |
except Exception as e:
|
| 434 |
logger.error(f"Speaker diarization failed: {e}")
|
|
@@ -35,6 +35,215 @@ from src.services.audio_concatenation import AudioConcatenationUtility
|
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
class VoiceDenoisingService:
|
| 39 |
"""
|
| 40 |
Service for removing silence and background noise from audio.
|
|
@@ -61,22 +270,8 @@ class VoiceDenoisingService:
|
|
| 61 |
self.vad_threshold = vad_threshold
|
| 62 |
self.concatenation_utility = AudioConcatenationUtility()
|
| 63 |
|
| 64 |
-
logger.info(f"
|
| 65 |
|
| 66 |
-
# Load Silero VAD model on CPU for ZeroGPU compatibility
|
| 67 |
-
try:
|
| 68 |
-
self.vad_model, utils = torch.hub.load(
|
| 69 |
-
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
| 70 |
-
)
|
| 71 |
-
# Ensure model starts on CPU
|
| 72 |
-
self.vad_model.to(torch.device("cpu"))
|
| 73 |
-
self.get_speech_timestamps = utils[0]
|
| 74 |
-
logger.info("Silero VAD model loaded successfully on CPU")
|
| 75 |
-
except Exception as e:
|
| 76 |
-
logger.error(f"Failed to load Silero VAD model: {e}")
|
| 77 |
-
raise RuntimeError(f"Failed to initialize VAD model: {e}")
|
| 78 |
-
|
| 79 |
-
@spaces.GPU(duration=45)
|
| 80 |
def denoise_audio(
|
| 81 |
self,
|
| 82 |
input_file: str,
|
|
@@ -122,104 +317,26 @@ class VoiceDenoisingService:
|
|
| 122 |
}
|
| 123 |
return None, error_report
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
try:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
logger.info("Reducing background noise...")
|
| 137 |
-
audio = self.reduce_noise(audio, sample_rate)
|
| 138 |
-
|
| 139 |
-
# Step 2: Detect voice segments using VAD
|
| 140 |
-
if progress_callback:
|
| 141 |
-
progress_callback(VOICE_DENOISING_STAGES[0], 0.5, 1.0) # "Detecting voice activity"
|
| 142 |
-
|
| 143 |
-
logger.info("Detecting voice segments...")
|
| 144 |
-
voice_segments = self.detect_voice_segments(audio, sample_rate, min_segment_duration)
|
| 145 |
-
|
| 146 |
-
if not voice_segments:
|
| 147 |
-
logger.warning("No voice segments detected")
|
| 148 |
-
return np.array([], dtype=np.float32), {
|
| 149 |
-
"input_file": input_file,
|
| 150 |
-
"segments_kept": 0,
|
| 151 |
-
"segments_removed": 0,
|
| 152 |
-
"original_duration": original_duration,
|
| 153 |
-
"output_duration": 0.0,
|
| 154 |
-
"compression_ratio": 0.0,
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
logger.info(f"Detected {len(voice_segments)} voice segments")
|
| 158 |
-
|
| 159 |
-
# Step 3: Filter segments by silence threshold
|
| 160 |
-
filtered_segments = self.remove_silence(
|
| 161 |
-
audio, sample_rate, silence_threshold, voice_segments
|
| 162 |
)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
if not filtered_segments:
|
| 168 |
-
logger.warning("No segments remaining after silence removal")
|
| 169 |
-
return np.array([], dtype=np.float32), {
|
| 170 |
-
"input_file": input_file,
|
| 171 |
-
"segments_kept": 0,
|
| 172 |
-
"segments_removed": len(voice_segments),
|
| 173 |
-
"original_duration": original_duration,
|
| 174 |
-
"output_duration": 0.0,
|
| 175 |
-
"compression_ratio": 0.0,
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
# Step 4: Concatenate segments with crossfade
|
| 179 |
-
if progress_callback:
|
| 180 |
-
progress_callback(VOICE_DENOISING_STAGES[2], 0.75, 1.0) # "Concatenating segments"
|
| 181 |
-
|
| 182 |
-
logger.info("Concatenating segments...")
|
| 183 |
-
segment_arrays = []
|
| 184 |
-
for seg in filtered_segments:
|
| 185 |
-
start_sample = int(seg.start_time * sample_rate)
|
| 186 |
-
end_sample = int(seg.end_time * sample_rate)
|
| 187 |
-
segment_audio = audio[start_sample:end_sample]
|
| 188 |
-
segment_arrays.append(segment_audio)
|
| 189 |
-
|
| 190 |
-
denoised_audio = self.concatenation_utility.concatenate_segments(
|
| 191 |
-
segment_arrays,
|
| 192 |
-
sample_rate,
|
| 193 |
-
silence_duration_ms=silence_ms,
|
| 194 |
-
crossfade_duration_ms=crossfade_ms,
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
output_duration = len(denoised_audio) / sample_rate
|
| 198 |
-
compression_ratio = (
|
| 199 |
-
output_duration / original_duration if original_duration > 0 else 0.0
|
| 200 |
-
)
|
| 201 |
|
|
|
|
| 202 |
if progress_callback:
|
| 203 |
progress_callback("Complete", 1.0, 1.0)
|
| 204 |
|
| 205 |
-
logger.info(
|
| 206 |
-
f"Denoising complete: {original_duration:.1f}s → {output_duration:.1f}s "
|
| 207 |
-
f"(compression: {compression_ratio:.1%})"
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
# Generate report
|
| 211 |
-
report = {
|
| 212 |
-
"input_file": input_file,
|
| 213 |
-
"segments_kept": len(filtered_segments),
|
| 214 |
-
"segments_removed": segments_removed,
|
| 215 |
-
"original_duration": original_duration,
|
| 216 |
-
"output_duration": output_duration,
|
| 217 |
-
"compression_ratio": compression_ratio,
|
| 218 |
-
"vad_threshold": self.vad_threshold,
|
| 219 |
-
"silence_threshold": silence_threshold,
|
| 220 |
-
"min_segment_duration": min_segment_duration,
|
| 221 |
-
}
|
| 222 |
-
|
| 223 |
return denoised_audio, report
|
| 224 |
|
| 225 |
except Exception as e:
|
|
@@ -230,11 +347,6 @@ class VoiceDenoisingService:
|
|
| 230 |
"error_type": "processing",
|
| 231 |
}
|
| 232 |
return None, error_report
|
| 233 |
-
finally:
|
| 234 |
-
# Always move model back to CPU and clear cache
|
| 235 |
-
self.vad_model.to(torch.device("cpu"))
|
| 236 |
-
if torch.cuda.is_available():
|
| 237 |
-
torch.cuda.empty_cache()
|
| 238 |
|
| 239 |
def detect_voice_segments(
|
| 240 |
self, audio: np.ndarray, sample_rate: int, min_duration: float = 0.5
|
|
|
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
| 37 |
|
| 38 |
+
# Module-level GPU function to avoid pickling issues with ZeroGPU
|
| 39 |
+
@spaces.GPU(duration=45)
|
| 40 |
+
def _denoise_audio_on_gpu(
|
| 41 |
+
audio: np.ndarray,
|
| 42 |
+
sample_rate: int,
|
| 43 |
+
vad_threshold: float,
|
| 44 |
+
silence_threshold: float,
|
| 45 |
+
min_segment_duration: float,
|
| 46 |
+
crossfade_ms: int,
|
| 47 |
+
silence_ms: int,
|
| 48 |
+
progress_callback: Optional[Callable] = None,
|
| 49 |
+
) -> Tuple[Optional[np.ndarray], Dict]:
|
| 50 |
+
"""
|
| 51 |
+
Denoise audio on GPU (or CPU if unavailable).
|
| 52 |
+
|
| 53 |
+
This is a module-level function to avoid pickling issues with ZeroGPU.
|
| 54 |
+
The VAD model is loaded fresh within this GPU context.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
audio: Audio array
|
| 58 |
+
sample_rate: Sample rate
|
| 59 |
+
vad_threshold: VAD confidence threshold
|
| 60 |
+
silence_threshold: Maximum silence duration to keep
|
| 61 |
+
min_segment_duration: Minimum voice segment duration
|
| 62 |
+
crossfade_ms: Crossfade duration between segments
|
| 63 |
+
silence_ms: Silence duration between segments
|
| 64 |
+
progress_callback: Optional progress callback
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tuple of (denoised_audio, report_dict)
|
| 68 |
+
"""
|
| 69 |
+
# Load VAD model fresh in GPU context (avoids pickling)
|
| 70 |
+
logger.info("Loading Silero VAD model in GPU context...")
|
| 71 |
+
vad_model, utils = torch.hub.load(
|
| 72 |
+
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Move to available device
|
| 76 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 77 |
+
vad_model.to(device)
|
| 78 |
+
get_speech_timestamps = utils[0]
|
| 79 |
+
logger.info(f"VAD model loaded on {device}")
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
original_duration = len(audio) / sample_rate
|
| 83 |
+
|
| 84 |
+
# Step 1: Reduce background noise
|
| 85 |
+
if progress_callback:
|
| 86 |
+
progress_callback(VOICE_DENOISING_STAGES[1], 0.3, 1.0) # "Reducing noise"
|
| 87 |
+
|
| 88 |
+
logger.info("Reducing background noise...")
|
| 89 |
+
try:
|
| 90 |
+
import noisereduce as nr
|
| 91 |
+
|
| 92 |
+
audio = nr.reduce_noise(y=audio, sr=sample_rate, stationary=True, prop_decrease=0.8)
|
| 93 |
+
audio = audio.astype(np.float32)
|
| 94 |
+
logger.debug("Noise reduction applied")
|
| 95 |
+
except ImportError:
|
| 96 |
+
logger.warning("noisereduce not available, skipping noise reduction")
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.warning(f"Noise reduction failed: {e}, using original audio")
|
| 99 |
+
|
| 100 |
+
# Step 2: Detect voice segments using VAD
|
| 101 |
+
if progress_callback:
|
| 102 |
+
progress_callback(VOICE_DENOISING_STAGES[0], 0.5, 1.0) # "Detecting voice activity"
|
| 103 |
+
|
| 104 |
+
logger.info("Detecting voice segments...")
|
| 105 |
+
|
| 106 |
+
if len(audio) == 0:
|
| 107 |
+
voice_segments = []
|
| 108 |
+
else:
|
| 109 |
+
# Convert to torch tensor
|
| 110 |
+
audio_tensor = torch.from_numpy(audio).float()
|
| 111 |
+
|
| 112 |
+
# Get speech timestamps
|
| 113 |
+
speech_timestamps = get_speech_timestamps(
|
| 114 |
+
audio_tensor,
|
| 115 |
+
vad_model,
|
| 116 |
+
sampling_rate=sample_rate,
|
| 117 |
+
threshold=vad_threshold,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Convert timestamps to AudioSegment objects
|
| 121 |
+
voice_segments = []
|
| 122 |
+
for ts in speech_timestamps:
|
| 123 |
+
start_time = ts["start"] / sample_rate
|
| 124 |
+
end_time = ts["end"] / sample_rate
|
| 125 |
+
duration = end_time - start_time
|
| 126 |
+
|
| 127 |
+
# Filter by minimum duration
|
| 128 |
+
if duration >= min_segment_duration:
|
| 129 |
+
from src.models.audio_segment import AudioSegment, SegmentType
|
| 130 |
+
|
| 131 |
+
segment = AudioSegment(
|
| 132 |
+
start_time=start_time,
|
| 133 |
+
end_time=end_time,
|
| 134 |
+
speaker_id="voice",
|
| 135 |
+
confidence=1.0,
|
| 136 |
+
segment_type=SegmentType.SPEECH,
|
| 137 |
+
)
|
| 138 |
+
voice_segments.append(segment)
|
| 139 |
+
|
| 140 |
+
logger.debug(
|
| 141 |
+
f"Detected {len(voice_segments)} voice segments (min duration: {min_segment_duration}s)"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if not voice_segments:
|
| 145 |
+
logger.warning("No voice segments detected")
|
| 146 |
+
return np.array([], dtype=np.float32), {
|
| 147 |
+
"segments_kept": 0,
|
| 148 |
+
"segments_removed": 0,
|
| 149 |
+
"original_duration": original_duration,
|
| 150 |
+
"output_duration": 0.0,
|
| 151 |
+
"compression_ratio": 0.0,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
logger.info(f"Detected {len(voice_segments)} voice segments")
|
| 155 |
+
|
| 156 |
+
# Step 3: Filter segments by silence threshold (merge close segments)
|
| 157 |
+
sorted_segments = sorted(voice_segments, key=lambda s: s.start_time)
|
| 158 |
+
filtered = []
|
| 159 |
+
current_segment = sorted_segments[0]
|
| 160 |
+
|
| 161 |
+
for next_segment in sorted_segments[1:]:
|
| 162 |
+
gap = next_segment.start_time - current_segment.end_time
|
| 163 |
+
|
| 164 |
+
if gap <= silence_threshold:
|
| 165 |
+
# Merge segments
|
| 166 |
+
from src.models.audio_segment import AudioSegment, SegmentType
|
| 167 |
+
|
| 168 |
+
current_segment = AudioSegment(
|
| 169 |
+
start_time=current_segment.start_time,
|
| 170 |
+
end_time=next_segment.end_time,
|
| 171 |
+
speaker_id="voice",
|
| 172 |
+
confidence=1.0,
|
| 173 |
+
segment_type=SegmentType.SPEECH,
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
# Gap too large, keep current and move to next
|
| 177 |
+
filtered.append(current_segment)
|
| 178 |
+
current_segment = next_segment
|
| 179 |
+
|
| 180 |
+
# Add the last segment
|
| 181 |
+
filtered.append(current_segment)
|
| 182 |
+
|
| 183 |
+
segments_removed = len(voice_segments) - len(filtered)
|
| 184 |
+
logger.info(f"Kept {len(filtered)} segments, removed {segments_removed}")
|
| 185 |
+
|
| 186 |
+
if not filtered:
|
| 187 |
+
logger.warning("No segments remaining after silence removal")
|
| 188 |
+
return np.array([], dtype=np.float32), {
|
| 189 |
+
"segments_kept": 0,
|
| 190 |
+
"segments_removed": len(voice_segments),
|
| 191 |
+
"original_duration": original_duration,
|
| 192 |
+
"output_duration": 0.0,
|
| 193 |
+
"compression_ratio": 0.0,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# Step 4: Concatenate segments with crossfade
|
| 197 |
+
if progress_callback:
|
| 198 |
+
progress_callback(VOICE_DENOISING_STAGES[2], 0.75, 1.0) # "Concatenating segments"
|
| 199 |
+
|
| 200 |
+
logger.info("Concatenating segments...")
|
| 201 |
+
segment_arrays = []
|
| 202 |
+
for seg in filtered:
|
| 203 |
+
start_sample = int(seg.start_time * sample_rate)
|
| 204 |
+
end_sample = int(seg.end_time * sample_rate)
|
| 205 |
+
segment_audio = audio[start_sample:end_sample]
|
| 206 |
+
segment_arrays.append(segment_audio)
|
| 207 |
+
|
| 208 |
+
from src.services.audio_concatenation import AudioConcatenationUtility
|
| 209 |
+
|
| 210 |
+
concatenation_utility = AudioConcatenationUtility()
|
| 211 |
+
denoised_audio = concatenation_utility.concatenate_segments(
|
| 212 |
+
segment_arrays,
|
| 213 |
+
sample_rate,
|
| 214 |
+
silence_duration_ms=silence_ms,
|
| 215 |
+
crossfade_duration_ms=crossfade_ms,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
output_duration = len(denoised_audio) / sample_rate
|
| 219 |
+
compression_ratio = output_duration / original_duration if original_duration > 0 else 0.0
|
| 220 |
+
|
| 221 |
+
logger.info(
|
| 222 |
+
f"Denoising complete: {original_duration:.1f}s → {output_duration:.1f}s "
|
| 223 |
+
f"(compression: {compression_ratio:.1%})"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Generate report
|
| 227 |
+
report = {
|
| 228 |
+
"segments_kept": len(filtered),
|
| 229 |
+
"segments_removed": segments_removed,
|
| 230 |
+
"original_duration": original_duration,
|
| 231 |
+
"output_duration": output_duration,
|
| 232 |
+
"compression_ratio": compression_ratio,
|
| 233 |
+
"vad_threshold": vad_threshold,
|
| 234 |
+
"silence_threshold": silence_threshold,
|
| 235 |
+
"min_segment_duration": min_segment_duration,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
return denoised_audio, report
|
| 239 |
+
|
| 240 |
+
finally:
|
| 241 |
+
# Clean up
|
| 242 |
+
del vad_model
|
| 243 |
+
if torch.cuda.is_available():
|
| 244 |
+
torch.cuda.empty_cache()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
class VoiceDenoisingService:
|
| 248 |
"""
|
| 249 |
Service for removing silence and background noise from audio.
|
|
|
|
| 270 |
self.vad_threshold = vad_threshold
|
| 271 |
self.concatenation_utility = AudioConcatenationUtility()
|
| 272 |
|
| 273 |
+
logger.info(f"Voice denoising service initialized (VAD threshold: {vad_threshold})")
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
def denoise_audio(
|
| 276 |
self,
|
| 277 |
input_file: str,
|
|
|
|
| 317 |
}
|
| 318 |
return None, error_report
|
| 319 |
|
| 320 |
+
# Call module-level GPU function (avoids pickling self)
|
| 321 |
+
# Note: progress_callback cannot be passed due to pickling constraints
|
| 322 |
try:
|
| 323 |
+
denoised_audio, report = _denoise_audio_on_gpu(
|
| 324 |
+
audio=audio,
|
| 325 |
+
sample_rate=sample_rate,
|
| 326 |
+
vad_threshold=self.vad_threshold,
|
| 327 |
+
silence_threshold=silence_threshold,
|
| 328 |
+
min_segment_duration=min_segment_duration,
|
| 329 |
+
crossfade_ms=crossfade_ms,
|
| 330 |
+
silence_ms=silence_ms,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
)
|
| 332 |
|
| 333 |
+
# Add input_file to report
|
| 334 |
+
report["input_file"] = input_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
+
# Provide progress update after GPU processing completes
|
| 337 |
if progress_callback:
|
| 338 |
progress_callback("Complete", 1.0, 1.0)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
return denoised_audio, report
|
| 341 |
|
| 342 |
except Exception as e:
|
|
|
|
| 347 |
"error_type": "processing",
|
| 348 |
}
|
| 349 |
return None, error_report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
def detect_voice_segments(
|
| 352 |
self, audio: np.ndarray, sample_rate: int, min_duration: float = 0.5
|
|
@@ -101,6 +101,7 @@ def create_speaker_extraction_tab() -> gr.Tab:
|
|
| 101 |
progress(0.1, desc="Initializing...")
|
| 102 |
svc = get_service()
|
| 103 |
|
|
|
|
| 104 |
report = svc.extract_and_export(
|
| 105 |
reference_clip=reference_file,
|
| 106 |
target_file=target_file,
|
|
@@ -112,7 +113,7 @@ def create_speaker_extraction_tab() -> gr.Tab:
|
|
| 112 |
crossfade_duration_ms=crossfade_duration,
|
| 113 |
sample_rate=sample_rate,
|
| 114 |
bitrate=bitrate,
|
| 115 |
-
progress_callback=
|
| 116 |
)
|
| 117 |
|
| 118 |
# Check if result is an error report
|
|
|
|
| 101 |
progress(0.1, desc="Initializing...")
|
| 102 |
svc = get_service()
|
| 103 |
|
| 104 |
+
# Note: progress_callback cannot be passed due to ZeroGPU pickling constraints
|
| 105 |
report = svc.extract_and_export(
|
| 106 |
reference_clip=reference_file,
|
| 107 |
target_file=target_file,
|
|
|
|
| 113 |
crossfade_duration_ms=crossfade_duration,
|
| 114 |
sample_rate=sample_rate,
|
| 115 |
bitrate=bitrate,
|
| 116 |
+
progress_callback=None, # Cannot pass callback to avoid pickling errors
|
| 117 |
)
|
| 118 |
|
| 119 |
# Check if result is an error report
|
|
@@ -65,13 +65,14 @@ def process_denoising(
|
|
| 65 |
# Process audio
|
| 66 |
if progress:
|
| 67 |
progress(0.1, desc="Starting voice denoising...")
|
|
|
|
| 68 |
denoised_audio, report = service.denoise_audio(
|
| 69 |
input_audio,
|
| 70 |
silence_threshold=silence_threshold,
|
| 71 |
min_segment_duration=min_duration,
|
| 72 |
crossfade_ms=crossfade_ms,
|
| 73 |
silence_ms=silence_ms,
|
| 74 |
-
progress_callback=
|
| 75 |
)
|
| 76 |
|
| 77 |
# Check if result is an error report
|
|
|
|
| 65 |
# Process audio
|
| 66 |
if progress:
|
| 67 |
progress(0.1, desc="Starting voice denoising...")
|
| 68 |
+
# Note: progress_callback cannot be passed due to ZeroGPU pickling constraints
|
| 69 |
denoised_audio, report = service.denoise_audio(
|
| 70 |
input_audio,
|
| 71 |
silence_threshold=silence_threshold,
|
| 72 |
min_segment_duration=min_duration,
|
| 73 |
crossfade_ms=crossfade_ms,
|
| 74 |
silence_ms=silence_ms,
|
| 75 |
+
progress_callback=None, # Cannot pass callback to avoid pickling errors
|
| 76 |
)
|
| 77 |
|
| 78 |
# Check if result is an error report
|