""" FastAPI + WebSocket backend for real-time speech transcription. Uses NeMo ASR model directly (no Triton required). """ import asyncio import json import uuid import sys from pathlib import Path from typing import Optional, AsyncIterator from datetime import datetime import numpy as np import torch from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from loguru import logger # Configure logging logger.remove() logger.add( sys.stderr, format="{time:HH:mm:ss} | {level: <8} | {message}", level="INFO", ) # Global model ASR_MODEL = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_model(): """Load the NeMo ASR model.""" global ASR_MODEL logger.info("Loading NeMo ASR Model...") try: import nemo.collections.asr as nemo_asr ASR_MODEL = nemo_asr.models.ASRModel.from_pretrained( model_name="nvidia/nemotron-speech-streaming-en-0.6b" ) ASR_MODEL.eval() if torch.cuda.is_available(): logger.info("Moving model to CUDA") ASR_MODEL = ASR_MODEL.cuda() else: logger.warning("CUDA not available, using CPU (will be slow)") logger.info("Model loaded successfully!") return True except Exception as e: logger.error(f"Failed to load model: {e}") return False # Create FastAPI app app = FastAPI(title="Nemotron Speech Streaming") @app.on_event("startup") async def startup(): """Load model on startup.""" load_model() @app.get("/health") async def health(): """Health check endpoint.""" return { "status": "healthy", "model_loaded": ASR_MODEL is not None, "device": DEVICE, } @app.get("/") async def root(): """Serve the frontend.""" return FileResponse(Path(__file__).parent / "static" / "index.html") @app.websocket("/ws/transcribe") async def websocket_transcribe(websocket: WebSocket): """ WebSocket endpoint for streaming transcription. Protocol: - Client sends binary PCM audio data (16-bit, 16kHz, mono) - Server sends JSON: {"type": "transcript", "text": "...", "is_final": bool} """ await websocket.accept() session_id = str(uuid.uuid4())[:8] logger.info(f"[{session_id}] Client connected") # Send ready message await websocket.send_json({ "type": "ready", "session_id": session_id, "model_loaded": ASR_MODEL is not None, }) if ASR_MODEL is None: await websocket.send_json({ "type": "error", "message": "Model not loaded. Please wait and try again.", }) await websocket.close() return # Audio buffer audio_buffer = np.array([], dtype=np.float32) chunk_count = 0 last_transcript = "" # Processing settings MIN_AUDIO_LENGTH = 8000 # 0.5 seconds at 16kHz MAX_AUDIO_LENGTH = 80000 # 5 seconds at 16kHz PROCESS_EVERY_N_CHUNKS = 3 # Process every N chunks for efficiency try: while True: message = await websocket.receive() if message["type"] == "websocket.disconnect": break # Handle binary audio data if "bytes" in message: audio_bytes = message["bytes"] chunk_count += 1 # Convert bytes to numpy array (expecting 16-bit PCM) audio_chunk = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 # Add to buffer audio_buffer = np.concatenate([audio_buffer, audio_chunk]) # Log periodically if chunk_count % 20 == 0: logger.debug(f"[{session_id}] Chunks: {chunk_count}, Buffer: {len(audio_buffer)} samples") # Process when we have enough audio if len(audio_buffer) >= MIN_AUDIO_LENGTH and chunk_count % PROCESS_EVERY_N_CHUNKS == 0: # Use last N samples for context audio_context = audio_buffer[-MAX_AUDIO_LENGTH:] if len(audio_buffer) > MAX_AUDIO_LENGTH else audio_buffer try: with torch.no_grad(): start_time = datetime.now() results = ASR_MODEL.transcribe([audio_context]) inference_time = (datetime.now() - start_time).total_seconds() * 1000 if results and len(results) > 0: hyp = results[0] # Extract text if isinstance(hyp, str): text = hyp elif hasattr(hyp, 'text'): text = hyp.text elif hasattr(hyp, 'pred_text'): text = hyp.pred_text else: text = str(hyp) text = text.strip() if text and text != last_transcript: last_transcript = text logger.info(f"[{session_id}] ({inference_time:.0f}ms) {text[:60]}...") await websocket.send_json({ "type": "transcript", "text": text, "is_final": False, "latency_ms": inference_time, }) except Exception as e: logger.error(f"[{session_id}] Inference error: {e}") # Trim buffer to prevent memory growth if len(audio_buffer) > MAX_AUDIO_LENGTH: audio_buffer = audio_buffer[-MAX_AUDIO_LENGTH:] # Handle JSON control messages elif "text" in message: try: data = json.loads(message["text"]) msg_type = data.get("type") if msg_type == "reset": audio_buffer = np.array([], dtype=np.float32) chunk_count = 0 last_transcript = "" logger.info(f"[{session_id}] Session reset") await websocket.send_json({"type": "reset_ack"}) elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except json.JSONDecodeError: pass except WebSocketDisconnect: logger.info(f"[{session_id}] Client disconnected") except Exception as e: logger.error(f"[{session_id}] WebSocket error: {e}") finally: logger.info(f"[{session_id}] Session ended (processed {chunk_count} chunks)") # Mount static files static_path = Path(__file__).parent / "static" if static_path.exists(): app.mount("/static", StaticFiles(directory=str(static_path)), name="static") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)