File size: 4,515 Bytes
99c0c40
 
 
 
 
07c06d2
88b685d
99c0c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88b685d
99c0c40
 
 
 
 
 
 
 
 
 
07c06d2
88b685d
 
 
 
 
 
99c0c40
 
 
07c06d2
99c0c40
 
 
 
 
07c06d2
99c0c40
07c06d2
99c0c40
 
 
 
07c06d2
99c0c40
 
 
 
 
 
 
 
 
88b685d
99c0c40
 
 
 
 
 
 
07c06d2
99c0c40
 
07c06d2
99c0c40
 
 
07c06d2
99c0c40
 
 
 
 
07c06d2
 
 
 
 
 
 
 
 
 
 
 
 
88b685d
07c06d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88b685d
07c06d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99c0c40
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import tempfile
import torch
import torchaudio
import numpy as np
from fastapi import FastAPI, UploadFile, File, Header, HTTPException, status, Form
from fastapi.responses import PlainTextResponse
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

app = FastAPI(title="SenseVoice ASR API")

# Auth Token (default myLinuxTypeless888)
AUTH_TOKEN = os.environ.get("AUTH_TOKEN", "myLinuxTypeless888")

# Model configurations
MODEL_CACHE_DIR = "./models"
SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall")
VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad")

# Device config (force CPU for free HF Space tier)
device = "cpu"

print(f"Loading SenseVoice model on {device}...")
asr_model = AutoModel(
    model=SENSE_VOICE_SMALL_LOCAL_PATH,
    trust_remote_code=False,
    vad_model=VAD_MODEL_LOCAL_PATH,
    vad_kwargs={"max_single_segment_time": 30000},
    device=device,
    disable_update=True,
    hub="hf",
)
print("Model loaded successfully.")


@app.get("/")
def read_root():
    """Hugging Face Space health check endpoint."""
    return {"status": "ok", "message": "SenseVoice ASR API is running"}


@app.post("/transcribe")
async def transcribe(
    file: UploadFile = File(...),
    authorization: str = Header(None),
):
    # Verify auth token
    if not authorization:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Authorization header missing",
        )

    parts = authorization.split()
    if len(parts) != 2 or parts[0].lower() != "bearer" or parts[1] != AUTH_TOKEN:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authorization token",
        )

    # Save incoming audio stream to a temporary WAV file
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        tmp.write(await file.read())
        audio_path = tmp.name

    try:
        # Generate transcription
        res = asr_model.generate(
            input=audio_path,
            cache={},
            language="auto",
            use_itn=True,
            batch_size_s=60,
            merge_vad=True,
        )

        text = rich_transcription_postprocess(res[0]["text"])
        return {"text": text}

    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Transcription error: {str(e)}",
        )
    finally:
        if os.path.exists(audio_path):
            os.unlink(audio_path)


@app.post("/v1/audio/transcriptions")
async def openai_audio_transcriptions(
    file: UploadFile = File(...),
    model: str | None = Form(default=None),
    response_format: str | None = Form(default=None),
    authorization: str = Header(None),
):
    """
    OpenAI-compatible audio transcription endpoint.

    Request (multipart/form-data):
      - file: audio file
      - model: optional model identifier (ignored, uses SenseVoice)
      - response_format: json | text | srt | vtt (default json)

    Response:
      - json: {"text": "..."}
      - text: raw text body
    """
    # Verify auth token when provided
    if authorization:
        parts = authorization.split()
        if not (len(parts) == 2 and parts[0].lower() == "bearer" and parts[1] == AUTH_TOKEN):
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid authorization token",
            )

    fmt = (response_format or "json").strip().lower()
    if fmt not in {"json", "text", "srt", "vtt"}:
        fmt = "json"

    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        tmp.write(await file.read())
        audio_path = tmp.name

    try:
        res = asr_model.generate(
            input=audio_path,
            cache={},
            language="auto",
            use_itn=True,
            batch_size_s=60,
            merge_vad=True,
        )

        text = rich_transcription_postprocess(res[0]["text"])
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Transcription error: {str(e)}",
        )
    finally:
        if os.path.exists(audio_path):
            os.unlink(audio_path)

    if fmt == "text":
        return PlainTextResponse(content=text)

    return {"text": text}


@app.get("/health")
def health():
    return {"status": "ok"}