Spaces:
Running
Running
| import os | |
| import tempfile | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Header, HTTPException, status, Form | |
| from fastapi.responses import PlainTextResponse | |
| from funasr import AutoModel | |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess | |
| app = FastAPI(title="SenseVoice ASR API") | |
| # Auth Token (default myLinuxTypeless888) | |
| AUTH_TOKEN = os.environ.get("AUTH_TOKEN", "myLinuxTypeless888") | |
| # Model configurations | |
| MODEL_CACHE_DIR = "./models" | |
| SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall") | |
| VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad") | |
| # Device config (force CPU for free HF Space tier) | |
| device = "cpu" | |
| print(f"Loading SenseVoice model on {device}...") | |
| asr_model = AutoModel( | |
| model=SENSE_VOICE_SMALL_LOCAL_PATH, | |
| trust_remote_code=False, | |
| vad_model=VAD_MODEL_LOCAL_PATH, | |
| vad_kwargs={"max_single_segment_time": 30000}, | |
| device=device, | |
| disable_update=True, | |
| hub="hf", | |
| ) | |
| print("Model loaded successfully.") | |
| def read_root(): | |
| """Hugging Face Space health check endpoint.""" | |
| return {"status": "ok", "message": "SenseVoice ASR API is running"} | |
| async def transcribe( | |
| file: UploadFile = File(...), | |
| authorization: str = Header(None), | |
| ): | |
| # Verify auth token | |
| if not authorization: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Authorization header missing", | |
| ) | |
| parts = authorization.split() | |
| if len(parts) != 2 or parts[0].lower() != "bearer" or parts[1] != AUTH_TOKEN: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authorization token", | |
| ) | |
| # Save incoming audio stream to a temporary WAV file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| audio_path = tmp.name | |
| try: | |
| # Generate transcription | |
| res = asr_model.generate( | |
| input=audio_path, | |
| cache={}, | |
| language="auto", | |
| use_itn=True, | |
| batch_size_s=60, | |
| merge_vad=True, | |
| ) | |
| text = rich_transcription_postprocess(res[0]["text"]) | |
| return {"text": text} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Transcription error: {str(e)}", | |
| ) | |
| finally: | |
| if os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| async def openai_audio_transcriptions( | |
| file: UploadFile = File(...), | |
| model: str | None = Form(default=None), | |
| response_format: str | None = Form(default=None), | |
| authorization: str = Header(None), | |
| ): | |
| """ | |
| OpenAI-compatible audio transcription endpoint. | |
| Request (multipart/form-data): | |
| - file: audio file | |
| - model: optional model identifier (ignored, uses SenseVoice) | |
| - response_format: json | text | srt | vtt (default json) | |
| Response: | |
| - json: {"text": "..."} | |
| - text: raw text body | |
| """ | |
| # Verify auth token when provided | |
| if authorization: | |
| parts = authorization.split() | |
| if not (len(parts) == 2 and parts[0].lower() == "bearer" and parts[1] == AUTH_TOKEN): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authorization token", | |
| ) | |
| fmt = (response_format or "json").strip().lower() | |
| if fmt not in {"json", "text", "srt", "vtt"}: | |
| fmt = "json" | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| audio_path = tmp.name | |
| try: | |
| res = asr_model.generate( | |
| input=audio_path, | |
| cache={}, | |
| language="auto", | |
| use_itn=True, | |
| batch_size_s=60, | |
| merge_vad=True, | |
| ) | |
| text = rich_transcription_postprocess(res[0]["text"]) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Transcription error: {str(e)}", | |
| ) | |
| finally: | |
| if os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| if fmt == "text": | |
| return PlainTextResponse(content=text) | |
| return {"text": text} | |
| def health(): | |
| return {"status": "ok"} | |