File size: 3,962 Bytes
6216b68
 
5b2eaaf
6216b68
 
 
 
 
 
 
 
 
 
 
 
 
5b2eaaf
6216b68
 
 
a0d2600
 
6216b68
 
 
 
 
 
a0d2600
 
6216b68
5b2eaaf
6216b68
a0d2600
6216b68
 
 
a0d2600
 
 
 
6216b68
a0d2600
6216b68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b2eaaf
6216b68
 
 
5b2eaaf
6216b68
 
 
 
 
 
 
 
a0d2600
6216b68
 
 
 
a0d2600
 
6216b68
 
 
 
 
 
 
 
a0d2600
6216b68
 
 
 
 
 
 
 
 
a0d2600
 
6216b68
a0d2600
 
6216b68
a0d2600
 
 
6216b68
a0d2600
 
6216b68
a0d2600
 
 
 
 
 
6216b68
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""

Slim GPU service for HF Inference Endpoints.

Exposes /transcribe and /transcribe/stream using Voxtral 4B.

"""

import io
import json
import logging
import os
import threading
from contextlib import asynccontextmanager

import numpy as np
import soundfile as sf
import librosa
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse

from voxtral_inference import VoxtralModel

logger = logging.getLogger("gpu_service")

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
TARGET_SR = 16000
MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")

# ---------------------------------------------------------------------------
# Singleton
# ---------------------------------------------------------------------------
_voxtral: VoxtralModel | None = None


def _load_voxtral():
    global _voxtral
    if _voxtral is None:
        logger.info("Loading Voxtral from %s ...", MODEL_DIR)
        _voxtral = VoxtralModel(MODEL_DIR)
        logger.info("Voxtral model loaded.")
    return _voxtral


# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def prepare_audio(raw_bytes: bytes) -> np.ndarray:
    """Read any audio format -> float32 mono @ 16 kHz."""
    audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    if sr != TARGET_SR:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
    return audio


# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
    _load_voxtral()
    yield


app = FastAPI(title="Voxtral Transcription Service (HF Endpoint)", lifespan=lifespan)


@app.get("/health")
async def health():
    return {"status": "ok", "gpu_available": torch.cuda.is_available()}


@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
    try:
        raw = await audio.read()
        audio_16k = prepare_audio(raw)

        model = _load_voxtral()
        text = model.transcribe(audio_16k)

        return {"text": text}
    except Exception as e:
        logger.exception("Transcription failed")
        return JSONResponse(status_code=500, content={"error": str(e)})


@app.post("/transcribe/stream")
async def transcribe_stream(audio: UploadFile = File(...)):
    try:
        raw = await audio.read()
        audio_16k = prepare_audio(raw)
    except Exception as e:
        logger.exception("Audio preparation failed")
        return JSONResponse(status_code=500, content={"error": str(e)})

    async def event_generator():
        try:
            model = _load_voxtral()
            full_text = ""

            # Run blocking generator in a thread
            tokens = []

            def _run():
                for tok in model.transcribe_stream(audio_16k):
                    tokens.append(tok)

            thread = threading.Thread(target=_run)
            thread.start()
            thread.join()

            for tok in tokens:
                full_text += tok
                yield {"event": "token", "data": json.dumps({"token": tok})}

            full_text = full_text.strip()
            yield {"event": "done", "data": json.dumps({"text": full_text})}
        except Exception as e:
            logger.exception("Streaming transcription failed")
            yield {"event": "error", "data": json.dumps({"error": str(e)})}

    return EventSourceResponse(event_generator())