"""Speech-to-text module for REFRAME. Modular, fail-safe STT using Cohere Transcribe (primary) or Whisper (fallback). If dependencies are missing or model fails to load, the app continues without voice. """ import logging import os import config logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def _on_zerogpu() -> bool: """True on an HF ZeroGPU Space (the `spaces` package is present only there). On ZeroGPU we must target CUDA and load the model on `cuda` at startup — a CUDA-emulation layer makes that work without a real GPU, and the real GPU is used inside @spaces.GPU functions. """ if not os.environ.get("SPACE_ID"): return False try: import spaces # noqa: F401 return True except Exception: return False _model = None _processor = None _pipeline = None _fallback_pipeline = None _load_attempted = False _is_cohere = False _FALLBACK_MODEL = "openai/whisper-small" def is_available() -> bool: """Check if STT dependencies are installed.""" try: import torch # noqa: F401 import transformers # noqa: F401 return True except ImportError: return False def _get_device() -> str: """Best device: ZeroGPU cuda (forced) > xpu (Intel) > mps (Mac) > cuda > cpu.""" if _on_zerogpu(): # Load on cuda at module level (emulated); real GPU inside @spaces.GPU. # Don't call cuda.is_available() here — it's False at startup on ZeroGPU. return "cuda" try: import torch if hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" if torch.backends.mps.is_available(): return "mps" if torch.cuda.is_available(): return "cuda" except Exception: pass return "cpu" def _dtype_for(device: str): """Pick a sensible compute dtype for the device.""" import torch if device == "cpu": return torch.float32 if device in ("xpu", "cuda"): return torch.bfloat16 # checkpoint's native dtype; supported on Intel GPU + ZeroGPU return torch.float16 # mps def _load_whisper_fallback(device: str, dtype): """Load Whisper fallback pipeline once and reuse it.""" global _fallback_pipeline if _fallback_pipeline is not None: return from transformers import pipeline logger.warning(f"[STT] Falling back to {_FALLBACK_MODEL}") _fallback_pipeline = pipeline( "automatic-speech-recognition", model=_FALLBACK_MODEL, dtype=dtype, device=device, generate_kwargs={"language": "en", "task": "transcribe"}, ) def _run_pipeline(pipe, audio_filepath: str) -> str: """Run an ASR pipeline on a file path, decoding to a 16 kHz array first. Passing a bare file path can fail with 'Soundfile ... malformed' (no ffmpeg on PATH / unsupported WAV encoding); a decoded array via load_audio is reliable and is how the Cohere path already loads audio. """ from transformers.audio_utils import load_audio audio = load_audio(audio_filepath, sampling_rate=16000) result = pipe({"raw": audio, "sampling_rate": 16000}) return result.get("text", "") if isinstance(result, dict) else str(result) def _load_model(): """Load the ASR model. Retries on failure (only locks after success).""" global _model, _processor, _pipeline, _load_attempted, _is_cohere if _load_attempted: return try: import transformers device = _get_device() dtype = _dtype_for(device) _is_cohere = "cohere" in config.STT_MODEL.lower() if _is_cohere: major = int(transformers.__version__.split(".")[0]) if major < 5: logger.warning( "[STT] Cohere Transcribe needs transformers>=5.4.0. " f"Current version is {transformers.__version__}. Using Whisper fallback." ) _load_whisper_fallback(device=device, dtype=dtype) _load_attempted = True return logger.info(f"[STT] Loading model: {config.STT_MODEL} on {device} ({dtype})") if _is_cohere: from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq # Use transformers' built-in (native) cohere_asr implementation, NOT the # Hub remote code. The remote-code path (trust_remote_code=True) produces # garbage output on transformers>=5.3 and is deprecated by the model # authors; the native path is the recommended way going forward. # (See model discussion #28.) _processor = AutoProcessor.from_pretrained(config.STT_MODEL) _model = AutoModelForSpeechSeq2Seq.from_pretrained(config.STT_MODEL, dtype=dtype) _model = _model.to(device) else: from transformers import pipeline _pipeline = pipeline( "automatic-speech-recognition", model=config.STT_MODEL, dtype=dtype, device=device, generate_kwargs={"language": "en", "task": "transcribe"}, ) _load_attempted = True # Only lock after success logger.info("[STT] Model loaded successfully.") except Exception: logger.exception("[STT] Failed to load model") _model = None _processor = None _pipeline = None def preload_model(): """Eagerly download and load the model at startup (blocks until done).""" global _pipeline, _model, _fallback_pipeline if not is_available(): logger.warning("[STT] Cannot preload — dependencies not available") return logger.info(f"[STT] Preloading model: {config.STT_MODEL}") _load_model() if _pipeline is not None or _model is not None or _fallback_pipeline is not None: logger.info("[STT] Model preloaded and ready.") else: logger.error("[STT] Preload failed — voice input will not work.") def transcribe(audio_filepath: str) -> str: """Transcribe audio file to text. Args: audio_filepath: Path to WAV/MP3 file from Gradio. Returns: Transcribed text, or empty string on any failure. """ if not audio_filepath: logger.warning("[STT] No audio filepath provided") return "" if not is_available(): logger.warning("[STT] Dependencies not available") return "" logger.info(f"[STT] Transcribing: {audio_filepath}") _load_model() try: if _is_cohere and _model is not None and _processor is not None: from transformers.audio_utils import load_audio audio = load_audio(audio_filepath, sampling_rate=16000) inputs = _processor(audio, sampling_rate=16000, return_tensors="pt", language="en") model_dtype = next(_model.parameters()).dtype inputs = inputs.to(device=_model.device, dtype=model_dtype) outputs = _model.generate(**inputs, max_new_tokens=256) # batch_decode returns one string per batch item; we have a batch of 1. text = _processor.batch_decode(outputs, skip_special_tokens=True) elif _fallback_pipeline is not None: text = _run_pipeline(_fallback_pipeline, audio_filepath) elif _pipeline is not None: text = _run_pipeline(_pipeline, audio_filepath) else: logger.error("[STT] Model not loaded — cannot transcribe") return "" # Some processors return a list of strings (one per batch item). if isinstance(text, (list, tuple)): text = text[0] if text else "" logger.info(f"[STT] Result: '{text.strip()}'") return text.strip() except Exception as e: logger.error(f"[STT] Transcription failed: {e}") if _is_cohere: try: device = _get_device() dtype = _dtype_for(device) _load_whisper_fallback(device=device, dtype=dtype) text = _run_pipeline(_fallback_pipeline, audio_filepath) logger.info(f"[STT] Fallback result: '{text.strip()}'") return text.strip() except Exception as fallback_err: logger.error(f"[STT] Whisper fallback failed: {fallback_err}") return ""