Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Standalone WebSocket-only STT Service | |
| Simplified service without Gradio, MCP, or web interfaces | |
| Following unmute.sh WebRTC pattern for HuggingFace Spaces | |
| """ | |
| import asyncio | |
| import json | |
| import uuid | |
| import base64 | |
| import tempfile | |
| import os | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional, Dict, Any | |
| import torch | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import torchaudio | |
| import soundfile as sf | |
| import numpy as np | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import spaces | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Version info | |
| __version__ = "1.0.0" | |
| __service__ = "STT WebSocket Service" | |
| class STTWebSocketService: | |
| """Standalone STT service with WebSocket-only interface""" | |
| def __init__(self): | |
| self.model = None | |
| self.processor = None | |
| self.model_size = "base" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| logger.info(f"🎤 {__service__} v{__version__} initializing...") | |
| logger.info(f"Device: {self.device}") | |
| logger.info(f"Model: whisper-{self.model_size}") | |
| async def load_model(self): | |
| """Load Whisper model with ZeroGPU compatibility""" | |
| if self.model is None: | |
| logger.info(f"Loading Whisper {self.model_size} model...") | |
| model_name = f"openai/whisper-{self.model_size}" | |
| self.processor = WhisperProcessor.from_pretrained(model_name) | |
| self.model = WhisperForConditionalGeneration.from_pretrained(model_name) | |
| if self.device == "cuda": | |
| self.model = self.model.to(self.device) | |
| logger.info(f"✅ Model loaded on {self.device}") | |
| async def transcribe_audio( | |
| self, | |
| audio_path: str, | |
| language: str = "auto", | |
| model_size: str = "base" | |
| ) -> tuple[str, str, Dict[str, Any]]: | |
| """Transcribe audio file using Whisper with ZeroGPU""" | |
| try: | |
| start_time = datetime.now() | |
| # Ensure model is loaded | |
| if self.model is None: | |
| await self.load_model() | |
| # Load and preprocess audio (following unmute.sh pattern) | |
| audio_input, sample_rate = torchaudio.load(audio_path) | |
| # Convert to 16kHz mono (Whisper requirement) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| audio_input = resampler(audio_input) | |
| if audio_input.shape[0] > 1: | |
| audio_input = torch.mean(audio_input, dim=0, keepdim=True) | |
| audio_array = audio_input.squeeze().numpy() | |
| # Process with Whisper | |
| inputs = self.processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| if self.device == "cuda": | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate transcription | |
| with torch.no_grad(): | |
| predicted_ids = self.model.generate(**inputs) | |
| transcription = self.processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| # Calculate timing | |
| end_time = datetime.now() | |
| processing_time = (end_time - start_time).total_seconds() | |
| timing_info = { | |
| "processing_time": processing_time, | |
| "start_time": start_time.isoformat(), | |
| "end_time": end_time.isoformat(), | |
| "model_size": model_size, | |
| "device": self.device | |
| } | |
| logger.info(f"Transcription completed in {processing_time:.2f}s: '{transcription[:50]}...'") | |
| return transcription.strip(), "success", timing_info | |
| except Exception as e: | |
| logger.error(f"Transcription error: {str(e)}") | |
| return "", "error", {"error": str(e)} | |
| async def connect_websocket(self, websocket: WebSocket) -> str: | |
| """Accept WebSocket connection and return client ID""" | |
| client_id = str(uuid.uuid4()) | |
| await websocket.accept() | |
| self.active_connections[client_id] = websocket | |
| # Send connection confirmation | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_connection_confirmed", | |
| "client_id": client_id, | |
| "service": __service__, | |
| "version": __version__, | |
| "model": f"whisper-{self.model_size}", | |
| "device": self.device, | |
| "message": "STT WebSocket connected and ready" | |
| })) | |
| logger.info(f"Client {client_id} connected") | |
| return client_id | |
| async def disconnect_websocket(self, client_id: str): | |
| """Clean up WebSocket connection""" | |
| if client_id in self.active_connections: | |
| del self.active_connections[client_id] | |
| logger.info(f"Client {client_id} disconnected") | |
| async def process_audio_message(self, client_id: str, message: Dict[str, Any]): | |
| """Process incoming audio data from WebSocket""" | |
| try: | |
| websocket = self.active_connections[client_id] | |
| # Extract audio data (base64 encoded) | |
| audio_data_b64 = message.get("audio_data") | |
| if not audio_data_b64: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": "No audio data provided" | |
| })) | |
| return | |
| # Decode base64 audio | |
| audio_bytes = base64.b64decode(audio_data_b64) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file: | |
| tmp_file.write(audio_bytes) | |
| temp_path = tmp_file.name | |
| try: | |
| # Transcribe audio | |
| transcription, status, timing = await self.transcribe_audio( | |
| temp_path, | |
| message.get("language", "auto"), | |
| message.get("model_size", self.model_size) | |
| ) | |
| # Send result back | |
| if status == "success" and transcription: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_complete", | |
| "client_id": client_id, | |
| "transcription": transcription, | |
| "timing": timing, | |
| "status": "success" | |
| })) | |
| else: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": "Transcription failed or empty result", | |
| "timing": timing | |
| })) | |
| finally: | |
| # Clean up temp file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| logger.error(f"Error processing audio for {client_id}: {str(e)}") | |
| if client_id in self.active_connections: | |
| websocket = self.active_connections[client_id] | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": f"Processing error: {str(e)}" | |
| })) | |
| # Initialize service | |
| stt_service = STTWebSocketService() | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="STT WebSocket Service", | |
| description="Standalone WebSocket-only Speech-to-Text service", | |
| version=__version__ | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Initialize service on startup""" | |
| logger.info(f"🚀 {__service__} v{__version__} starting...") | |
| logger.info("Pre-loading Whisper model for optimal performance...") | |
| await stt_service.load_model() | |
| logger.info("✅ Service ready for WebSocket connections") | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "service": __service__, | |
| "version": __version__, | |
| "status": "ready", | |
| "endpoints": { | |
| "websocket": "/ws/stt", | |
| "health": "/health" | |
| }, | |
| "model": f"whisper-{stt_service.model_size}", | |
| "device": stt_service.device | |
| } | |
| async def health_check(): | |
| """Detailed health check""" | |
| return { | |
| "service": __service__, | |
| "version": __version__, | |
| "status": "healthy", | |
| "model_loaded": stt_service.model is not None, | |
| "active_connections": len(stt_service.active_connections), | |
| "device": stt_service.device, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def websocket_stt_endpoint(websocket: WebSocket): | |
| """Main STT WebSocket endpoint""" | |
| client_id = None | |
| try: | |
| # Accept connection | |
| client_id = await stt_service.connect_websocket(websocket) | |
| # Handle messages | |
| while True: | |
| try: | |
| # Receive message | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| # Process based on message type | |
| message_type = message.get("type", "unknown") | |
| if message_type == "stt_audio_chunk": | |
| await stt_service.process_audio_message(client_id, message) | |
| elif message_type == "ping": | |
| # Respond to ping | |
| await websocket.send_text(json.dumps({ | |
| "type": "pong", | |
| "client_id": client_id, | |
| "timestamp": datetime.now().isoformat() | |
| })) | |
| else: | |
| logger.warning(f"Unknown message type from {client_id}: {message_type}") | |
| except WebSocketDisconnect: | |
| break | |
| except json.JSONDecodeError: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": "Invalid JSON message format" | |
| })) | |
| except Exception as e: | |
| logger.error(f"Error handling message from {client_id}: {str(e)}") | |
| break | |
| except WebSocketDisconnect: | |
| logger.info(f"Client {client_id} disconnected normally") | |
| except Exception as e: | |
| logger.error(f"WebSocket error for {client_id}: {str(e)}") | |
| finally: | |
| if client_id: | |
| await stt_service.disconnect_websocket(client_id) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| logger.info(f"🎤 Starting {__service__} v{__version__} on port {port}") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="info" | |
| ) |