"""Speaker Diarization API - FastAPI Application.""" import asyncio import tempfile import traceback from pathlib import Path from typing import Optional, List import os import torch from fastapi import ( FastAPI, File, UploadFile, Form, WebSocket, WebSocketDisconnect, HTTPException, Query, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from pydantic import BaseModel from loguru import logger class SegmentOut(BaseModel): start: float end: float duration: float speaker: str class DiarizationResponse(BaseModel): status: str = "success" num_speakers: int audio_duration: float processing_time: float sample_rate: int speakers: List[str] segments: List[SegmentOut] class HealthResponse(BaseModel): status: str device: str version: str = "1.0.0" app = FastAPI( title="Speaker Diarization API", description="Who Spoke When - Speaker diarization using ECAPA-TDNN + AHC Clustering", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _pipeline = None def get_pipeline(): global _pipeline if _pipeline is None: from app.pipeline import DiarizationPipeline cache_dir = os.getenv( "CACHE_DIR", str(Path(tempfile.gettempdir()) / "model_cache"), ) _pipeline = DiarizationPipeline( device="auto", use_pyannote_vad=True, use_pyannote_diarization=os.getenv("USE_PYANNOTE_DIARIZATION", "true").lower() in {"1", "true", "yes"}, pyannote_diarization_model=os.getenv("PYANNOTE_DIARIZATION_MODEL", "pyannote/speaker-diarization-3.1"), hf_token=os.getenv("HF_TOKEN"), max_speakers=int(os.getenv("MAX_SPEAKERS", "6")), cache_dir=cache_dir, ) return _pipeline @app.get("/health", response_model=HealthResponse, tags=["System"]) async def health_check(): device = "cuda" if torch.cuda.is_available() else "cpu" return HealthResponse(status="healthy", device=device) @app.post("/diarize", response_model=DiarizationResponse, tags=["Diarization"]) async def diarize_audio( file: UploadFile = File(...), num_speakers: Optional[int] = Form(None, ge=1, le=20), ): """Diarize an uploaded audio file. Returns timestamped speaker labels.""" allowed = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".webm"} suffix = Path(file.filename).suffix.lower() if suffix not in allowed: raise HTTPException(status_code=415, detail=f"Unsupported format '{suffix}'") audio_bytes = await file.read() with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: pipeline = get_pipeline() result = pipeline.process(tmp_path, num_speakers=num_speakers) return DiarizationResponse( num_speakers=result.num_speakers, audio_duration=result.audio_duration, processing_time=result.processing_time, sample_rate=result.sample_rate, speakers=sorted(set(s.speaker for s in result.segments)), segments=[SegmentOut(**s.to_dict()) for s in result.segments], ) except Exception as e: logger.error(f"Diarization failed: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(e)) finally: Path(tmp_path).unlink(missing_ok=True) @app.post("/diarize/url", response_model=DiarizationResponse, tags=["Diarization"]) async def diarize_from_url( audio_url: str = Query(...), num_speakers: Optional[int] = Query(None, ge=1, le=20), ): """Diarize audio from a URL.""" import httpx try: async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.get(audio_url) resp.raise_for_status() except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to fetch audio: {e}") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(resp.content) tmp_path = tmp.name try: pipeline = get_pipeline() result = pipeline.process(tmp_path, num_speakers=num_speakers) return DiarizationResponse( num_speakers=result.num_speakers, audio_duration=result.audio_duration, processing_time=result.processing_time, sample_rate=result.sample_rate, speakers=sorted(set(s.speaker for s in result.segments)), segments=[SegmentOut(**s.to_dict()) for s in result.segments], ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: Path(tmp_path).unlink(missing_ok=True) @app.websocket("/ws/stream") async def stream_diarization(websocket: WebSocket): """Real-time streaming diarization via WebSocket.""" await websocket.accept() import numpy as np audio_buffer = bytearray() sample_rate = 16000 num_speakers = None chunk_count = 0 try: config_msg = await websocket.receive_json() sample_rate = config_msg.get("sample_rate", 16000) num_speakers = config_msg.get("num_speakers", None) await websocket.send_json( { "type": "progress", "data": {"message": "Config received. Send audio chunks.", "chunks_received": 0}, } ) while True: try: msg = await asyncio.wait_for(websocket.receive(), timeout=30.0) except asyncio.TimeoutError: await websocket.send_json({"type": "error", "data": {"message": "Timeout"}}) break if "bytes" in msg: audio_buffer.extend(msg["bytes"]) chunk_count += 1 await websocket.send_json( { "type": "progress", "data": { "message": f"Received chunk {chunk_count}", "chunks_received": chunk_count, }, } ) elif "text" in msg: import json data = json.loads(msg["text"]) if data.get("type") == "eof": break if not audio_buffer: await websocket.send_json({"type": "error", "data": {"message": "No audio received"}}) return import torch as torch_local audio_np = np.frombuffer(audio_buffer, dtype=np.float32).copy() audio_tensor = torch_local.from_numpy(audio_np) await websocket.send_json( { "type": "progress", "data": {"message": "Running diarization pipeline..."}, } ) loop = asyncio.get_event_loop() pipeline = get_pipeline() result = await loop.run_in_executor( None, lambda: pipeline.process(audio_tensor, sample_rate=sample_rate, num_speakers=num_speakers), ) for seg in result.segments: await websocket.send_json({"type": "segment", "data": seg.to_dict()}) await websocket.send_json( { "type": "done", "data": { "num_speakers": result.num_speakers, "total_segments": len(result.segments), "audio_duration": result.audio_duration, "processing_time": result.processing_time, }, } ) except WebSocketDisconnect: logger.info("WebSocket client disconnected") except Exception as e: logger.error(f"WebSocket error: {traceback.format_exc()}") try: await websocket.send_json({"type": "error", "data": {"message": str(e)}}) except Exception: pass @app.get("/", response_class=HTMLResponse, include_in_schema=False) async def serve_ui(): ui_path = Path(__file__).resolve().parent.parent / "static" / "index.html" if ui_path.exists(): return HTMLResponse(ui_path.read_text(encoding="utf-8")) return HTMLResponse("

Speaker Diarization API

API Docs

") @app.get("/debug", tags=["System"]) async def debug(): import inspect import speechbrain from speechbrain.inference.classifiers import EncoderClassifier cache_dir = os.getenv( "CACHE_DIR", str(Path(tempfile.gettempdir()) / "model_cache"), ) sig = str(inspect.signature(EncoderClassifier.from_hparams)) return { "speechbrain_version": speechbrain.__version__, "temp_dir": tempfile.gettempdir(), "temp_writable": os.access(tempfile.gettempdir(), os.W_OK), "cache_dir": cache_dir, "cache_exists": os.path.exists(cache_dir), "from_hparams_signature": sig, } static_dir = Path(__file__).resolve().parent.parent / "static" if static_dir.exists(): app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")