File size: 5,459 Bytes
20e9692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Model lifecycle and device management for the VAD segmenter."""

import torch

from config import SEGMENTER_MODEL, DTYPE, IS_HF_SPACE, TORCH_COMPILE
from ..zero_gpu import ZERO_GPU_AVAILABLE, is_quota_exhausted, is_user_forced_cpu


# =============================================================================
# Model caches
# =============================================================================

_segmenter_cache = {"model": None, "processor": None, "loaded": False, "load_time": 0.0, "device": None}
_env_logged = False


def _log_env_once():
    """Log library and GPU versions once for debugging HF Space mismatches."""
    global _env_logged
    if _env_logged:
        return
    _env_logged = True
    try:
        import importlib.metadata as _im

        def _ver(pkg: str) -> str:
            try:
                return _im.version(pkg)
            except Exception:
                return "unknown"

        print(f"[ENV] torch={torch.__version__} cuda={torch.version.cuda} cudnn={torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'none'}")
        print(f"[ENV] transformers={_ver('transformers')} recitations_segmenter={_ver('recitations_segmenter')}")
        if torch.cuda.is_available():
            print(f"[ENV] GPU={torch.cuda.get_device_name(0)}")
    except Exception as e:
        print(f"[ENV] Failed to log env: {e}")


_TORCH_DTYPE = torch.float16 if DTYPE == "float16" else torch.float32


def _get_device_and_dtype():
    """Get the best available device and dtype."""
    if IS_HF_SPACE or ZERO_GPU_AVAILABLE:
        return torch.device("cpu"), _TORCH_DTYPE
    if torch.cuda.is_available():
        return torch.device("cuda"), _TORCH_DTYPE
    return torch.device("cpu"), _TORCH_DTYPE


def ensure_models_on_gpu(asr_model_name=None):
    """
    Move models to GPU. Call this INSIDE a GPU-decorated function
    after ZeroGPU lease is acquired.

    Args:
        asr_model_name: If provided, move only this ASR model to GPU.
            If None, skip ASR model movement (e.g. during VAD-only lease).

    Skips if quota exhausted or CUDA unavailable.
    Idempotent: checks current device before moving.

    Returns:
        float: Time in seconds spent moving models to GPU.
    """
    import time
    from ..alignment.phoneme_asr import move_phoneme_asr_to_gpu

    if is_user_forced_cpu() or is_quota_exhausted() or not torch.cuda.is_available():
        return 0.0

    device = torch.device("cuda")
    dtype = _TORCH_DTYPE
    move_start = time.time()

    # Move segmenter to GPU
    if _segmenter_cache["loaded"] and _segmenter_cache["model"] is not None:
        model = _segmenter_cache["model"]
        if next(model.parameters()).device.type != "cuda":
            print("[GPU] Moving segmenter to CUDA...")
            model.to(device, dtype=dtype)
            _segmenter_cache["model"] = model
            _segmenter_cache["device"] = "cuda"
            print("[GPU] Segmenter on CUDA")

    # Move phoneme ASR to GPU (only the requested model)
    if asr_model_name is not None:
        move_phoneme_asr_to_gpu(asr_model_name)

    return time.time() - move_start


def ensure_models_on_cpu():
    """
    Move all models back to CPU. Called when GPU lease fails or quota
    is exhausted so that CPU fallback inference can proceed.

    Idempotent: checks current device before moving.
    """
    from ..alignment.phoneme_asr import move_phoneme_asr_to_cpu

    device = torch.device("cpu")
    dtype = _TORCH_DTYPE

    # Move segmenter to CPU
    if _segmenter_cache["loaded"] and _segmenter_cache["model"] is not None:
        model = _segmenter_cache["model"]
        if next(model.parameters()).device.type != "cpu":
            print("[CPU] Moving segmenter to CPU...")
            model.to(device, dtype=dtype)
            _segmenter_cache["model"] = model
            _segmenter_cache["device"] = "cpu"
            print("[CPU] Segmenter on CPU")

    # Move phoneme ASR to CPU
    move_phoneme_asr_to_cpu()


def load_segmenter():
    """Load the VAD segmenter model on CPU. Returns (model, processor, load_time).

    Models are loaded once and cached. Use ensure_models_on_gpu()
    inside GPU-decorated functions to move to CUDA.
    """
    if _segmenter_cache["loaded"]:
        return _segmenter_cache["model"], _segmenter_cache["processor"], 0.0

    import time
    start_time = time.time()

    try:
        from transformers import AutoModelForAudioFrameClassification, AutoFeatureExtractor

        print(f"Loading segmenter: {SEGMENTER_MODEL}")
        device, dtype = _get_device_and_dtype()

        model = AutoModelForAudioFrameClassification.from_pretrained(SEGMENTER_MODEL)
        model.to(device, dtype=dtype)
        model.eval()
        if TORCH_COMPILE and not (IS_HF_SPACE or ZERO_GPU_AVAILABLE):
            model = torch.compile(model, mode="reduce-overhead")

        processor = AutoFeatureExtractor.from_pretrained(SEGMENTER_MODEL)

        load_time = time.time() - start_time
        _segmenter_cache["model"] = model
        _segmenter_cache["processor"] = processor
        _segmenter_cache["loaded"] = True
        _segmenter_cache["load_time"] = load_time
        _segmenter_cache["device"] = device.type

        print(f"Segmenter loaded on {device} in {load_time:.2f}s")
        return model, processor, load_time

    except Exception as e:
        print(f"Failed to load segmenter: {e}")
        return None, None, 0.0