streaming / server_wrapper.py
rjzevallos's picture
Feat(server): support WHISPER_MODEL_NAME/WHISPER_MODEL_SIZE (e.g. tiny); prefer local <name>.pt; improve error guidance
88e4729
import io
import threading
from types import SimpleNamespace
import numpy as np
import soundfile as sf
import librosa
from simulstreaming_whisper import simul_asr_factory
_lock = threading.Lock()
_initialized = False
_asr = None
_online = None
def _get_model_path():
"""Get the path to the Whisper model.
Behavior:
- Prefer `WHISPER_MODEL_PATH` env var if provided.
- Otherwise prefer `./large-v3.pt` (repo-local file) or cached `~/.cache/whisper/large-v3.pt`.
- Do NOT attempt to download the model automatically (downloading at runtime can hang Spaces).
- If not found, raise FileNotFoundError with guidance.
"""
import os
# allow user to override with env var path
env_path = os.environ.get('WHISPER_MODEL_PATH') or os.environ.get('MODEL_PATH')
if env_path:
if os.path.exists(env_path):
return env_path
else:
raise FileNotFoundError(f"WHISPER_MODEL_PATH is set but file not found: {env_path}")
# allow user to request a model name/size (e.g. 'tiny', 'base', 'large-v3')
model_name = os.environ.get('WHISPER_MODEL_NAME') or os.environ.get('WHISPER_MODEL_SIZE') or 'large-v3'
# check local repo file first (e.g. ./tiny.pt or ./large-v3.pt)
local_path = f'./{model_name}.pt'
if os.path.exists(local_path):
return local_path
# check cache path (pre-downloaded by build or other process)
model_dir = os.path.expanduser('~/.cache/whisper')
model_path = os.path.join(model_dir, f'{model_name}.pt')
if os.path.exists(model_path):
return model_path
# Do not attempt to download automatically in runtime.
raise FileNotFoundError(
'Whisper model not found. Set WHISPER_MODEL_PATH to a local model file, or set WHISPER_MODEL_NAME to a model name (e.g. tiny) and pre-download the corresponding "<name>.pt" file into the repo or ~/.cache/whisper/.'
)
def _make_args():
# Minimal args required by simul_asr_factory
return SimpleNamespace(
log_level='INFO',
decoder=None,
beams=1,
model_path=_get_model_path(),
cif_ckpt_path=None,
frame_threshold=25,
audio_min_len=0.0,
audio_max_len=30.0,
task='transcribe',
never_fire=False,
init_prompt=None,
static_init_prompt=None,
max_context_tokens=None,
logdir=None,
lan='en',
min_chunk_size=1.2,
vac=False,
vac_chunk_size=0.04,
)
def init_model():
global _initialized, _asr, _online
with _lock:
if _initialized:
return
try:
args = _make_args()
_asr, _online = simul_asr_factory(args)
_initialized = True
except FileNotFoundError as e:
print(f"Model initialization aborted: {e}")
# leave _initialized False so callers know model not ready
except Exception as e:
print(f"Unexpected error initializing model: {e}")
# don't raise here; allow the app to continue running without model
def reset():
global _online
with _lock:
if _online is None:
raise RuntimeError("Model not initialized")
_online.init()
def _read_audio_bytes(raw_bytes):
# Try to read with soundfile; fallback to librosa
bio = io.BytesIO(raw_bytes)
try:
data, sr = sf.read(bio, dtype='float32')
except Exception:
bio.seek(0)
data, sr = librosa.load(bio, sr=None, mono=True)
if data.ndim > 1:
data = np.mean(data, axis=1)
if sr != 16000:
data = librosa.resample(data, orig_sr=sr, target_sr=16000)
sr = 16000
# ensure float32
data = data.astype(np.float32)
return data
def process_chunk_from_bytes(raw_bytes):
"""Insert audio chunk and run one processing iteration. Returns the JSON-able result."""
global _online
if _online is None:
raise RuntimeError("Model not initialized")
audio = _read_audio_bytes(raw_bytes)
with _lock:
_online.insert_audio_chunk(audio)
out = _online.process_iter()
return out or {}
def finish():
global _online
if _online is None:
raise RuntimeError("Model not initialized")
with _lock:
out = _online.finish()
return out or {}