Spaces:
Runtime error
Runtime error
| """ | |
| FastAPI + WebSocket backend for real-time speech transcription. | |
| Uses NeMo ASR model directly (no Triton required). | |
| """ | |
| import asyncio | |
| import json | |
| import uuid | |
| import sys | |
| from pathlib import Path | |
| from typing import Optional, AsyncIterator | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from loguru import logger | |
| # Configure logging | |
| logger.remove() | |
| logger.add( | |
| sys.stderr, | |
| format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>", | |
| level="INFO", | |
| ) | |
| # Global model | |
| ASR_MODEL = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(): | |
| """Load the NeMo ASR model.""" | |
| global ASR_MODEL | |
| logger.info("Loading NeMo ASR Model...") | |
| try: | |
| import nemo.collections.asr as nemo_asr | |
| ASR_MODEL = nemo_asr.models.ASRModel.from_pretrained( | |
| model_name="nvidia/nemotron-speech-streaming-en-0.6b" | |
| ) | |
| ASR_MODEL.eval() | |
| if torch.cuda.is_available(): | |
| logger.info("Moving model to CUDA") | |
| ASR_MODEL = ASR_MODEL.cuda() | |
| else: | |
| logger.warning("CUDA not available, using CPU (will be slow)") | |
| logger.info("Model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| return False | |
| # Create FastAPI app | |
| app = FastAPI(title="Nemotron Speech Streaming") | |
| async def startup(): | |
| """Load model on startup.""" | |
| load_model() | |
| async def health(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": ASR_MODEL is not None, | |
| "device": DEVICE, | |
| } | |
| async def root(): | |
| """Serve the frontend.""" | |
| return FileResponse(Path(__file__).parent / "static" / "index.html") | |
| async def websocket_transcribe(websocket: WebSocket): | |
| """ | |
| WebSocket endpoint for streaming transcription. | |
| Protocol: | |
| - Client sends binary PCM audio data (16-bit, 16kHz, mono) | |
| - Server sends JSON: {"type": "transcript", "text": "...", "is_final": bool} | |
| """ | |
| await websocket.accept() | |
| session_id = str(uuid.uuid4())[:8] | |
| logger.info(f"[{session_id}] Client connected") | |
| # Send ready message | |
| await websocket.send_json({ | |
| "type": "ready", | |
| "session_id": session_id, | |
| "model_loaded": ASR_MODEL is not None, | |
| }) | |
| if ASR_MODEL is None: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Model not loaded. Please wait and try again.", | |
| }) | |
| await websocket.close() | |
| return | |
| # Audio buffer | |
| audio_buffer = np.array([], dtype=np.float32) | |
| chunk_count = 0 | |
| last_transcript = "" | |
| # Processing settings | |
| MIN_AUDIO_LENGTH = 8000 # 0.5 seconds at 16kHz | |
| MAX_AUDIO_LENGTH = 80000 # 5 seconds at 16kHz | |
| PROCESS_EVERY_N_CHUNKS = 3 # Process every N chunks for efficiency | |
| try: | |
| while True: | |
| message = await websocket.receive() | |
| if message["type"] == "websocket.disconnect": | |
| break | |
| # Handle binary audio data | |
| if "bytes" in message: | |
| audio_bytes = message["bytes"] | |
| chunk_count += 1 | |
| # Convert bytes to numpy array (expecting 16-bit PCM) | |
| audio_chunk = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 | |
| # Add to buffer | |
| audio_buffer = np.concatenate([audio_buffer, audio_chunk]) | |
| # Log periodically | |
| if chunk_count % 20 == 0: | |
| logger.debug(f"[{session_id}] Chunks: {chunk_count}, Buffer: {len(audio_buffer)} samples") | |
| # Process when we have enough audio | |
| if len(audio_buffer) >= MIN_AUDIO_LENGTH and chunk_count % PROCESS_EVERY_N_CHUNKS == 0: | |
| # Use last N samples for context | |
| audio_context = audio_buffer[-MAX_AUDIO_LENGTH:] if len(audio_buffer) > MAX_AUDIO_LENGTH else audio_buffer | |
| try: | |
| with torch.no_grad(): | |
| start_time = datetime.now() | |
| results = ASR_MODEL.transcribe([audio_context]) | |
| inference_time = (datetime.now() - start_time).total_seconds() * 1000 | |
| if results and len(results) > 0: | |
| hyp = results[0] | |
| # Extract text | |
| if isinstance(hyp, str): | |
| text = hyp | |
| elif hasattr(hyp, 'text'): | |
| text = hyp.text | |
| elif hasattr(hyp, 'pred_text'): | |
| text = hyp.pred_text | |
| else: | |
| text = str(hyp) | |
| text = text.strip() | |
| if text and text != last_transcript: | |
| last_transcript = text | |
| logger.info(f"[{session_id}] ({inference_time:.0f}ms) {text[:60]}...") | |
| await websocket.send_json({ | |
| "type": "transcript", | |
| "text": text, | |
| "is_final": False, | |
| "latency_ms": inference_time, | |
| }) | |
| except Exception as e: | |
| logger.error(f"[{session_id}] Inference error: {e}") | |
| # Trim buffer to prevent memory growth | |
| if len(audio_buffer) > MAX_AUDIO_LENGTH: | |
| audio_buffer = audio_buffer[-MAX_AUDIO_LENGTH:] | |
| # Handle JSON control messages | |
| elif "text" in message: | |
| try: | |
| data = json.loads(message["text"]) | |
| msg_type = data.get("type") | |
| if msg_type == "reset": | |
| audio_buffer = np.array([], dtype=np.float32) | |
| chunk_count = 0 | |
| last_transcript = "" | |
| logger.info(f"[{session_id}] Session reset") | |
| await websocket.send_json({"type": "reset_ack"}) | |
| elif msg_type == "ping": | |
| await websocket.send_json({"type": "pong"}) | |
| except json.JSONDecodeError: | |
| pass | |
| except WebSocketDisconnect: | |
| logger.info(f"[{session_id}] Client disconnected") | |
| except Exception as e: | |
| logger.error(f"[{session_id}] WebSocket error: {e}") | |
| finally: | |
| logger.info(f"[{session_id}] Session ended (processed {chunk_count} chunks)") | |
| # Mount static files | |
| static_path = Path(__file__).parent / "static" | |
| if static_path.exists(): | |
| app.mount("/static", StaticFiles(directory=str(static_path)), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |