reframe / stt.py
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8
Raw
History Blame Contribute Delete
8.39 kB
"""Speech-to-text module for REFRAME.
Modular, fail-safe STT using Cohere Transcribe (primary) or Whisper (fallback).
If dependencies are missing or model fails to load, the app continues without voice.
"""
import logging
import os
import config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _on_zerogpu() -> bool:
"""True on an HF ZeroGPU Space (the `spaces` package is present only there).
On ZeroGPU we must target CUDA and load the model on `cuda` at startup — a
CUDA-emulation layer makes that work without a real GPU, and the real GPU is
used inside @spaces.GPU functions.
"""
if not os.environ.get("SPACE_ID"):
return False
try:
import spaces # noqa: F401
return True
except Exception:
return False
_model = None
_processor = None
_pipeline = None
_fallback_pipeline = None
_load_attempted = False
_is_cohere = False
_FALLBACK_MODEL = "openai/whisper-small"
def is_available() -> bool:
"""Check if STT dependencies are installed."""
try:
import torch # noqa: F401
import transformers # noqa: F401
return True
except ImportError:
return False
def _get_device() -> str:
"""Best device: ZeroGPU cuda (forced) > xpu (Intel) > mps (Mac) > cuda > cpu."""
if _on_zerogpu():
# Load on cuda at module level (emulated); real GPU inside @spaces.GPU.
# Don't call cuda.is_available() here — it's False at startup on ZeroGPU.
return "cuda"
try:
import torch
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
except Exception:
pass
return "cpu"
def _dtype_for(device: str):
"""Pick a sensible compute dtype for the device."""
import torch
if device == "cpu":
return torch.float32
if device in ("xpu", "cuda"):
return torch.bfloat16 # checkpoint's native dtype; supported on Intel GPU + ZeroGPU
return torch.float16 # mps
def _load_whisper_fallback(device: str, dtype):
"""Load Whisper fallback pipeline once and reuse it."""
global _fallback_pipeline
if _fallback_pipeline is not None:
return
from transformers import pipeline
logger.warning(f"[STT] Falling back to {_FALLBACK_MODEL}")
_fallback_pipeline = pipeline(
"automatic-speech-recognition",
model=_FALLBACK_MODEL,
dtype=dtype,
device=device,
generate_kwargs={"language": "en", "task": "transcribe"},
)
def _run_pipeline(pipe, audio_filepath: str) -> str:
"""Run an ASR pipeline on a file path, decoding to a 16 kHz array first.
Passing a bare file path can fail with 'Soundfile ... malformed' (no ffmpeg
on PATH / unsupported WAV encoding); a decoded array via load_audio is
reliable and is how the Cohere path already loads audio.
"""
from transformers.audio_utils import load_audio
audio = load_audio(audio_filepath, sampling_rate=16000)
result = pipe({"raw": audio, "sampling_rate": 16000})
return result.get("text", "") if isinstance(result, dict) else str(result)
def _load_model():
"""Load the ASR model. Retries on failure (only locks after success)."""
global _model, _processor, _pipeline, _load_attempted, _is_cohere
if _load_attempted:
return
try:
import transformers
device = _get_device()
dtype = _dtype_for(device)
_is_cohere = "cohere" in config.STT_MODEL.lower()
if _is_cohere:
major = int(transformers.__version__.split(".")[0])
if major < 5:
logger.warning(
"[STT] Cohere Transcribe needs transformers>=5.4.0. "
f"Current version is {transformers.__version__}. Using Whisper fallback."
)
_load_whisper_fallback(device=device, dtype=dtype)
_load_attempted = True
return
logger.info(f"[STT] Loading model: {config.STT_MODEL} on {device} ({dtype})")
if _is_cohere:
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Use transformers' built-in (native) cohere_asr implementation, NOT the
# Hub remote code. The remote-code path (trust_remote_code=True) produces
# garbage output on transformers>=5.3 and is deprecated by the model
# authors; the native path is the recommended way going forward.
# (See model discussion #28.)
_processor = AutoProcessor.from_pretrained(config.STT_MODEL)
_model = AutoModelForSpeechSeq2Seq.from_pretrained(config.STT_MODEL, dtype=dtype)
_model = _model.to(device)
else:
from transformers import pipeline
_pipeline = pipeline(
"automatic-speech-recognition",
model=config.STT_MODEL,
dtype=dtype,
device=device,
generate_kwargs={"language": "en", "task": "transcribe"},
)
_load_attempted = True # Only lock after success
logger.info("[STT] Model loaded successfully.")
except Exception:
logger.exception("[STT] Failed to load model")
_model = None
_processor = None
_pipeline = None
def preload_model():
"""Eagerly download and load the model at startup (blocks until done)."""
global _pipeline, _model, _fallback_pipeline
if not is_available():
logger.warning("[STT] Cannot preload — dependencies not available")
return
logger.info(f"[STT] Preloading model: {config.STT_MODEL}")
_load_model()
if _pipeline is not None or _model is not None or _fallback_pipeline is not None:
logger.info("[STT] Model preloaded and ready.")
else:
logger.error("[STT] Preload failed — voice input will not work.")
def transcribe(audio_filepath: str) -> str:
"""Transcribe audio file to text.
Args:
audio_filepath: Path to WAV/MP3 file from Gradio.
Returns:
Transcribed text, or empty string on any failure.
"""
if not audio_filepath:
logger.warning("[STT] No audio filepath provided")
return ""
if not is_available():
logger.warning("[STT] Dependencies not available")
return ""
logger.info(f"[STT] Transcribing: {audio_filepath}")
_load_model()
try:
if _is_cohere and _model is not None and _processor is not None:
from transformers.audio_utils import load_audio
audio = load_audio(audio_filepath, sampling_rate=16000)
inputs = _processor(audio, sampling_rate=16000, return_tensors="pt", language="en")
model_dtype = next(_model.parameters()).dtype
inputs = inputs.to(device=_model.device, dtype=model_dtype)
outputs = _model.generate(**inputs, max_new_tokens=256)
# batch_decode returns one string per batch item; we have a batch of 1.
text = _processor.batch_decode(outputs, skip_special_tokens=True)
elif _fallback_pipeline is not None:
text = _run_pipeline(_fallback_pipeline, audio_filepath)
elif _pipeline is not None:
text = _run_pipeline(_pipeline, audio_filepath)
else:
logger.error("[STT] Model not loaded — cannot transcribe")
return ""
# Some processors return a list of strings (one per batch item).
if isinstance(text, (list, tuple)):
text = text[0] if text else ""
logger.info(f"[STT] Result: '{text.strip()}'")
return text.strip()
except Exception as e:
logger.error(f"[STT] Transcription failed: {e}")
if _is_cohere:
try:
device = _get_device()
dtype = _dtype_for(device)
_load_whisper_fallback(device=device, dtype=dtype)
text = _run_pipeline(_fallback_pipeline, audio_filepath)
logger.info(f"[STT] Fallback result: '{text.strip()}'")
return text.strip()
except Exception as fallback_err:
logger.error(f"[STT] Whisper fallback failed: {fallback_err}")
return ""