Spaces:
Running
Running
File size: 4,515 Bytes
99c0c40 07c06d2 88b685d 99c0c40 88b685d 99c0c40 07c06d2 88b685d 99c0c40 07c06d2 99c0c40 07c06d2 99c0c40 07c06d2 99c0c40 07c06d2 99c0c40 88b685d 99c0c40 07c06d2 99c0c40 07c06d2 99c0c40 07c06d2 99c0c40 07c06d2 88b685d 07c06d2 88b685d 07c06d2 99c0c40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
import tempfile
import torch
import torchaudio
import numpy as np
from fastapi import FastAPI, UploadFile, File, Header, HTTPException, status, Form
from fastapi.responses import PlainTextResponse
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
app = FastAPI(title="SenseVoice ASR API")
# Auth Token (default myLinuxTypeless888)
AUTH_TOKEN = os.environ.get("AUTH_TOKEN", "myLinuxTypeless888")
# Model configurations
MODEL_CACHE_DIR = "./models"
SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall")
VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad")
# Device config (force CPU for free HF Space tier)
device = "cpu"
print(f"Loading SenseVoice model on {device}...")
asr_model = AutoModel(
model=SENSE_VOICE_SMALL_LOCAL_PATH,
trust_remote_code=False,
vad_model=VAD_MODEL_LOCAL_PATH,
vad_kwargs={"max_single_segment_time": 30000},
device=device,
disable_update=True,
hub="hf",
)
print("Model loaded successfully.")
@app.get("/")
def read_root():
"""Hugging Face Space health check endpoint."""
return {"status": "ok", "message": "SenseVoice ASR API is running"}
@app.post("/transcribe")
async def transcribe(
file: UploadFile = File(...),
authorization: str = Header(None),
):
# Verify auth token
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header missing",
)
parts = authorization.split()
if len(parts) != 2 or parts[0].lower() != "bearer" or parts[1] != AUTH_TOKEN:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization token",
)
# Save incoming audio stream to a temporary WAV file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(await file.read())
audio_path = tmp.name
try:
# Generate transcription
res = asr_model.generate(
input=audio_path,
cache={},
language="auto",
use_itn=True,
batch_size_s=60,
merge_vad=True,
)
text = rich_transcription_postprocess(res[0]["text"])
return {"text": text}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Transcription error: {str(e)}",
)
finally:
if os.path.exists(audio_path):
os.unlink(audio_path)
@app.post("/v1/audio/transcriptions")
async def openai_audio_transcriptions(
file: UploadFile = File(...),
model: str | None = Form(default=None),
response_format: str | None = Form(default=None),
authorization: str = Header(None),
):
"""
OpenAI-compatible audio transcription endpoint.
Request (multipart/form-data):
- file: audio file
- model: optional model identifier (ignored, uses SenseVoice)
- response_format: json | text | srt | vtt (default json)
Response:
- json: {"text": "..."}
- text: raw text body
"""
# Verify auth token when provided
if authorization:
parts = authorization.split()
if not (len(parts) == 2 and parts[0].lower() == "bearer" and parts[1] == AUTH_TOKEN):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization token",
)
fmt = (response_format or "json").strip().lower()
if fmt not in {"json", "text", "srt", "vtt"}:
fmt = "json"
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(await file.read())
audio_path = tmp.name
try:
res = asr_model.generate(
input=audio_path,
cache={},
language="auto",
use_itn=True,
batch_size_s=60,
merge_vad=True,
)
text = rich_transcription_postprocess(res[0]["text"])
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Transcription error: {str(e)}",
)
finally:
if os.path.exists(audio_path):
os.unlink(audio_path)
if fmt == "text":
return PlainTextResponse(content=text)
return {"text": text}
@app.get("/health")
def health():
return {"status": "ok"}
|