Spaces:
Sleeping
Sleeping
| import os | |
| # Workaround for PyTorch 2.6+ weights_only=True default. | |
| # pyannote VAD model checkpoints (used by WhisperX) contain omegaconf types | |
| # and other globals that are not in torch's safe-globals allowlist. | |
| # This env var tells PyTorch to fall back to weights_only=False when the | |
| # caller did not explicitly pass weights_only. The pyannote models are | |
| # published, trusted checkpoints. | |
| os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1") | |
| import whisperx | |
| import gc | |
| import torch | |
| _model = None | |
| _current_device = None | |
| def _get_model(device: str = None): | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| global _model, _current_device | |
| if _model is None or _current_device != device: | |
| _model = whisperx.load_model( | |
| "base", | |
| device=device, | |
| compute_type="int8", | |
| ) | |
| _current_device = device | |
| return _model | |
| def unload_model(): | |
| """Free WhisperX model from GPU memory to make room for other models.""" | |
| global _model, _current_device | |
| if _model is not None: | |
| del _model | |
| _model = None | |
| _current_device = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("[WhisperX] Model unloaded, GPU memory freed.") | |
| def transcribe_audio(audio_path: str, language: str | None = None, device: str = None) -> str: | |
| """ | |
| Transcribe audio file using WhisperX. | |
| Args: | |
| audio_path: Path to audio file (any format supported by ffmpeg). | |
| language: ISO 639-1 language code (e.g. "en", "ko", "ja"). | |
| None for auto-detection. | |
| device: "cuda" or "cpu". | |
| Returns: | |
| Transcribed text as a single string. | |
| """ | |
| model = _get_model(device) | |
| audio = whisperx.load_audio(audio_path) | |
| transcribe_kwargs = {"batch_size": 16} | |
| if language: | |
| transcribe_kwargs["language"] = language | |
| result = model.transcribe(audio, **transcribe_kwargs) | |
| segments = result.get("segments", []) | |
| text = " ".join(seg["text"].strip() for seg in segments if seg.get("text")) | |
| return text | |