gpu_endpoint / app.py
tantk's picture
Upload app.py with huggingface_hub
5b2eaaf verified
"""
Slim GPU service for HF Inference Endpoints.
Exposes /transcribe and /transcribe/stream using Voxtral 4B.
"""
import io
import json
import logging
import os
import threading
from contextlib import asynccontextmanager
import numpy as np
import soundfile as sf
import librosa
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from voxtral_inference import VoxtralModel
logger = logging.getLogger("gpu_service")
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
TARGET_SR = 16000
MODEL_DIR = os.environ.get("VOXTRAL_MODEL_DIR", "/repository/voxtral-model")
# ---------------------------------------------------------------------------
# Singleton
# ---------------------------------------------------------------------------
_voxtral: VoxtralModel | None = None
def _load_voxtral():
global _voxtral
if _voxtral is None:
logger.info("Loading Voxtral from %s ...", MODEL_DIR)
_voxtral = VoxtralModel(MODEL_DIR)
logger.info("Voxtral model loaded.")
return _voxtral
# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def prepare_audio(raw_bytes: bytes) -> np.ndarray:
"""Read any audio format -> float32 mono @ 16 kHz."""
audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
if audio.ndim > 1:
audio = audio.mean(axis=1)
if sr != TARGET_SR:
audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
return audio
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
_load_voxtral()
yield
app = FastAPI(title="Voxtral Transcription Service (HF Endpoint)", lifespan=lifespan)
@app.get("/health")
async def health():
return {"status": "ok", "gpu_available": torch.cuda.is_available()}
@app.post("/transcribe")
async def transcribe(audio: UploadFile = File(...)):
try:
raw = await audio.read()
audio_16k = prepare_audio(raw)
model = _load_voxtral()
text = model.transcribe(audio_16k)
return {"text": text}
except Exception as e:
logger.exception("Transcription failed")
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/transcribe/stream")
async def transcribe_stream(audio: UploadFile = File(...)):
try:
raw = await audio.read()
audio_16k = prepare_audio(raw)
except Exception as e:
logger.exception("Audio preparation failed")
return JSONResponse(status_code=500, content={"error": str(e)})
async def event_generator():
try:
model = _load_voxtral()
full_text = ""
# Run blocking generator in a thread
tokens = []
def _run():
for tok in model.transcribe_stream(audio_16k):
tokens.append(tok)
thread = threading.Thread(target=_run)
thread.start()
thread.join()
for tok in tokens:
full_text += tok
yield {"event": "token", "data": json.dumps({"token": tok})}
full_text = full_text.strip()
yield {"event": "done", "data": json.dumps({"text": full_text})}
except Exception as e:
logger.exception("Streaming transcription failed")
yield {"event": "error", "data": json.dumps({"error": str(e)})}
return EventSourceResponse(event_generator())