| |
| """ |
| Unified WebSocket/HTTP Whisper Transcription Server |
| Handles real-time audio streaming, transcription using Whisper, and HTTP serving |
| """ |
|
|
| import asyncio |
| import websockets |
| import json |
| import numpy as np |
| import torch |
| import logging |
| import traceback |
| import os |
| from typing import Dict, Any |
| from aiohttp import web, WSMsgType |
| from aiohttp.web_ws import WebSocketResponse |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| try: |
| from whisper_stream import load_streaming_model_correct |
| from whisper_stream.streaming_decoding import DecodingOptions |
| except ImportError: |
| logger.error("whisper_stream not found. Please install it or use regular whisper") |
| |
| import whisper |
|
|
| class UnifiedTranscriptionServer: |
| def __init__(self, host: str = "0.0.0.0", port: int = 8000): |
| self.host = host |
| self.port = port |
| self.clients: Dict[str, Dict[str, Any]] = {} |
| self.app = web.Application() |
| self.setup_routes() |
| |
| def setup_routes(self): |
| """Setup HTTP routes and WebSocket endpoint""" |
| |
| self.app.router.add_get('/', self.serve_index) |
| self.app.router.add_get('/health', self.health_check) |
| |
| |
| self.app.router.add_get('/ws', self.websocket_handler) |
| |
| |
| if os.path.exists('static'): |
| self.app.router.add_static('/static/', 'static') |
| |
| async def serve_index(self, request): |
| """Serve the main HTML page""" |
| try: |
| with open("./static/client.html", "r", encoding='utf-8') as f: |
| html_content = f.read() |
| return web.Response(text=html_content, content_type='text/html') |
| except FileNotFoundError: |
| return web.Response(text="client.html not found!", status=404) |
| except Exception as e: |
| logger.error(f"Error serving client.html! {e}") |
| return web.Response(text="Error loading page...", status=500) |
| |
| async def health_check(self, request): |
| """Health check endpoint""" |
| return web.json_response({"status": "healthy", "cuda": torch.cuda.is_available()}) |
| |
| async def websocket_handler(self, request): |
| """Handle WebSocket connections""" |
| ws = WebSocketResponse() |
| await ws.prepare(request) |
| |
| |
| client_id = f"{request.remote}:{id(ws)}" |
| logger.info(f"New WebSocket client connected: {client_id}") |
| |
| |
| self.clients[client_id] = { |
| 'websocket': ws, |
| 'model': None, |
| 'config': None, |
| 'buffer': bytearray(), |
| 'total_samples': 0, |
| 'is_first_chunk': True |
| } |
| |
| try: |
| await self.process_websocket_messages(client_id) |
| except Exception as e: |
| logger.error(f"Error handling WebSocket client {client_id}: {e}") |
| logger.error(traceback.format_exc()) |
| finally: |
| |
| if client_id in self.clients: |
| del self.clients[client_id] |
| |
| if not ws.closed: |
| await ws.close() |
| |
| return ws |
| |
| async def process_websocket_messages(self, client_id: str): |
| """Process messages from a WebSocket client""" |
| client = self.clients[client_id] |
| ws = client['websocket'] |
| |
| async for msg in ws: |
| if msg.type == WSMsgType.TEXT: |
| |
| await self.handle_config_message(client_id, msg.data) |
| elif msg.type == WSMsgType.BINARY: |
| |
| await self.handle_audio_data(client_id, msg.data) |
| elif msg.type == WSMsgType.ERROR: |
| logger.error(f'WebSocket error for client {client_id}: {ws.exception()}') |
| break |
| |
| async def handle_config_message(self, client_id: str, message: str): |
| """Handle configuration message from client""" |
| client = self.clients[client_id] |
| ws = client['websocket'] |
| |
| try: |
| config = json.loads(message) |
| logger.info(f"Received config from {client_id}: {config}") |
| |
| |
| required_fields = ['model_size', 'chunk_size', 'beam_size', 'language'] |
| for field in required_fields: |
| if field not in config: |
| await ws.send_str(json.dumps({"error": f"Missing required field: {field}"})) |
| return |
| |
| |
| model_size = config['model_size'] |
| chunk_size = config['chunk_size'] |
| |
| logger.info(f"Loading model {model_size} for client {client_id}") |
| |
| |
| |
| if multilingual := config['language'] != "en": |
| if model_size != "large-v2" or chunk_size != 300: |
| await ws.send_str(json.dumps({"error": f"Running multilingual transcription is available for now only on large-v2 model using chunk size of 300ms."})) |
| return |
|
|
| |
| try: |
| model = load_streaming_model_correct(model_size, chunk_size, multilingual) |
| client['first_chunk'] = True |
| if torch.cuda.is_available(): |
| model = model.to("cuda") |
| logger.info(f"Model loaded on GPU for client {client_id}") |
| else: |
| logger.info(f"Model loaded on CPU for client {client_id}") |
| |
| model.reset(use_stream=True) |
| model.eval() |
| |
| client['model'] = model |
| client['config'] = config |
| |
| await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available()})) |
| |
| except Exception as e: |
| logger.error(f"Error loading streaming model: {e}") |
| |
| try: |
| model = whisper.load_model(model_size) |
| if torch.cuda.is_available(): |
| model = model.to("cuda") |
| |
| client['model'] = model |
| client['config'] = config |
| client['use_streaming'] = False |
| |
| await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available(), "fallback": True})) |
| except Exception as e2: |
| logger.error(f"Error loading fallback model: {e2}") |
| await ws.send_str(json.dumps({"error": f"Failed to load model: {e2}"})) |
| |
| except json.JSONDecodeError as e: |
| await ws.send_str(json.dumps({"error": f"Invalid JSON: {e}"})) |
| except Exception as e: |
| logger.error(f"Error handling config for client {client_id}: {e}") |
| await ws.send_str(json.dumps({"error": str(e)})) |
| |
| async def handle_audio_data(self, client_id: str, audio_data: bytes): |
| """Handle audio data from client""" |
| client = self.clients[client_id] |
| ws = client['websocket'] |
| |
| if client['config'] is None: |
| await ws.send_str(json.dumps({"error": "Config not set"})) |
| return |
| |
| if client['model'] is None: |
| await ws.send_str(json.dumps({"error": "Model not loaded"})) |
| return |
| |
| |
| client['buffer'].extend(audio_data) |
| |
| |
| chunk_size_ms = client['config']['chunk_size'] |
| sample_rate = 16000 |
| chunk_samples = int(sample_rate * (chunk_size_ms / 1000)) |
| chunk_bytes = chunk_samples * 2 |
| if client.get('first_chunk', True): |
| chunk_bytes += 720 |
| |
| |
| while len(client['buffer']) >= chunk_bytes: |
| chunk = client['buffer'][:chunk_bytes] |
| client['buffer'] = client['buffer'][chunk_bytes:] |
| |
| try: |
| if client.get('first_chunk', True): |
| client['first_chunk'] = False |
| await self.transcribe_chunk(client_id, chunk) |
| except Exception as e: |
| logger.error(f"Error transcribing chunk for client {client_id}: {e}") |
| await ws.send_str(json.dumps({"error": f"Transcription error: {str(e)}"})) |
| |
| async def transcribe_chunk(self, client_id: str, chunk: bytes): |
| """Transcribe audio chunk""" |
| client = self.clients[client_id] |
| ws = client['websocket'] |
| model = client['model'] |
| config = client['config'] |
| |
| try: |
| |
| pcm = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0 |
| |
| |
| audio = torch.tensor(pcm) |
| if torch.cuda.is_available() and next(model.parameters()).is_cuda: |
| audio = audio.to("cuda") |
| |
| |
| if hasattr(model, 'decode') and 'use_streaming' not in client: |
| |
| decoding_options = DecodingOptions( |
| language=config['language'], |
| gran=(config['chunk_size'] // 20), |
| single_frame_mel=True, |
| without_timestamps=True, |
| beam_size=config['beam_size'], |
| stream_decode=True, |
| use_ca_kv_cache=True, |
| look_ahead_blocks=model.extra_gran_blocks |
| ) |
| result = model.decode(audio, decoding_options, use_frames=True) |
| text = result.text |
| else: |
| |
| |
| min_length = 16000 |
| if len(audio) < min_length: |
| audio = torch.nn.functional.pad(audio, (0, min_length - len(audio))) |
| |
| result = model.transcribe(audio.cpu().numpy(), |
| language="en", |
| beam_size=config['beam_size'], |
| temperature=config['temperature']) |
| text = result['text'] |
| |
| |
| if text.strip(): |
| client['total_samples'] += len(pcm) |
| duration = client['total_samples'] / 16000 |
| |
| await ws.send_str(json.dumps({ |
| "text": text.strip(), |
| "timestamp": duration, |
| "chunk_duration": len(pcm) / 16000 |
| })) |
| |
| except Exception as e: |
| logger.error(f"Error in transcription for client {client_id}: {e}") |
| logger.exception("Exception occurred") |
| raise |
| |
| async def start_server(self): |
| """Start the unified server""" |
| logger.info(f"Starting unified server on {self.host}:{self.port}") |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") |
| |
| runner = web.AppRunner(self.app) |
| await runner.setup() |
| site = web.TCPSite(runner, self.host, self.port) |
| await site.start() |
| |
| logger.info(f"Server running on http://{self.host}:{self.port}") |
| logger.info(f"WebSocket endpoint: ws://{self.host}:{self.port}/ws") |
| |
| |
| try: |
| await asyncio.Future() |
| except KeyboardInterrupt: |
| logger.info("Server stopped by user") |
| finally: |
| await runner.cleanup() |
|
|
| def main(): |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Unified WebSocket/HTTP Whisper Transcription Server') |
| parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') |
| parser.add_argument('--port', type=int, default=8000, help='Port to bind to') |
| parser.add_argument('--log-level', default='INFO', help='Log level') |
| |
| args = parser.parse_args() |
| |
| |
| logging.getLogger().setLevel(getattr(logging, args.log_level.upper())) |
| |
| server = UnifiedTranscriptionServer(args.host, args.port) |
| |
| try: |
| asyncio.run(server.start_server()) |
| except KeyboardInterrupt: |
| logger.info("Server stopped by user") |
| except Exception as e: |
| logger.error(f"Server error: {e}") |
| logger.error(traceback.format_exc()) |
|
|
| if __name__ == '__main__': |
| main() |