import base64 import io import os from pathlib import Path from typing import Tuple import gradio as gr import numpy as np import soundfile as sf # Global model handle (lazy-loaded) _model = None def _download_or_resolve_model() -> str: """Ensure model exists in a writable cache dir and return local path. Uses ModelScope's default cache (~/.cache/modelscope/hub) which is writable on Hugging Face Spaces. """ try: from modelscope.hub.snapshot_download import snapshot_download cache_dir = Path.home() / ".cache" / "modelscope" / "hub" cache_dir.mkdir(parents=True, exist_ok=True) model_path = snapshot_download( model_id="iic/SenseVoiceSmall", cache_dir=str(cache_dir), revision="master", ) return str(model_path) except Exception as e: # If anything goes wrong, fall back to a conventional project path # (still try to keep it writable on Spaces) fallback = str(Path.home() / "models" / "SenseVoiceSmall") os.makedirs(fallback, exist_ok=True) print(f"[WARN] Model download failed, using fallback dir: {e}") return fallback def _load_model(): global _model if _model is not None: return _model model_path = _download_or_resolve_model() print(f"[INIT] Loading SenseVoice model from: {model_path}") from funasr import AutoModel _model = AutoModel(model=model_path, trust_remote_code=True) print("[INIT] SenseVoice model loaded") return _model def _decode_audio_b64(b64_data: str) -> Tuple[np.ndarray, int]: """Decode base64-encoded audio (wav/ogg/opus) into mono float32 PCM and sample rate.""" audio_bytes = base64.b64decode(b64_data) with sf.SoundFile(io.BytesIO(audio_bytes)) as f: audio = f.read(dtype='float32') sr = f.samplerate # Convert to mono if multi-channel if audio.ndim > 1: audio = np.mean(audio, axis=1) return audio, sr def transcribe_audio(gr_audio) -> str: """Gradio fn: accepts audio either from microphone (temp file) or from base64 JSON. - If gradio mic/upload is used, `gr_audio` is (sr, numpy.ndarray) - If a dict with {"name":"...","data":""} is posted via /api/predict, handle that path too. """ try: model = _load_model() # Case 1: standard Gradio input: (sample_rate, np.ndarray) if isinstance(gr_audio, tuple) and len(gr_audio) == 2: sr, audio = gr_audio if audio is None or len(audio) == 0: return "No audio received" # funasr expects file path or raw array; we'll save temp wav for simplicity wav_path = "_tmp.wav" sf.write(wav_path, audio, sr) result = model.generate(input=wav_path) text = result[0]["text"] if isinstance(result, list) else str(result) try: os.remove(wav_path) except Exception: pass return text # Case 2: API style: dict with base64 if isinstance(gr_audio, dict) and "data" in gr_audio: try: audio, sr = _decode_audio_b64(gr_audio["data"]) wav_path = "_tmp.wav" sf.write(wav_path, audio, sr) result = model.generate(input=wav_path) text = result[0]["text"] if isinstance(result, list) else str(result) try: os.remove(wav_path) except Exception: pass return text except Exception as e: return f"Failed to decode/process audio: {e}" return "Unsupported input format" except Exception as e: return f"Error during transcription: {str(e)}" with gr.Blocks() as demo: gr.Markdown(""" # SenseVoiceSmall ASR (Gradio) Upload a short audio file or record via microphone to get a transcript. """) with gr.Row(): audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio") with gr.Row(): out = gr.Textbox(label="Transcript") btn = gr.Button("Transcribe") btn.click(fn=transcribe_audio, inputs=audio, outputs=out, api_name="predict") if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))