Spaces:
Sleeping
Sleeping
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8 | """Speech-to-text module for REFRAME. | |
| Modular, fail-safe STT using Cohere Transcribe (primary) or Whisper (fallback). | |
| If dependencies are missing or model fails to load, the app continues without voice. | |
| """ | |
| import logging | |
| import os | |
| import config | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def _on_zerogpu() -> bool: | |
| """True on an HF ZeroGPU Space (the `spaces` package is present only there). | |
| On ZeroGPU we must target CUDA and load the model on `cuda` at startup — a | |
| CUDA-emulation layer makes that work without a real GPU, and the real GPU is | |
| used inside @spaces.GPU functions. | |
| """ | |
| if not os.environ.get("SPACE_ID"): | |
| return False | |
| try: | |
| import spaces # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| _model = None | |
| _processor = None | |
| _pipeline = None | |
| _fallback_pipeline = None | |
| _load_attempted = False | |
| _is_cohere = False | |
| _FALLBACK_MODEL = "openai/whisper-small" | |
| def is_available() -> bool: | |
| """Check if STT dependencies are installed.""" | |
| try: | |
| import torch # noqa: F401 | |
| import transformers # noqa: F401 | |
| return True | |
| except ImportError: | |
| return False | |
| def _get_device() -> str: | |
| """Best device: ZeroGPU cuda (forced) > xpu (Intel) > mps (Mac) > cuda > cpu.""" | |
| if _on_zerogpu(): | |
| # Load on cuda at module level (emulated); real GPU inside @spaces.GPU. | |
| # Don't call cuda.is_available() here — it's False at startup on ZeroGPU. | |
| return "cuda" | |
| try: | |
| import torch | |
| if hasattr(torch, "xpu") and torch.xpu.is_available(): | |
| return "xpu" | |
| if torch.backends.mps.is_available(): | |
| return "mps" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| except Exception: | |
| pass | |
| return "cpu" | |
| def _dtype_for(device: str): | |
| """Pick a sensible compute dtype for the device.""" | |
| import torch | |
| if device == "cpu": | |
| return torch.float32 | |
| if device in ("xpu", "cuda"): | |
| return torch.bfloat16 # checkpoint's native dtype; supported on Intel GPU + ZeroGPU | |
| return torch.float16 # mps | |
| def _load_whisper_fallback(device: str, dtype): | |
| """Load Whisper fallback pipeline once and reuse it.""" | |
| global _fallback_pipeline | |
| if _fallback_pipeline is not None: | |
| return | |
| from transformers import pipeline | |
| logger.warning(f"[STT] Falling back to {_FALLBACK_MODEL}") | |
| _fallback_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=_FALLBACK_MODEL, | |
| dtype=dtype, | |
| device=device, | |
| generate_kwargs={"language": "en", "task": "transcribe"}, | |
| ) | |
| def _run_pipeline(pipe, audio_filepath: str) -> str: | |
| """Run an ASR pipeline on a file path, decoding to a 16 kHz array first. | |
| Passing a bare file path can fail with 'Soundfile ... malformed' (no ffmpeg | |
| on PATH / unsupported WAV encoding); a decoded array via load_audio is | |
| reliable and is how the Cohere path already loads audio. | |
| """ | |
| from transformers.audio_utils import load_audio | |
| audio = load_audio(audio_filepath, sampling_rate=16000) | |
| result = pipe({"raw": audio, "sampling_rate": 16000}) | |
| return result.get("text", "") if isinstance(result, dict) else str(result) | |
| def _load_model(): | |
| """Load the ASR model. Retries on failure (only locks after success).""" | |
| global _model, _processor, _pipeline, _load_attempted, _is_cohere | |
| if _load_attempted: | |
| return | |
| try: | |
| import transformers | |
| device = _get_device() | |
| dtype = _dtype_for(device) | |
| _is_cohere = "cohere" in config.STT_MODEL.lower() | |
| if _is_cohere: | |
| major = int(transformers.__version__.split(".")[0]) | |
| if major < 5: | |
| logger.warning( | |
| "[STT] Cohere Transcribe needs transformers>=5.4.0. " | |
| f"Current version is {transformers.__version__}. Using Whisper fallback." | |
| ) | |
| _load_whisper_fallback(device=device, dtype=dtype) | |
| _load_attempted = True | |
| return | |
| logger.info(f"[STT] Loading model: {config.STT_MODEL} on {device} ({dtype})") | |
| if _is_cohere: | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| # Use transformers' built-in (native) cohere_asr implementation, NOT the | |
| # Hub remote code. The remote-code path (trust_remote_code=True) produces | |
| # garbage output on transformers>=5.3 and is deprecated by the model | |
| # authors; the native path is the recommended way going forward. | |
| # (See model discussion #28.) | |
| _processor = AutoProcessor.from_pretrained(config.STT_MODEL) | |
| _model = AutoModelForSpeechSeq2Seq.from_pretrained(config.STT_MODEL, dtype=dtype) | |
| _model = _model.to(device) | |
| else: | |
| from transformers import pipeline | |
| _pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=config.STT_MODEL, | |
| dtype=dtype, | |
| device=device, | |
| generate_kwargs={"language": "en", "task": "transcribe"}, | |
| ) | |
| _load_attempted = True # Only lock after success | |
| logger.info("[STT] Model loaded successfully.") | |
| except Exception: | |
| logger.exception("[STT] Failed to load model") | |
| _model = None | |
| _processor = None | |
| _pipeline = None | |
| def preload_model(): | |
| """Eagerly download and load the model at startup (blocks until done).""" | |
| global _pipeline, _model, _fallback_pipeline | |
| if not is_available(): | |
| logger.warning("[STT] Cannot preload — dependencies not available") | |
| return | |
| logger.info(f"[STT] Preloading model: {config.STT_MODEL}") | |
| _load_model() | |
| if _pipeline is not None or _model is not None or _fallback_pipeline is not None: | |
| logger.info("[STT] Model preloaded and ready.") | |
| else: | |
| logger.error("[STT] Preload failed — voice input will not work.") | |
| def transcribe(audio_filepath: str) -> str: | |
| """Transcribe audio file to text. | |
| Args: | |
| audio_filepath: Path to WAV/MP3 file from Gradio. | |
| Returns: | |
| Transcribed text, or empty string on any failure. | |
| """ | |
| if not audio_filepath: | |
| logger.warning("[STT] No audio filepath provided") | |
| return "" | |
| if not is_available(): | |
| logger.warning("[STT] Dependencies not available") | |
| return "" | |
| logger.info(f"[STT] Transcribing: {audio_filepath}") | |
| _load_model() | |
| try: | |
| if _is_cohere and _model is not None and _processor is not None: | |
| from transformers.audio_utils import load_audio | |
| audio = load_audio(audio_filepath, sampling_rate=16000) | |
| inputs = _processor(audio, sampling_rate=16000, return_tensors="pt", language="en") | |
| model_dtype = next(_model.parameters()).dtype | |
| inputs = inputs.to(device=_model.device, dtype=model_dtype) | |
| outputs = _model.generate(**inputs, max_new_tokens=256) | |
| # batch_decode returns one string per batch item; we have a batch of 1. | |
| text = _processor.batch_decode(outputs, skip_special_tokens=True) | |
| elif _fallback_pipeline is not None: | |
| text = _run_pipeline(_fallback_pipeline, audio_filepath) | |
| elif _pipeline is not None: | |
| text = _run_pipeline(_pipeline, audio_filepath) | |
| else: | |
| logger.error("[STT] Model not loaded — cannot transcribe") | |
| return "" | |
| # Some processors return a list of strings (one per batch item). | |
| if isinstance(text, (list, tuple)): | |
| text = text[0] if text else "" | |
| logger.info(f"[STT] Result: '{text.strip()}'") | |
| return text.strip() | |
| except Exception as e: | |
| logger.error(f"[STT] Transcription failed: {e}") | |
| if _is_cohere: | |
| try: | |
| device = _get_device() | |
| dtype = _dtype_for(device) | |
| _load_whisper_fallback(device=device, dtype=dtype) | |
| text = _run_pipeline(_fallback_pipeline, audio_filepath) | |
| logger.info(f"[STT] Fallback result: '{text.strip()}'") | |
| return text.strip() | |
| except Exception as fallback_err: | |
| logger.error(f"[STT] Whisper fallback failed: {fallback_err}") | |
| return "" | |