Spaces:
Sleeping
Sleeping
File size: 2,139 Bytes
f2532fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | import os
# Workaround for PyTorch 2.6+ weights_only=True default.
# pyannote VAD model checkpoints (used by WhisperX) contain omegaconf types
# and other globals that are not in torch's safe-globals allowlist.
# This env var tells PyTorch to fall back to weights_only=False when the
# caller did not explicitly pass weights_only. The pyannote models are
# published, trusted checkpoints.
os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
import whisperx
import gc
import torch
_model = None
_current_device = None
def _get_model(device: str = None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
global _model, _current_device
if _model is None or _current_device != device:
_model = whisperx.load_model(
"base",
device=device,
compute_type="int8",
)
_current_device = device
return _model
def unload_model():
"""Free WhisperX model from GPU memory to make room for other models."""
global _model, _current_device
if _model is not None:
del _model
_model = None
_current_device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("[WhisperX] Model unloaded, GPU memory freed.")
def transcribe_audio(audio_path: str, language: str | None = None, device: str = None) -> str:
"""
Transcribe audio file using WhisperX.
Args:
audio_path: Path to audio file (any format supported by ffmpeg).
language: ISO 639-1 language code (e.g. "en", "ko", "ja").
None for auto-detection.
device: "cuda" or "cpu".
Returns:
Transcribed text as a single string.
"""
model = _get_model(device)
audio = whisperx.load_audio(audio_path)
transcribe_kwargs = {"batch_size": 16}
if language:
transcribe_kwargs["language"] = language
result = model.transcribe(audio, **transcribe_kwargs)
segments = result.get("segments", [])
text = " ".join(seg["text"].strip() for seg in segments if seg.get("text"))
return text
|