File size: 8,394 Bytes
4ae4ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Speech-to-text module for REFRAME.

Modular, fail-safe STT using Cohere Transcribe (primary) or Whisper (fallback).
If dependencies are missing or model fails to load, the app continues without voice.
"""

import logging
import os

import config

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _on_zerogpu() -> bool:
    """True on an HF ZeroGPU Space (the `spaces` package is present only there).

    On ZeroGPU we must target CUDA and load the model on `cuda` at startup — a
    CUDA-emulation layer makes that work without a real GPU, and the real GPU is
    used inside @spaces.GPU functions.
    """
    if not os.environ.get("SPACE_ID"):
        return False
    try:
        import spaces  # noqa: F401
        return True
    except Exception:
        return False

_model = None
_processor = None
_pipeline = None
_fallback_pipeline = None
_load_attempted = False
_is_cohere = False
_FALLBACK_MODEL = "openai/whisper-small"


def is_available() -> bool:
    """Check if STT dependencies are installed."""
    try:
        import torch  # noqa: F401
        import transformers  # noqa: F401
        return True
    except ImportError:
        return False


def _get_device() -> str:
    """Best device: ZeroGPU cuda (forced) > xpu (Intel) > mps (Mac) > cuda > cpu."""
    if _on_zerogpu():
        # Load on cuda at module level (emulated); real GPU inside @spaces.GPU.
        # Don't call cuda.is_available() here — it's False at startup on ZeroGPU.
        return "cuda"
    try:
        import torch
        if hasattr(torch, "xpu") and torch.xpu.is_available():
            return "xpu"
        if torch.backends.mps.is_available():
            return "mps"
        if torch.cuda.is_available():
            return "cuda"
    except Exception:
        pass
    return "cpu"


def _dtype_for(device: str):
    """Pick a sensible compute dtype for the device."""
    import torch
    if device == "cpu":
        return torch.float32
    if device in ("xpu", "cuda"):
        return torch.bfloat16  # checkpoint's native dtype; supported on Intel GPU + ZeroGPU
    return torch.float16  # mps


def _load_whisper_fallback(device: str, dtype):
    """Load Whisper fallback pipeline once and reuse it."""
    global _fallback_pipeline
    if _fallback_pipeline is not None:
        return

    from transformers import pipeline

    logger.warning(f"[STT] Falling back to {_FALLBACK_MODEL}")
    _fallback_pipeline = pipeline(
        "automatic-speech-recognition",
        model=_FALLBACK_MODEL,
        dtype=dtype,
        device=device,
        generate_kwargs={"language": "en", "task": "transcribe"},
    )


def _run_pipeline(pipe, audio_filepath: str) -> str:
    """Run an ASR pipeline on a file path, decoding to a 16 kHz array first.

    Passing a bare file path can fail with 'Soundfile ... malformed' (no ffmpeg
    on PATH / unsupported WAV encoding); a decoded array via load_audio is
    reliable and is how the Cohere path already loads audio.
    """
    from transformers.audio_utils import load_audio

    audio = load_audio(audio_filepath, sampling_rate=16000)
    result = pipe({"raw": audio, "sampling_rate": 16000})
    return result.get("text", "") if isinstance(result, dict) else str(result)


def _load_model():
    """Load the ASR model. Retries on failure (only locks after success)."""
    global _model, _processor, _pipeline, _load_attempted, _is_cohere
    if _load_attempted:
        return

    try:
        import transformers

        device = _get_device()
        dtype = _dtype_for(device)
        _is_cohere = "cohere" in config.STT_MODEL.lower()

        if _is_cohere:
            major = int(transformers.__version__.split(".")[0])
            if major < 5:
                logger.warning(
                    "[STT] Cohere Transcribe needs transformers>=5.4.0. "
                    f"Current version is {transformers.__version__}. Using Whisper fallback."
                )
                _load_whisper_fallback(device=device, dtype=dtype)
                _load_attempted = True
                return

        logger.info(f"[STT] Loading model: {config.STT_MODEL} on {device} ({dtype})")

        if _is_cohere:
            from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

            # Use transformers' built-in (native) cohere_asr implementation, NOT the
            # Hub remote code. The remote-code path (trust_remote_code=True) produces
            # garbage output on transformers>=5.3 and is deprecated by the model
            # authors; the native path is the recommended way going forward.
            # (See model discussion #28.)
            _processor = AutoProcessor.from_pretrained(config.STT_MODEL)
            _model = AutoModelForSpeechSeq2Seq.from_pretrained(config.STT_MODEL, dtype=dtype)
            _model = _model.to(device)
        else:
            from transformers import pipeline

            _pipeline = pipeline(
                "automatic-speech-recognition",
                model=config.STT_MODEL,
                dtype=dtype,
                device=device,
                generate_kwargs={"language": "en", "task": "transcribe"},
            )

        _load_attempted = True  # Only lock after success
        logger.info("[STT] Model loaded successfully.")
    except Exception:
        logger.exception("[STT] Failed to load model")
        _model = None
        _processor = None
        _pipeline = None


def preload_model():
    """Eagerly download and load the model at startup (blocks until done)."""
    global _pipeline, _model, _fallback_pipeline
    if not is_available():
        logger.warning("[STT] Cannot preload — dependencies not available")
        return
    logger.info(f"[STT] Preloading model: {config.STT_MODEL}")
    _load_model()
    if _pipeline is not None or _model is not None or _fallback_pipeline is not None:
        logger.info("[STT] Model preloaded and ready.")
    else:
        logger.error("[STT] Preload failed — voice input will not work.")


def transcribe(audio_filepath: str) -> str:
    """Transcribe audio file to text.

    Args:
        audio_filepath: Path to WAV/MP3 file from Gradio.

    Returns:
        Transcribed text, or empty string on any failure.
    """
    if not audio_filepath:
        logger.warning("[STT] No audio filepath provided")
        return ""

    if not is_available():
        logger.warning("[STT] Dependencies not available")
        return ""

    logger.info(f"[STT] Transcribing: {audio_filepath}")
    _load_model()

    try:
        if _is_cohere and _model is not None and _processor is not None:
            from transformers.audio_utils import load_audio

            audio = load_audio(audio_filepath, sampling_rate=16000)
            inputs = _processor(audio, sampling_rate=16000, return_tensors="pt", language="en")
            model_dtype = next(_model.parameters()).dtype
            inputs = inputs.to(device=_model.device, dtype=model_dtype)
            outputs = _model.generate(**inputs, max_new_tokens=256)
            # batch_decode returns one string per batch item; we have a batch of 1.
            text = _processor.batch_decode(outputs, skip_special_tokens=True)
        elif _fallback_pipeline is not None:
            text = _run_pipeline(_fallback_pipeline, audio_filepath)
        elif _pipeline is not None:
            text = _run_pipeline(_pipeline, audio_filepath)
        else:
            logger.error("[STT] Model not loaded — cannot transcribe")
            return ""

        # Some processors return a list of strings (one per batch item).
        if isinstance(text, (list, tuple)):
            text = text[0] if text else ""
        logger.info(f"[STT] Result: '{text.strip()}'")
        return text.strip()
    except Exception as e:
        logger.error(f"[STT] Transcription failed: {e}")
        if _is_cohere:
            try:
                device = _get_device()
                dtype = _dtype_for(device)
                _load_whisper_fallback(device=device, dtype=dtype)
                text = _run_pipeline(_fallback_pipeline, audio_filepath)
                logger.info(f"[STT] Fallback result: '{text.strip()}'")
                return text.strip()
            except Exception as fallback_err:
                logger.error(f"[STT] Whisper fallback failed: {fallback_err}")
        return ""