gnumanth's picture
FastAPI + WebSocket streaming with newspaper UI
ea31d8c verified
"""
FastAPI + WebSocket backend for real-time speech transcription.
Uses NeMo ASR model directly (no Triton required).
"""
import asyncio
import json
import uuid
import sys
from pathlib import Path
from typing import Optional, AsyncIterator
from datetime import datetime
import numpy as np
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from loguru import logger
# Configure logging
logger.remove()
logger.add(
sys.stderr,
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
level="INFO",
)
# Global model
ASR_MODEL = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
"""Load the NeMo ASR model."""
global ASR_MODEL
logger.info("Loading NeMo ASR Model...")
try:
import nemo.collections.asr as nemo_asr
ASR_MODEL = nemo_asr.models.ASRModel.from_pretrained(
model_name="nvidia/nemotron-speech-streaming-en-0.6b"
)
ASR_MODEL.eval()
if torch.cuda.is_available():
logger.info("Moving model to CUDA")
ASR_MODEL = ASR_MODEL.cuda()
else:
logger.warning("CUDA not available, using CPU (will be slow)")
logger.info("Model loaded successfully!")
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
# Create FastAPI app
app = FastAPI(title="Nemotron Speech Streaming")
@app.on_event("startup")
async def startup():
"""Load model on startup."""
load_model()
@app.get("/health")
async def health():
"""Health check endpoint."""
return {
"status": "healthy",
"model_loaded": ASR_MODEL is not None,
"device": DEVICE,
}
@app.get("/")
async def root():
"""Serve the frontend."""
return FileResponse(Path(__file__).parent / "static" / "index.html")
@app.websocket("/ws/transcribe")
async def websocket_transcribe(websocket: WebSocket):
"""
WebSocket endpoint for streaming transcription.
Protocol:
- Client sends binary PCM audio data (16-bit, 16kHz, mono)
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": bool}
"""
await websocket.accept()
session_id = str(uuid.uuid4())[:8]
logger.info(f"[{session_id}] Client connected")
# Send ready message
await websocket.send_json({
"type": "ready",
"session_id": session_id,
"model_loaded": ASR_MODEL is not None,
})
if ASR_MODEL is None:
await websocket.send_json({
"type": "error",
"message": "Model not loaded. Please wait and try again.",
})
await websocket.close()
return
# Audio buffer
audio_buffer = np.array([], dtype=np.float32)
chunk_count = 0
last_transcript = ""
# Processing settings
MIN_AUDIO_LENGTH = 8000 # 0.5 seconds at 16kHz
MAX_AUDIO_LENGTH = 80000 # 5 seconds at 16kHz
PROCESS_EVERY_N_CHUNKS = 3 # Process every N chunks for efficiency
try:
while True:
message = await websocket.receive()
if message["type"] == "websocket.disconnect":
break
# Handle binary audio data
if "bytes" in message:
audio_bytes = message["bytes"]
chunk_count += 1
# Convert bytes to numpy array (expecting 16-bit PCM)
audio_chunk = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
# Add to buffer
audio_buffer = np.concatenate([audio_buffer, audio_chunk])
# Log periodically
if chunk_count % 20 == 0:
logger.debug(f"[{session_id}] Chunks: {chunk_count}, Buffer: {len(audio_buffer)} samples")
# Process when we have enough audio
if len(audio_buffer) >= MIN_AUDIO_LENGTH and chunk_count % PROCESS_EVERY_N_CHUNKS == 0:
# Use last N samples for context
audio_context = audio_buffer[-MAX_AUDIO_LENGTH:] if len(audio_buffer) > MAX_AUDIO_LENGTH else audio_buffer
try:
with torch.no_grad():
start_time = datetime.now()
results = ASR_MODEL.transcribe([audio_context])
inference_time = (datetime.now() - start_time).total_seconds() * 1000
if results and len(results) > 0:
hyp = results[0]
# Extract text
if isinstance(hyp, str):
text = hyp
elif hasattr(hyp, 'text'):
text = hyp.text
elif hasattr(hyp, 'pred_text'):
text = hyp.pred_text
else:
text = str(hyp)
text = text.strip()
if text and text != last_transcript:
last_transcript = text
logger.info(f"[{session_id}] ({inference_time:.0f}ms) {text[:60]}...")
await websocket.send_json({
"type": "transcript",
"text": text,
"is_final": False,
"latency_ms": inference_time,
})
except Exception as e:
logger.error(f"[{session_id}] Inference error: {e}")
# Trim buffer to prevent memory growth
if len(audio_buffer) > MAX_AUDIO_LENGTH:
audio_buffer = audio_buffer[-MAX_AUDIO_LENGTH:]
# Handle JSON control messages
elif "text" in message:
try:
data = json.loads(message["text"])
msg_type = data.get("type")
if msg_type == "reset":
audio_buffer = np.array([], dtype=np.float32)
chunk_count = 0
last_transcript = ""
logger.info(f"[{session_id}] Session reset")
await websocket.send_json({"type": "reset_ack"})
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
logger.info(f"[{session_id}] Client disconnected")
except Exception as e:
logger.error(f"[{session_id}] WebSocket error: {e}")
finally:
logger.info(f"[{session_id}] Session ended (processed {chunk_count} chunks)")
# Mount static files
static_path = Path(__file__).parent / "static"
if static_path.exists():
app.mount("/static", StaticFiles(directory=str(static_path)), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)