#!/usr/bin/env python3 """ Standalone WebSocket-only STT Service Simplified service without Gradio, MCP, or web interfaces Following unmute.sh WebRTC pattern for HuggingFace Spaces """ import asyncio import json import uuid import base64 import tempfile import os import logging from datetime import datetime from typing import Optional, Dict, Any import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration import torchaudio import soundfile as sf import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware import spaces import uvicorn # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Version info __version__ = "1.0.0" __service__ = "STT WebSocket Service" class STTWebSocketService: """Standalone STT service with WebSocket-only interface""" def __init__(self): self.model = None self.processor = None self.model_size = "base" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.active_connections: Dict[str, WebSocket] = {} logger.info(f"🎤 {__service__} v{__version__} initializing...") logger.info(f"Device: {self.device}") logger.info(f"Model: whisper-{self.model_size}") async def load_model(self): """Load Whisper model with ZeroGPU compatibility""" if self.model is None: logger.info(f"Loading Whisper {self.model_size} model...") model_name = f"openai/whisper-{self.model_size}" self.processor = WhisperProcessor.from_pretrained(model_name) self.model = WhisperForConditionalGeneration.from_pretrained(model_name) if self.device == "cuda": self.model = self.model.to(self.device) logger.info(f"✅ Model loaded on {self.device}") @spaces.GPU(duration=30) async def transcribe_audio( self, audio_path: str, language: str = "auto", model_size: str = "base" ) -> tuple[str, str, Dict[str, Any]]: """Transcribe audio file using Whisper with ZeroGPU""" try: start_time = datetime.now() # Ensure model is loaded if self.model is None: await self.load_model() # Load and preprocess audio (following unmute.sh pattern) audio_input, sample_rate = torchaudio.load(audio_path) # Convert to 16kHz mono (Whisper requirement) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) audio_input = resampler(audio_input) if audio_input.shape[0] > 1: audio_input = torch.mean(audio_input, dim=0, keepdim=True) audio_array = audio_input.squeeze().numpy() # Process with Whisper inputs = self.processor( audio_array, sampling_rate=16000, return_tensors="pt" ) if self.device == "cuda": inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate transcription with torch.no_grad(): predicted_ids = self.model.generate(**inputs) transcription = self.processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] # Calculate timing end_time = datetime.now() processing_time = (end_time - start_time).total_seconds() timing_info = { "processing_time": processing_time, "start_time": start_time.isoformat(), "end_time": end_time.isoformat(), "model_size": model_size, "device": self.device } logger.info(f"Transcription completed in {processing_time:.2f}s: '{transcription[:50]}...'") return transcription.strip(), "success", timing_info except Exception as e: logger.error(f"Transcription error: {str(e)}") return "", "error", {"error": str(e)} async def connect_websocket(self, websocket: WebSocket) -> str: """Accept WebSocket connection and return client ID""" client_id = str(uuid.uuid4()) await websocket.accept() self.active_connections[client_id] = websocket # Send connection confirmation await websocket.send_text(json.dumps({ "type": "stt_connection_confirmed", "client_id": client_id, "service": __service__, "version": __version__, "model": f"whisper-{self.model_size}", "device": self.device, "message": "STT WebSocket connected and ready" })) logger.info(f"Client {client_id} connected") return client_id async def disconnect_websocket(self, client_id: str): """Clean up WebSocket connection""" if client_id in self.active_connections: del self.active_connections[client_id] logger.info(f"Client {client_id} disconnected") async def process_audio_message(self, client_id: str, message: Dict[str, Any]): """Process incoming audio data from WebSocket""" try: websocket = self.active_connections[client_id] # Extract audio data (base64 encoded) audio_data_b64 = message.get("audio_data") if not audio_data_b64: await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": "No audio data provided" })) return # Decode base64 audio audio_bytes = base64.b64decode(audio_data_b64) # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file: tmp_file.write(audio_bytes) temp_path = tmp_file.name try: # Transcribe audio transcription, status, timing = await self.transcribe_audio( temp_path, message.get("language", "auto"), message.get("model_size", self.model_size) ) # Send result back if status == "success" and transcription: await websocket.send_text(json.dumps({ "type": "stt_transcription_complete", "client_id": client_id, "transcription": transcription, "timing": timing, "status": "success" })) else: await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": "Transcription failed or empty result", "timing": timing })) finally: # Clean up temp file if os.path.exists(temp_path): os.unlink(temp_path) except Exception as e: logger.error(f"Error processing audio for {client_id}: {str(e)}") if client_id in self.active_connections: websocket = self.active_connections[client_id] await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": f"Processing error: {str(e)}" })) # Initialize service stt_service = STTWebSocketService() # Create FastAPI app app = FastAPI( title="STT WebSocket Service", description="Standalone WebSocket-only Speech-to-Text service", version=__version__ ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """Initialize service on startup""" logger.info(f"🚀 {__service__} v{__version__} starting...") logger.info("Pre-loading Whisper model for optimal performance...") await stt_service.load_model() logger.info("✅ Service ready for WebSocket connections") @app.get("/") async def root(): """Health check endpoint""" return { "service": __service__, "version": __version__, "status": "ready", "endpoints": { "websocket": "/ws/stt", "health": "/health" }, "model": f"whisper-{stt_service.model_size}", "device": stt_service.device } @app.get("/health") async def health_check(): """Detailed health check""" return { "service": __service__, "version": __version__, "status": "healthy", "model_loaded": stt_service.model is not None, "active_connections": len(stt_service.active_connections), "device": stt_service.device, "timestamp": datetime.now().isoformat() } @app.websocket("/ws/stt") async def websocket_stt_endpoint(websocket: WebSocket): """Main STT WebSocket endpoint""" client_id = None try: # Accept connection client_id = await stt_service.connect_websocket(websocket) # Handle messages while True: try: # Receive message data = await websocket.receive_text() message = json.loads(data) # Process based on message type message_type = message.get("type", "unknown") if message_type == "stt_audio_chunk": await stt_service.process_audio_message(client_id, message) elif message_type == "ping": # Respond to ping await websocket.send_text(json.dumps({ "type": "pong", "client_id": client_id, "timestamp": datetime.now().isoformat() })) else: logger.warning(f"Unknown message type from {client_id}: {message_type}") except WebSocketDisconnect: break except json.JSONDecodeError: await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": "Invalid JSON message format" })) except Exception as e: logger.error(f"Error handling message from {client_id}: {str(e)}") break except WebSocketDisconnect: logger.info(f"Client {client_id} disconnected normally") except Exception as e: logger.error(f"WebSocket error for {client_id}: {str(e)}") finally: if client_id: await stt_service.disconnect_websocket(client_id) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) logger.info(f"🎤 Starting {__service__} v{__version__} on port {port}") uvicorn.run( app, host="0.0.0.0", port=port, log_level="info" )