""" Slim GPU service for HF Inference Endpoints. Exposes /transcribe and /transcribe/stream using Voxtral 4B. """ import io import json import logging import os import threading from contextlib import asynccontextmanager import numpy as np import soundfile as sf import librosa import torch from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from voxtral_inference import VoxtralModel logger = logging.getLogger("gpu_service") # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- TARGET_SR = 16000 MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model") # --------------------------------------------------------------------------- # Singleton # --------------------------------------------------------------------------- _voxtral: VoxtralModel | None = None def _load_voxtral(): global _voxtral if _voxtral is None: logger.info("Loading Voxtral from %s ...", MODEL_DIR) _voxtral = VoxtralModel(MODEL_DIR) logger.info("Voxtral model loaded.") return _voxtral # --------------------------------------------------------------------------- # Audio helpers # --------------------------------------------------------------------------- def prepare_audio(raw_bytes: bytes) -> np.ndarray: """Read any audio format -> float32 mono @ 16 kHz.""" audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32") if audio.ndim > 1: audio = audio.mean(axis=1) if sr != TARGET_SR: audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR) return audio # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): _load_voxtral() yield app = FastAPI(title="Voxtral Transcription Service (HF Endpoint)", lifespan=lifespan) @app.get("/health") async def health(): return {"status": "ok", "gpu_available": torch.cuda.is_available()} @app.post("/transcribe") async def transcribe(audio: UploadFile = File(...)): try: raw = await audio.read() audio_16k = prepare_audio(raw) model = _load_voxtral() text = model.transcribe(audio_16k) return {"text": text} except Exception as e: logger.exception("Transcription failed") return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/transcribe/stream") async def transcribe_stream(audio: UploadFile = File(...)): try: raw = await audio.read() audio_16k = prepare_audio(raw) except Exception as e: logger.exception("Audio preparation failed") return JSONResponse(status_code=500, content={"error": str(e)}) async def event_generator(): try: model = _load_voxtral() full_text = "" # Run blocking generator in a thread tokens = [] def _run(): for tok in model.transcribe_stream(audio_16k): tokens.append(tok) thread = threading.Thread(target=_run) thread.start() thread.join() for tok in tokens: full_text += tok yield {"event": "token", "data": json.dumps({"token": tok})} full_text = full_text.strip() yield {"event": "done", "data": json.dumps({"text": full_text})} except Exception as e: logger.exception("Streaming transcription failed") yield {"event": "error", "data": json.dumps({"error": str(e)})} return EventSourceResponse(event_generator())