use anyhow::Result; use base64::{Engine as _, engine::general_purpose}; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, RwLock}; use tokio_tungstenite::{accept_async, tungstenite::Message}; use tracing::{info, warn, error, debug}; use uuid::Uuid; use crate::audio::{AudioChunk, AudioProcessor, TranscriptionResult}; use crate::model::ModelManager; #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type")] pub enum ClientMessage { #[serde(rename = "audio")] Audio { data: String, // Base64 encoded audio data sample_rate: u32, channels: u16, timestamp: u64, }, #[serde(rename = "start")] StartStream { config: Option, }, #[serde(rename = "stop")] StopStream, #[serde(rename = "ping")] Ping, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct StreamConfig { pub enable_timestamps: Option, pub enable_vad: Option, pub language: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type")] pub enum ServerMessage { #[serde(rename = "transcription")] Transcription { result: TranscriptionResult }, #[serde(rename = "error")] Error { message: String }, #[serde(rename = "status")] Status { message: String }, #[serde(rename = "pong")] Pong, } pub struct ClientSession { id: Uuid, audio_processor: AudioProcessor, is_streaming: bool, config: Option, } impl ClientSession { fn new(id: Uuid) -> Self { Self { id, audio_processor: AudioProcessor::new(24000), // 24kHz target sample rate is_streaming: false, config: None, } } } pub struct ServerState { clients: RwLock>>>, model_manager: Arc, } pub async fn start_server(addr: String, model_manager: ModelManager) -> Result<()> { let server_state = Arc::new(ServerState { clients: RwLock::new(HashMap::new()), model_manager: Arc::new(model_manager), }); // Start HTTP health check server on a separate port let health_server_state = Arc::clone(&server_state); let health_addr = addr.replace(":8080", ":8081"); // Health check on port 8081 tokio::spawn(async move { if let Err(e) = start_health_server(health_addr, health_server_state).await { error!("Health server error: {}", e); } }); // Start main WebSocket server let listener = TcpListener::bind(&addr).await?; info!("WebSocket server listening on: {}", addr); info!("Health check available on: {}", addr.replace(":8080", ":8081")); while let Ok((stream, addr)) = listener.accept().await { let server_state = Arc::clone(&server_state); tokio::spawn(handle_connection(stream, addr.to_string(), server_state)); } Ok(()) } async fn start_health_server(addr: String, _server_state: Arc) -> Result<()> { use axum::{response::Json, routing::get, Router}; use serde_json::json; let app = Router::new() .route("/health", get(|| async { Json(json!({ "status": "healthy", "service": "kyutai-stt-server", "version": "0.1.0" })) })); let listener = tokio::net::TcpListener::bind(&addr).await?; info!("Health server listening on: {}", addr); axum::serve(listener, app).await?; Ok(()) } async fn handle_connection( stream: TcpStream, addr: String, server_state: Arc, ) { info!("New connection from: {}", addr); let ws_stream = match accept_async(stream).await { Ok(ws) => ws, Err(e) => { error!("WebSocket connection error: {}", e); return; } }; let client_id = Uuid::new_v4(); let session = Arc::new(Mutex::new(ClientSession::new(client_id))); // Register client server_state.clients.write().await.insert(client_id, Arc::clone(&session)); let (mut ws_sender, mut ws_receiver) = ws_stream.split(); // Send welcome message let welcome_msg = ServerMessage::Status { message: format!("Connected to Kyutai STT Server. Session ID: {}", client_id), }; if let Ok(msg) = serde_json::to_string(&welcome_msg) { if let Err(e) = ws_sender.send(Message::Text(msg)).await { error!("Failed to send welcome message: {}", e); } } // Handle incoming messages while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Text(text)) => { if let Err(e) = handle_text_message( text, &session, &mut ws_sender, &server_state.model_manager, ).await { error!("Error handling message: {}", e); let error_msg = ServerMessage::Error { message: format!("Error processing message: {}", e), }; if let Ok(msg) = serde_json::to_string(&error_msg) { let _ = ws_sender.send(Message::Text(msg)).await; } } } Ok(Message::Binary(data)) => { warn!("Received binary message (not supported): {} bytes", data.len()); } Ok(Message::Close(_)) => { info!("Client {} disconnected", client_id); break; } Ok(Message::Ping(data)) => { let _ = ws_sender.send(Message::Pong(data)).await; } Ok(Message::Pong(_)) => { debug!("Received pong from client {}", client_id); } Ok(Message::Frame(_)) => { // Raw frame messages are handled internally by tungstenite debug!("Received raw frame from client {}", client_id); } Err(e) => { error!("WebSocket error: {}", e); break; } } } // Cleanup server_state.clients.write().await.remove(&client_id); info!("Client {} session ended", client_id); } async fn handle_text_message( text: String, session: &Arc>, ws_sender: &mut futures_util::stream::SplitSink, Message>, model_manager: &Arc, ) -> Result<()> { let client_msg: ClientMessage = serde_json::from_str(&text)?; let mut session = session.lock().await; match client_msg { ClientMessage::StartStream { config } => { info!("Starting stream for client {}", session.id); session.is_streaming = true; session.config = config; session.audio_processor.clear_buffer(); let response = ServerMessage::Status { message: "Streaming started".to_string(), }; send_response(ws_sender, response).await?; } ClientMessage::StopStream => { info!("Stopping stream for client {}", session.id); session.is_streaming = false; session.audio_processor.clear_buffer(); let response = ServerMessage::Status { message: "Streaming stopped".to_string(), }; send_response(ws_sender, response).await?; } ClientMessage::Audio { data, sample_rate, channels, timestamp } => { if !session.is_streaming { return Ok(()); } // Decode base64 audio data let audio_bytes = general_purpose::STANDARD.decode(&data)?; let audio_data: Vec = audio_bytes .chunks(4) .map(|chunk| { let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]]; f32::from_le_bytes(bytes) }) .collect(); let audio_chunk = AudioChunk { data: audio_data, sample_rate, channels, timestamp, }; // Process audio let _processed_audio = session.audio_processor.process_chunk(audio_chunk)?; // Check if we have enough audio for transcription const MIN_AUDIO_LENGTH: usize = 24000; // 1 second at 24kHz if let Some(buffered_audio) = session.audio_processor.get_buffered_audio(MIN_AUDIO_LENGTH) { debug!("Processing {} samples for transcription", buffered_audio.len()); // Run transcription with language from config let language = session.config.as_ref().and_then(|c| c.language.clone()); match model_manager.transcribe_with_language(buffered_audio, language).await { Ok(transcription_text) => { let result = TranscriptionResult { text: transcription_text, confidence: None, start_time: None, end_time: None, words: None, }; let response = ServerMessage::Transcription { result }; send_response(ws_sender, response).await?; } Err(e) => { error!("Transcription error: {}", e); let response = ServerMessage::Error { message: format!("Transcription failed: {}", e), }; send_response(ws_sender, response).await?; } } } } ClientMessage::Ping => { let response = ServerMessage::Pong; send_response(ws_sender, response).await?; } } Ok(()) } async fn send_response( ws_sender: &mut futures_util::stream::SplitSink, Message>, response: ServerMessage, ) -> Result<()> { let json = serde_json::to_string(&response)?; ws_sender.send(Message::Text(json)).await?; Ok(()) }