| | """
|
| | Slim GPU service for HF Inference Endpoints.
|
| | Exposes /transcribe and /transcribe/stream using Voxtral 4B.
|
| | """
|
| |
|
| | import io
|
| | import json
|
| | import logging
|
| | import os
|
| | import threading
|
| | from contextlib import asynccontextmanager
|
| |
|
| | import numpy as np
|
| | import soundfile as sf
|
| | import librosa
|
| | import torch
|
| | from fastapi import FastAPI, File, UploadFile
|
| | from fastapi.responses import JSONResponse
|
| | from sse_starlette.sse import EventSourceResponse
|
| |
|
| | from voxtral_inference import VoxtralModel
|
| |
|
| | logger = logging.getLogger("gpu_service")
|
| |
|
| |
|
| |
|
| |
|
| | TARGET_SR = 16000
|
| | MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
|
| |
|
| |
|
| |
|
| |
|
| | _voxtral: VoxtralModel | None = None
|
| |
|
| |
|
| | def _load_voxtral():
|
| | global _voxtral
|
| | if _voxtral is None:
|
| | logger.info("Loading Voxtral from %s ...", MODEL_DIR)
|
| | _voxtral = VoxtralModel(MODEL_DIR)
|
| | logger.info("Voxtral model loaded.")
|
| | return _voxtral
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def prepare_audio(raw_bytes: bytes) -> np.ndarray:
|
| | """Read any audio format -> float32 mono @ 16 kHz."""
|
| | audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
|
| | if audio.ndim > 1:
|
| | audio = audio.mean(axis=1)
|
| | if sr != TARGET_SR:
|
| | audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
|
| | return audio
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | @asynccontextmanager
|
| | async def lifespan(app: FastAPI):
|
| | _load_voxtral()
|
| | yield
|
| |
|
| |
|
| | app = FastAPI(title="Voxtral Transcription Service (HF Endpoint)", lifespan=lifespan)
|
| |
|
| |
|
| | @app.get("/health")
|
| | async def health():
|
| | return {"status": "ok", "gpu_available": torch.cuda.is_available()}
|
| |
|
| |
|
| | @app.post("/transcribe")
|
| | async def transcribe(audio: UploadFile = File(...)):
|
| | try:
|
| | raw = await audio.read()
|
| | audio_16k = prepare_audio(raw)
|
| |
|
| | model = _load_voxtral()
|
| | text = model.transcribe(audio_16k)
|
| |
|
| | return {"text": text}
|
| | except Exception as e:
|
| | logger.exception("Transcription failed")
|
| | return JSONResponse(status_code=500, content={"error": str(e)})
|
| |
|
| |
|
| | @app.post("/transcribe/stream")
|
| | async def transcribe_stream(audio: UploadFile = File(...)):
|
| | try:
|
| | raw = await audio.read()
|
| | audio_16k = prepare_audio(raw)
|
| | except Exception as e:
|
| | logger.exception("Audio preparation failed")
|
| | return JSONResponse(status_code=500, content={"error": str(e)})
|
| |
|
| | async def event_generator():
|
| | try:
|
| | model = _load_voxtral()
|
| | full_text = ""
|
| |
|
| |
|
| | tokens = []
|
| |
|
| | def _run():
|
| | for tok in model.transcribe_stream(audio_16k):
|
| | tokens.append(tok)
|
| |
|
| | thread = threading.Thread(target=_run)
|
| | thread.start()
|
| | thread.join()
|
| |
|
| | for tok in tokens:
|
| | full_text += tok
|
| | yield {"event": "token", "data": json.dumps({"token": tok})}
|
| |
|
| | full_text = full_text.strip()
|
| | yield {"event": "done", "data": json.dumps({"text": full_text})}
|
| | except Exception as e:
|
| | logger.exception("Streaming transcription failed")
|
| | yield {"event": "error", "data": json.dumps({"error": str(e)})}
|
| |
|
| | return EventSourceResponse(event_generator())
|
| |
|