File size: 2,139 Bytes
f2532fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os

# Workaround for PyTorch 2.6+ weights_only=True default.
# pyannote VAD model checkpoints (used by WhisperX) contain omegaconf types
# and other globals that are not in torch's safe-globals allowlist.
# This env var tells PyTorch to fall back to weights_only=False when the
# caller did not explicitly pass weights_only.  The pyannote models are
# published, trusted checkpoints.
os.environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")

import whisperx

import gc
import torch

_model = None
_current_device = None


def _get_model(device: str = None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    global _model, _current_device
    if _model is None or _current_device != device:
        _model = whisperx.load_model(
            "base",
            device=device,
            compute_type="int8",
        )
        _current_device = device
    return _model


def unload_model():
    """Free WhisperX model from GPU memory to make room for other models."""
    global _model, _current_device
    if _model is not None:
        del _model
        _model = None
        _current_device = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("[WhisperX] Model unloaded, GPU memory freed.")


def transcribe_audio(audio_path: str, language: str | None = None, device: str = None) -> str:
    """
    Transcribe audio file using WhisperX.

    Args:
        audio_path: Path to audio file (any format supported by ffmpeg).
        language: ISO 639-1 language code (e.g. "en", "ko", "ja").
                  None for auto-detection.
        device: "cuda" or "cpu".

    Returns:
        Transcribed text as a single string.
    """
    model = _get_model(device)
    audio = whisperx.load_audio(audio_path)

    transcribe_kwargs = {"batch_size": 16}
    if language:
        transcribe_kwargs["language"] = language

    result = model.transcribe(audio, **transcribe_kwargs)

    segments = result.get("segments", [])
    text = " ".join(seg["text"].strip() for seg in segments if seg.get("text"))
    return text