SenseVoiceSmall / app.py
hz6666's picture
Fix: API error
a023c4b
import base64
import io
import os
from pathlib import Path
from typing import Tuple
import gradio as gr
import numpy as np
import soundfile as sf
# Global model handle (lazy-loaded)
_model = None
def _download_or_resolve_model() -> str:
"""Ensure model exists in a writable cache dir and return local path.
Uses ModelScope's default cache (~/.cache/modelscope/hub) which is
writable on Hugging Face Spaces.
"""
try:
from modelscope.hub.snapshot_download import snapshot_download
cache_dir = Path.home() / ".cache" / "modelscope" / "hub"
cache_dir.mkdir(parents=True, exist_ok=True)
model_path = snapshot_download(
model_id="iic/SenseVoiceSmall",
cache_dir=str(cache_dir),
revision="master",
)
return str(model_path)
except Exception as e:
# If anything goes wrong, fall back to a conventional project path
# (still try to keep it writable on Spaces)
fallback = str(Path.home() / "models" / "SenseVoiceSmall")
os.makedirs(fallback, exist_ok=True)
print(f"[WARN] Model download failed, using fallback dir: {e}")
return fallback
def _load_model():
global _model
if _model is not None:
return _model
model_path = _download_or_resolve_model()
print(f"[INIT] Loading SenseVoice model from: {model_path}")
from funasr import AutoModel
_model = AutoModel(model=model_path, trust_remote_code=True)
print("[INIT] SenseVoice model loaded")
return _model
def _decode_audio_b64(b64_data: str) -> Tuple[np.ndarray, int]:
"""Decode base64-encoded audio (wav/ogg/opus) into mono float32 PCM and sample rate."""
audio_bytes = base64.b64decode(b64_data)
with sf.SoundFile(io.BytesIO(audio_bytes)) as f:
audio = f.read(dtype='float32')
sr = f.samplerate
# Convert to mono if multi-channel
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
return audio, sr
def transcribe_audio(gr_audio) -> str:
"""Gradio fn: accepts audio either from microphone (temp file) or from base64 JSON.
- If gradio mic/upload is used, `gr_audio` is (sr, numpy.ndarray)
- If a dict with {"name":"...","data":"<base64>"} is posted via /api/predict,
handle that path too.
"""
try:
model = _load_model()
# Case 1: standard Gradio input: (sample_rate, np.ndarray)
if isinstance(gr_audio, tuple) and len(gr_audio) == 2:
sr, audio = gr_audio
if audio is None or len(audio) == 0:
return "No audio received"
# funasr expects file path or raw array; we'll save temp wav for simplicity
wav_path = "_tmp.wav"
sf.write(wav_path, audio, sr)
result = model.generate(input=wav_path)
text = result[0]["text"] if isinstance(result, list) else str(result)
try:
os.remove(wav_path)
except Exception:
pass
return text
# Case 2: API style: dict with base64
if isinstance(gr_audio, dict) and "data" in gr_audio:
try:
audio, sr = _decode_audio_b64(gr_audio["data"])
wav_path = "_tmp.wav"
sf.write(wav_path, audio, sr)
result = model.generate(input=wav_path)
text = result[0]["text"] if isinstance(result, list) else str(result)
try:
os.remove(wav_path)
except Exception:
pass
return text
except Exception as e:
return f"Failed to decode/process audio: {e}"
return "Unsupported input format"
except Exception as e:
return f"Error during transcription: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("""
# SenseVoiceSmall ASR (Gradio)
Upload a short audio file or record via microphone to get a transcript.
""")
with gr.Row():
audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio")
with gr.Row():
out = gr.Textbox(label="Transcript")
btn = gr.Button("Transcribe")
btn.click(fn=transcribe_audio, inputs=audio, outputs=out, api_name="predict")
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))