stt-gpu-service / websocket_stt_server.py
Peter Michael Gits
feat: Add standalone WebSocket-only STT service v1.0.0
69f7704
#!/usr/bin/env python3
"""
Standalone WebSocket-only STT Service
Simplified service without Gradio, MCP, or web interfaces
Following unmute.sh WebRTC pattern for HuggingFace Spaces
"""
import asyncio
import json
import uuid
import base64
import tempfile
import os
import logging
from datetime import datetime
from typing import Optional, Dict, Any
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torchaudio
import soundfile as sf
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import spaces
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Version info
__version__ = "1.0.0"
__service__ = "STT WebSocket Service"
class STTWebSocketService:
"""Standalone STT service with WebSocket-only interface"""
def __init__(self):
self.model = None
self.processor = None
self.model_size = "base"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.active_connections: Dict[str, WebSocket] = {}
logger.info(f"🎤 {__service__} v{__version__} initializing...")
logger.info(f"Device: {self.device}")
logger.info(f"Model: whisper-{self.model_size}")
async def load_model(self):
"""Load Whisper model with ZeroGPU compatibility"""
if self.model is None:
logger.info(f"Loading Whisper {self.model_size} model...")
model_name = f"openai/whisper-{self.model_size}"
self.processor = WhisperProcessor.from_pretrained(model_name)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
if self.device == "cuda":
self.model = self.model.to(self.device)
logger.info(f"✅ Model loaded on {self.device}")
@spaces.GPU(duration=30)
async def transcribe_audio(
self,
audio_path: str,
language: str = "auto",
model_size: str = "base"
) -> tuple[str, str, Dict[str, Any]]:
"""Transcribe audio file using Whisper with ZeroGPU"""
try:
start_time = datetime.now()
# Ensure model is loaded
if self.model is None:
await self.load_model()
# Load and preprocess audio (following unmute.sh pattern)
audio_input, sample_rate = torchaudio.load(audio_path)
# Convert to 16kHz mono (Whisper requirement)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_input = resampler(audio_input)
if audio_input.shape[0] > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
audio_array = audio_input.squeeze().numpy()
# Process with Whisper
inputs = self.processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
)
if self.device == "cuda":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate transcription
with torch.no_grad():
predicted_ids = self.model.generate(**inputs)
transcription = self.processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
# Calculate timing
end_time = datetime.now()
processing_time = (end_time - start_time).total_seconds()
timing_info = {
"processing_time": processing_time,
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"model_size": model_size,
"device": self.device
}
logger.info(f"Transcription completed in {processing_time:.2f}s: '{transcription[:50]}...'")
return transcription.strip(), "success", timing_info
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
return "", "error", {"error": str(e)}
async def connect_websocket(self, websocket: WebSocket) -> str:
"""Accept WebSocket connection and return client ID"""
client_id = str(uuid.uuid4())
await websocket.accept()
self.active_connections[client_id] = websocket
# Send connection confirmation
await websocket.send_text(json.dumps({
"type": "stt_connection_confirmed",
"client_id": client_id,
"service": __service__,
"version": __version__,
"model": f"whisper-{self.model_size}",
"device": self.device,
"message": "STT WebSocket connected and ready"
}))
logger.info(f"Client {client_id} connected")
return client_id
async def disconnect_websocket(self, client_id: str):
"""Clean up WebSocket connection"""
if client_id in self.active_connections:
del self.active_connections[client_id]
logger.info(f"Client {client_id} disconnected")
async def process_audio_message(self, client_id: str, message: Dict[str, Any]):
"""Process incoming audio data from WebSocket"""
try:
websocket = self.active_connections[client_id]
# Extract audio data (base64 encoded)
audio_data_b64 = message.get("audio_data")
if not audio_data_b64:
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": "No audio data provided"
}))
return
# Decode base64 audio
audio_bytes = base64.b64decode(audio_data_b64)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file:
tmp_file.write(audio_bytes)
temp_path = tmp_file.name
try:
# Transcribe audio
transcription, status, timing = await self.transcribe_audio(
temp_path,
message.get("language", "auto"),
message.get("model_size", self.model_size)
)
# Send result back
if status == "success" and transcription:
await websocket.send_text(json.dumps({
"type": "stt_transcription_complete",
"client_id": client_id,
"transcription": transcription,
"timing": timing,
"status": "success"
}))
else:
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": "Transcription failed or empty result",
"timing": timing
}))
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"Error processing audio for {client_id}: {str(e)}")
if client_id in self.active_connections:
websocket = self.active_connections[client_id]
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": f"Processing error: {str(e)}"
}))
# Initialize service
stt_service = STTWebSocketService()
# Create FastAPI app
app = FastAPI(
title="STT WebSocket Service",
description="Standalone WebSocket-only Speech-to-Text service",
version=__version__
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Initialize service on startup"""
logger.info(f"🚀 {__service__} v{__version__} starting...")
logger.info("Pre-loading Whisper model for optimal performance...")
await stt_service.load_model()
logger.info("✅ Service ready for WebSocket connections")
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"service": __service__,
"version": __version__,
"status": "ready",
"endpoints": {
"websocket": "/ws/stt",
"health": "/health"
},
"model": f"whisper-{stt_service.model_size}",
"device": stt_service.device
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
return {
"service": __service__,
"version": __version__,
"status": "healthy",
"model_loaded": stt_service.model is not None,
"active_connections": len(stt_service.active_connections),
"device": stt_service.device,
"timestamp": datetime.now().isoformat()
}
@app.websocket("/ws/stt")
async def websocket_stt_endpoint(websocket: WebSocket):
"""Main STT WebSocket endpoint"""
client_id = None
try:
# Accept connection
client_id = await stt_service.connect_websocket(websocket)
# Handle messages
while True:
try:
# Receive message
data = await websocket.receive_text()
message = json.loads(data)
# Process based on message type
message_type = message.get("type", "unknown")
if message_type == "stt_audio_chunk":
await stt_service.process_audio_message(client_id, message)
elif message_type == "ping":
# Respond to ping
await websocket.send_text(json.dumps({
"type": "pong",
"client_id": client_id,
"timestamp": datetime.now().isoformat()
}))
else:
logger.warning(f"Unknown message type from {client_id}: {message_type}")
except WebSocketDisconnect:
break
except json.JSONDecodeError:
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": "Invalid JSON message format"
}))
except Exception as e:
logger.error(f"Error handling message from {client_id}: {str(e)}")
break
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected normally")
except Exception as e:
logger.error(f"WebSocket error for {client_id}: {str(e)}")
finally:
if client_id:
await stt_service.disconnect_websocket(client_id)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
logger.info(f"🎤 Starting {__service__} v{__version__} on port {port}")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)