Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Alias module to redirect whisper imports to whisperx. | |
| This allows OuteTTS to use whisperx instead of the standard whisper package. | |
| """ | |
| import sys | |
| import importlib.util | |
| def setup_whisper_alias(): | |
| """Setup alias so that 'import whisper' uses whisperx instead.""" | |
| try: | |
| # Check if whisperx is available | |
| whisperx_spec = importlib.util.find_spec("whisperx") | |
| if whisperx_spec is None: | |
| print("Warning: whisperx not found, falling back to regular whisper") | |
| return | |
| # Import whisperx | |
| import whisperx | |
| # Create a module wrapper that provides whisper-like interface | |
| class WhisperAlias: | |
| def __init__(self): | |
| self.model = whisperx.WhisperModel if hasattr(whisperx, 'WhisperModel') else None | |
| self.load_model = self._load_model | |
| def _load_model(self, name, **kwargs): | |
| """Load model with whisperx compatible interface.""" | |
| # Create WhisperX model instance | |
| device = "cuda" if kwargs.get("device", "auto") == "cuda" else "cpu" | |
| compute_type = "float16" if device == "cuda" else "int8" | |
| model = whisperx.load_model( | |
| name, | |
| device=device, | |
| compute_type=compute_type | |
| ) | |
| return WhisperXModelWrapper(model, device) | |
| class WhisperXModelWrapper: | |
| """Wrapper to make whisperx compatible with whisper interface.""" | |
| def __init__(self, model, device): | |
| self.model = model | |
| self.device = device | |
| def transcribe(self, audio, **kwargs): | |
| """Transcribe audio with whisper-compatible interface.""" | |
| # Store original word_timestamps setting | |
| original_word_timestamps = kwargs.get('word_timestamps', False) | |
| # Load audio if it's a file path | |
| if isinstance(audio, str): | |
| audio_data = whisperx.load_audio(audio) | |
| else: | |
| audio_data = audio | |
| # Use whisperx's transcribe method | |
| batch_size = kwargs.get('batch_size', 16) | |
| result = self.model.transcribe(audio_data, batch_size=batch_size) | |
| # If word timestamps are requested, perform alignment | |
| if original_word_timestamps and result.get("segments"): | |
| try: | |
| # Load alignment model | |
| model_a, metadata = whisperx.load_align_model( | |
| language_code=result.get("language", "en"), | |
| device=self.device | |
| ) | |
| # Align the segments | |
| result = whisperx.align( | |
| result["segments"], | |
| model_a, | |
| metadata, | |
| audio_data, | |
| self.device, | |
| return_char_alignments=False | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not perform alignment: {e}") | |
| # Continue without alignment | |
| # Ensure result format is compatible with whisper format | |
| if "segments" not in result: | |
| result["segments"] = [] | |
| # Ensure 'text' field exists - concatenate all segment texts | |
| if "text" not in result: | |
| result["text"] = " ".join([segment.get("text", "") for segment in result.get("segments", [])]) | |
| # Add words field to segments if word timestamps were requested | |
| for segment in result.get("segments", []): | |
| if original_word_timestamps and "words" not in segment: | |
| # If we don't have words but they were requested, create empty words list | |
| segment["words"] = [] | |
| return result | |
| # Create the alias module | |
| whisper_alias = WhisperAlias() | |
| # Add to sys.modules so 'import whisper' uses our alias | |
| sys.modules['whisper'] = whisper_alias | |
| print("✅ Successfully aliased whisper to whisperx") | |
| except ImportError as e: | |
| print(f"Warning: Could not setup whisper alias: {e}") | |
| print("Falling back to regular whisper (if available)") | |
| except Exception as e: | |
| print(f"Warning: Error setting up whisper alias: {e}") | |
| # Auto-setup when module is imported | |
| setup_whisper_alias() |