Spaces:
Sleeping
Sleeping
File size: 8,394 Bytes
4ae4ae8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | """Speech-to-text module for REFRAME.
Modular, fail-safe STT using Cohere Transcribe (primary) or Whisper (fallback).
If dependencies are missing or model fails to load, the app continues without voice.
"""
import logging
import os
import config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _on_zerogpu() -> bool:
"""True on an HF ZeroGPU Space (the `spaces` package is present only there).
On ZeroGPU we must target CUDA and load the model on `cuda` at startup — a
CUDA-emulation layer makes that work without a real GPU, and the real GPU is
used inside @spaces.GPU functions.
"""
if not os.environ.get("SPACE_ID"):
return False
try:
import spaces # noqa: F401
return True
except Exception:
return False
_model = None
_processor = None
_pipeline = None
_fallback_pipeline = None
_load_attempted = False
_is_cohere = False
_FALLBACK_MODEL = "openai/whisper-small"
def is_available() -> bool:
"""Check if STT dependencies are installed."""
try:
import torch # noqa: F401
import transformers # noqa: F401
return True
except ImportError:
return False
def _get_device() -> str:
"""Best device: ZeroGPU cuda (forced) > xpu (Intel) > mps (Mac) > cuda > cpu."""
if _on_zerogpu():
# Load on cuda at module level (emulated); real GPU inside @spaces.GPU.
# Don't call cuda.is_available() here — it's False at startup on ZeroGPU.
return "cuda"
try:
import torch
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
except Exception:
pass
return "cpu"
def _dtype_for(device: str):
"""Pick a sensible compute dtype for the device."""
import torch
if device == "cpu":
return torch.float32
if device in ("xpu", "cuda"):
return torch.bfloat16 # checkpoint's native dtype; supported on Intel GPU + ZeroGPU
return torch.float16 # mps
def _load_whisper_fallback(device: str, dtype):
"""Load Whisper fallback pipeline once and reuse it."""
global _fallback_pipeline
if _fallback_pipeline is not None:
return
from transformers import pipeline
logger.warning(f"[STT] Falling back to {_FALLBACK_MODEL}")
_fallback_pipeline = pipeline(
"automatic-speech-recognition",
model=_FALLBACK_MODEL,
dtype=dtype,
device=device,
generate_kwargs={"language": "en", "task": "transcribe"},
)
def _run_pipeline(pipe, audio_filepath: str) -> str:
"""Run an ASR pipeline on a file path, decoding to a 16 kHz array first.
Passing a bare file path can fail with 'Soundfile ... malformed' (no ffmpeg
on PATH / unsupported WAV encoding); a decoded array via load_audio is
reliable and is how the Cohere path already loads audio.
"""
from transformers.audio_utils import load_audio
audio = load_audio(audio_filepath, sampling_rate=16000)
result = pipe({"raw": audio, "sampling_rate": 16000})
return result.get("text", "") if isinstance(result, dict) else str(result)
def _load_model():
"""Load the ASR model. Retries on failure (only locks after success)."""
global _model, _processor, _pipeline, _load_attempted, _is_cohere
if _load_attempted:
return
try:
import transformers
device = _get_device()
dtype = _dtype_for(device)
_is_cohere = "cohere" in config.STT_MODEL.lower()
if _is_cohere:
major = int(transformers.__version__.split(".")[0])
if major < 5:
logger.warning(
"[STT] Cohere Transcribe needs transformers>=5.4.0. "
f"Current version is {transformers.__version__}. Using Whisper fallback."
)
_load_whisper_fallback(device=device, dtype=dtype)
_load_attempted = True
return
logger.info(f"[STT] Loading model: {config.STT_MODEL} on {device} ({dtype})")
if _is_cohere:
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Use transformers' built-in (native) cohere_asr implementation, NOT the
# Hub remote code. The remote-code path (trust_remote_code=True) produces
# garbage output on transformers>=5.3 and is deprecated by the model
# authors; the native path is the recommended way going forward.
# (See model discussion #28.)
_processor = AutoProcessor.from_pretrained(config.STT_MODEL)
_model = AutoModelForSpeechSeq2Seq.from_pretrained(config.STT_MODEL, dtype=dtype)
_model = _model.to(device)
else:
from transformers import pipeline
_pipeline = pipeline(
"automatic-speech-recognition",
model=config.STT_MODEL,
dtype=dtype,
device=device,
generate_kwargs={"language": "en", "task": "transcribe"},
)
_load_attempted = True # Only lock after success
logger.info("[STT] Model loaded successfully.")
except Exception:
logger.exception("[STT] Failed to load model")
_model = None
_processor = None
_pipeline = None
def preload_model():
"""Eagerly download and load the model at startup (blocks until done)."""
global _pipeline, _model, _fallback_pipeline
if not is_available():
logger.warning("[STT] Cannot preload — dependencies not available")
return
logger.info(f"[STT] Preloading model: {config.STT_MODEL}")
_load_model()
if _pipeline is not None or _model is not None or _fallback_pipeline is not None:
logger.info("[STT] Model preloaded and ready.")
else:
logger.error("[STT] Preload failed — voice input will not work.")
def transcribe(audio_filepath: str) -> str:
"""Transcribe audio file to text.
Args:
audio_filepath: Path to WAV/MP3 file from Gradio.
Returns:
Transcribed text, or empty string on any failure.
"""
if not audio_filepath:
logger.warning("[STT] No audio filepath provided")
return ""
if not is_available():
logger.warning("[STT] Dependencies not available")
return ""
logger.info(f"[STT] Transcribing: {audio_filepath}")
_load_model()
try:
if _is_cohere and _model is not None and _processor is not None:
from transformers.audio_utils import load_audio
audio = load_audio(audio_filepath, sampling_rate=16000)
inputs = _processor(audio, sampling_rate=16000, return_tensors="pt", language="en")
model_dtype = next(_model.parameters()).dtype
inputs = inputs.to(device=_model.device, dtype=model_dtype)
outputs = _model.generate(**inputs, max_new_tokens=256)
# batch_decode returns one string per batch item; we have a batch of 1.
text = _processor.batch_decode(outputs, skip_special_tokens=True)
elif _fallback_pipeline is not None:
text = _run_pipeline(_fallback_pipeline, audio_filepath)
elif _pipeline is not None:
text = _run_pipeline(_pipeline, audio_filepath)
else:
logger.error("[STT] Model not loaded — cannot transcribe")
return ""
# Some processors return a list of strings (one per batch item).
if isinstance(text, (list, tuple)):
text = text[0] if text else ""
logger.info(f"[STT] Result: '{text.strip()}'")
return text.strip()
except Exception as e:
logger.error(f"[STT] Transcription failed: {e}")
if _is_cohere:
try:
device = _get_device()
dtype = _dtype_for(device)
_load_whisper_fallback(device=device, dtype=dtype)
text = _run_pipeline(_fallback_pipeline, audio_filepath)
logger.info(f"[STT] Fallback result: '{text.strip()}'")
return text.strip()
except Exception as fallback_err:
logger.error(f"[STT] Whisper fallback failed: {fallback_err}")
return ""
|