| """ |
| Slim GPU service for HF Inference Endpoints. |
| Exposes /diarize, /embed, /transcribe, and /transcribe/stream endpoints. |
| """ |
|
|
| import io |
| import json |
| import logging |
| import os |
| import re |
| import threading |
| from contextlib import asynccontextmanager |
|
|
| import numpy as np |
| import soundfile as sf |
| import librosa |
| import torch |
| from fastapi import FastAPI, File, Form, UploadFile |
| from fastapi.responses import JSONResponse |
| from pydub import AudioSegment |
| from sse_starlette.sse import EventSourceResponse |
|
|
| logger = logging.getLogger("gpu_service") |
|
|
| |
| |
| |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| PYANNOTE_MODEL = "pyannote/speaker-diarization-community-1" |
| FUNASR_MODEL = "iic/speech_campplus_sv_zh-cn_16k-common" |
| PYANNOTE_MIN_SPEAKERS = int(os.environ.get("PYANNOTE_MIN_SPEAKERS", "1")) |
| PYANNOTE_MAX_SPEAKERS = int(os.environ.get("PYANNOTE_MAX_SPEAKERS", "10")) |
| TARGET_SR = 16000 |
|
|
| |
| |
| |
| _diarize_pipeline = None |
| _embed_model = None |
| _voxtral_model = None |
| _voxtral_processor = None |
|
|
| VOXTRAL_MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602" |
|
|
| |
| _MARKER_RE = re.compile(r"\[STREAMING_PAD\]|\[STREAMING_WORD\]") |
|
|
|
|
| def _load_diarize_pipeline(): |
| global _diarize_pipeline |
| if _diarize_pipeline is None: |
| from pyannote.audio import Pipeline as PyannotePipeline |
|
|
| _diarize_pipeline = PyannotePipeline.from_pretrained( |
| PYANNOTE_MODEL, token=HF_TOKEN |
| ) |
| _diarize_pipeline = _diarize_pipeline.to(torch.device("cuda")) |
| return _diarize_pipeline |
|
|
|
|
| def _load_embed_model(): |
| global _embed_model |
| if _embed_model is None: |
| from funasr import AutoModel |
|
|
| _embed_model = AutoModel(model=FUNASR_MODEL) |
| return _embed_model |
|
|
|
|
| def _load_voxtral(): |
| """Lazy-load Voxtral model and processor (first call only).""" |
| global _voxtral_model, _voxtral_processor |
| if _voxtral_model is None: |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor |
|
|
| logger.info("Loading Voxtral model %s ...", VOXTRAL_MODEL_ID) |
| _voxtral_processor = AutoProcessor.from_pretrained( |
| VOXTRAL_MODEL_ID, trust_remote_code=True |
| ) |
| _voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| VOXTRAL_MODEL_ID, torch_dtype=torch.float16, trust_remote_code=True |
| ).to("cuda") |
| logger.info("Voxtral model loaded.") |
| return _voxtral_model, _voxtral_processor |
|
|
|
|
| def _clean_voxtral_text(text: str) -> str: |
| """Strip Voxtral streaming markers and collapse whitespace.""" |
| text = _MARKER_RE.sub("", text) |
| return " ".join(text.split()).strip() |
|
|
|
|
| |
| |
| |
| def prepare_audio(raw_bytes: bytes) -> np.ndarray: |
| """Read any audio format -> float32 mono @ 16 kHz.""" |
| audio, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32") |
| if audio.ndim > 1: |
| audio = audio.mean(axis=1) |
| if sr != TARGET_SR: |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR) |
| return audio |
|
|
|
|
| def prepare_audio_slice(raw_bytes: bytes, start_time: float, end_time: float) -> np.ndarray: |
| """Read audio, slice by time, return float32 mono @ 16 kHz.""" |
| seg = AudioSegment.from_file(io.BytesIO(raw_bytes)) |
| seg = seg[int(start_time * 1000):int(end_time * 1000)] |
| seg = seg.set_frame_rate(TARGET_SR).set_channels(1).set_sample_width(2) |
| return np.array(seg.get_array_of_samples(), dtype=np.float32) / 32768.0 |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| _load_diarize_pipeline() |
| yield |
|
|
|
|
| app = FastAPI(title="GPU Service (HF Endpoint)", lifespan=lifespan) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok", "gpu_available": torch.cuda.is_available()} |
|
|
|
|
| @app.post("/diarize") |
| async def diarize( |
| audio: UploadFile = File(...), |
| min_speakers: int | None = Form(None), |
| max_speakers: int | None = Form(None), |
| ): |
| try: |
| raw = await audio.read() |
| audio_16k = prepare_audio(raw) |
|
|
| pipeline = _load_diarize_pipeline() |
| waveform = torch.from_numpy(audio_16k).unsqueeze(0).float() |
| input_data = {"waveform": waveform, "sample_rate": TARGET_SR} |
|
|
| result = pipeline( |
| input_data, |
| min_speakers=min_speakers or PYANNOTE_MIN_SPEAKERS, |
| max_speakers=max_speakers or PYANNOTE_MAX_SPEAKERS, |
| ) |
| |
| diarization = getattr(result, "speaker_diarization", result) |
|
|
| segments = [] |
| for turn, _, speaker in diarization.itertracks(yield_label=True): |
| segments.append( |
| { |
| "speaker": speaker, |
| "start": round(turn.start, 3), |
| "end": round(turn.end, 3), |
| "duration": round(turn.end - turn.start, 3), |
| } |
| ) |
| segments.sort(key=lambda s: s["start"]) |
| return {"segments": segments} |
| except Exception as e: |
| return JSONResponse(status_code=500, content={"error": str(e)}) |
|
|
|
|
| @app.post("/embed") |
| async def embed( |
| audio: UploadFile = File(...), |
| start_time: float | None = Form(None), |
| end_time: float | None = Form(None), |
| ): |
| try: |
| raw = await audio.read() |
| if start_time is not None and end_time is not None: |
| audio_16k = prepare_audio_slice(raw, start_time, end_time) |
| else: |
| audio_16k = prepare_audio(raw) |
|
|
| model = _load_embed_model() |
| result = model.generate(input=audio_16k, output_dir=None) |
| raw_emb = result[0]["spk_embedding"] |
| if hasattr(raw_emb, "cpu"): |
| raw_emb = raw_emb.cpu().numpy() |
| emb = np.array(raw_emb).flatten() |
|
|
| |
| norm = np.linalg.norm(emb) |
| if norm > 0: |
| emb = emb / norm |
|
|
| return {"embedding": emb.tolist(), "dim": len(emb)} |
| except Exception as e: |
| return JSONResponse(status_code=500, content={"error": str(e)}) |
|
|
|
|
| @app.post("/transcribe") |
| async def transcribe( |
| audio: UploadFile = File(...), |
| prompt: str = Form("Transcribe this audio."), |
| ): |
| try: |
| raw = await audio.read() |
| audio_16k = prepare_audio(raw) |
|
|
| model, processor = _load_voxtral() |
| inputs = processor( |
| audios=audio_16k, |
| sampling_rate=TARGET_SR, |
| text=prompt, |
| return_tensors="pt", |
| ).to("cuda") |
|
|
| output_ids = model.generate(**inputs, max_new_tokens=1024) |
| text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] |
| text = _clean_voxtral_text(text) |
|
|
| return {"text": text} |
| except Exception as e: |
| logger.exception("Transcription failed") |
| return JSONResponse(status_code=500, content={"error": str(e)}) |
|
|
|
|
| @app.post("/transcribe/stream") |
| async def transcribe_stream( |
| audio: UploadFile = File(...), |
| prompt: str = Form("Transcribe this audio."), |
| ): |
| try: |
| raw = await audio.read() |
| audio_16k = prepare_audio(raw) |
| except Exception as e: |
| logger.exception("Audio preparation failed") |
| return JSONResponse(status_code=500, content={"error": str(e)}) |
|
|
| async def event_generator(): |
| try: |
| from transformers import TextIteratorStreamer |
|
|
| model, processor = _load_voxtral() |
| inputs = processor( |
| audios=audio_16k, |
| sampling_rate=TARGET_SR, |
| text=prompt, |
| return_tensors="pt", |
| ).to("cuda") |
|
|
| streamer = TextIteratorStreamer( |
| processor.tokenizer, skip_prompt=True, skip_special_tokens=True |
| ) |
| gen_kwargs = {**inputs, "max_new_tokens": 1024, "streamer": streamer} |
|
|
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) |
| thread.start() |
|
|
| full_text = "" |
| for chunk in streamer: |
| chunk = _MARKER_RE.sub("", chunk) |
| if chunk: |
| full_text += chunk |
| yield {"event": "token", "data": json.dumps({"token": chunk})} |
|
|
| thread.join() |
| full_text = " ".join(full_text.split()).strip() |
| yield {"event": "done", "data": json.dumps({"text": full_text})} |
| except Exception as e: |
| logger.exception("Streaming transcription failed") |
| yield {"event": "error", "data": json.dumps({"error": str(e)})} |
|
|
| return EventSourceResponse(event_generator()) |
|
|