File size: 3,962 Bytes
6216b68 5b2eaaf 6216b68 5b2eaaf 6216b68 a0d2600 6216b68 a0d2600 6216b68 5b2eaaf 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 5b2eaaf 6216b68 5b2eaaf 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 a0d2600 6216b68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """
Slim GPU service for HF Inference Endpoints.
Exposes /transcribe and /transcribe/stream using Voxtral 4B.
"""
import io
import json
import logging
import os
import threading
from contextlib import asynccontextmanager
import numpy as np
import soundfile as sf
import librosa
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from voxtral_inference import VoxtralModel
logger = logging.getLogger("gpu_service")
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
TARGET_SR = 16000
MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
# ---------------------------------------------------------------------------
# Singleton
# ---------------------------------------------------------------------------
_voxtral: VoxtralModel | None = None
def _load_voxtral():
global _voxtral
if _voxtral is None:
logger.info("Loading Voxtral from %s ...", MODEL_DIR)
_voxtral = VoxtralModel(MODEL_DIR)
logger.info("Voxtral model loaded.")
return _voxtral
# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def prepare_audio(raw_bytes: bytes) -> np.ndarray:
"""Read any audio format -> float32 mono @ 16 kHz."""
audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
if sr != TARGET_SR:
audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
return audio
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
_load_voxtral()
yield
app = FastAPI(title="Voxtral Transcription Service (HF Endpoint)", lifespan=lifespan)
@app.get("/health")
async def health():
return {"status": "ok", "gpu_available": torch.cuda.is_available()}
@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
try:
raw = await audio.read()
audio_16k = prepare_audio(raw)
model = _load_voxtral()
text = model.transcribe(audio_16k)
return {"text": text}
except Exception as e:
logger.exception("Transcription failed")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/transcribe/stream")
async def transcribe_stream(audio: UploadFile = File(...)):
try:
raw = await audio.read()
audio_16k = prepare_audio(raw)
except Exception as e:
logger.exception("Audio preparation failed")
return JSONResponse(status_code=500, content={"error": str(e)})
async def event_generator():
try:
model = _load_voxtral()
full_text = ""
# Run blocking generator in a thread
tokens = []
def _run():
for tok in model.transcribe_stream(audio_16k):
tokens.append(tok)
thread = threading.Thread(target=_run)
thread.start()
thread.join()
for tok in tokens:
full_text += tok
yield {"event": "token", "data": json.dumps({"token": tok})}
full_text = full_text.strip()
yield {"event": "done", "data": json.dumps({"text": full_text})}
except Exception as e:
logger.exception("Streaming transcription failed")
yield {"event": "error", "data": json.dumps({"error": str(e)})}
return EventSourceResponse(event_generator())
|