Spaces:
Sleeping
Sleeping
| use anyhow::Result; | |
| use base64::{Engine as _, engine::general_purpose}; | |
| use futures_util::{SinkExt, StreamExt}; | |
| use serde::{Deserialize, Serialize}; | |
| use std::collections::HashMap; | |
| use std::sync::Arc; | |
| use tokio::net::{TcpListener, TcpStream}; | |
| use tokio::sync::{Mutex, RwLock}; | |
| use tokio_tungstenite::{accept_async, tungstenite::Message}; | |
| use tracing::{info, warn, error, debug}; | |
| use uuid::Uuid; | |
| use crate::audio::{AudioChunk, AudioProcessor, TranscriptionResult}; | |
| use crate::model::ModelManager; | |
| pub enum ClientMessage { | |
| Audio { | |
| data: String, // Base64 encoded audio data | |
| sample_rate: u32, | |
| channels: u16, | |
| timestamp: u64, | |
| }, | |
| StartStream { | |
| config: Option<StreamConfig>, | |
| }, | |
| StopStream, | |
| Ping, | |
| } | |
| pub struct StreamConfig { | |
| pub enable_timestamps: Option<bool>, | |
| pub enable_vad: Option<bool>, | |
| pub language: Option<String>, | |
| } | |
| pub enum ServerMessage { | |
| Transcription { result: TranscriptionResult }, | |
| Error { message: String }, | |
| Status { message: String }, | |
| Pong, | |
| } | |
| pub struct ClientSession { | |
| id: Uuid, | |
| audio_processor: AudioProcessor, | |
| is_streaming: bool, | |
| config: Option<StreamConfig>, | |
| } | |
| impl ClientSession { | |
| fn new(id: Uuid) -> Self { | |
| Self { | |
| id, | |
| audio_processor: AudioProcessor::new(24000), // 24kHz target sample rate | |
| is_streaming: false, | |
| config: None, | |
| } | |
| } | |
| } | |
| pub struct ServerState { | |
| clients: RwLock<HashMap<Uuid, Arc<Mutex<ClientSession>>>>, | |
| model_manager: Arc<ModelManager>, | |
| } | |
| pub async fn start_server(addr: String, model_manager: ModelManager) -> Result<()> { | |
| let server_state = Arc::new(ServerState { | |
| clients: RwLock::new(HashMap::new()), | |
| model_manager: Arc::new(model_manager), | |
| }); | |
| // Start HTTP health check server on a separate port | |
| let health_server_state = Arc::clone(&server_state); | |
| let health_addr = addr.replace(":8080", ":8081"); // Health check on port 8081 | |
| tokio::spawn(async move { | |
| if let Err(e) = start_health_server(health_addr, health_server_state).await { | |
| error!("Health server error: {}", e); | |
| } | |
| }); | |
| // Start main WebSocket server | |
| let listener = TcpListener::bind(&addr).await?; | |
| info!("WebSocket server listening on: {}", addr); | |
| info!("Health check available on: {}", addr.replace(":8080", ":8081")); | |
| while let Ok((stream, addr)) = listener.accept().await { | |
| let server_state = Arc::clone(&server_state); | |
| tokio::spawn(handle_connection(stream, addr.to_string(), server_state)); | |
| } | |
| Ok(()) | |
| } | |
| async fn start_health_server(addr: String, _server_state: Arc<ServerState>) -> Result<()> { | |
| use axum::{response::Json, routing::get, Router}; | |
| use serde_json::json; | |
| let app = Router::new() | |
| .route("/health", get(|| async { | |
| Json(json!({ | |
| "status": "healthy", | |
| "service": "kyutai-stt-server", | |
| "version": "0.1.0" | |
| })) | |
| })); | |
| let listener = tokio::net::TcpListener::bind(&addr).await?; | |
| info!("Health server listening on: {}", addr); | |
| axum::serve(listener, app).await?; | |
| Ok(()) | |
| } | |
| async fn handle_connection( | |
| stream: TcpStream, | |
| addr: String, | |
| server_state: Arc<ServerState>, | |
| ) { | |
| info!("New connection from: {}", addr); | |
| let ws_stream = match accept_async(stream).await { | |
| Ok(ws) => ws, | |
| Err(e) => { | |
| error!("WebSocket connection error: {}", e); | |
| return; | |
| } | |
| }; | |
| let client_id = Uuid::new_v4(); | |
| let session = Arc::new(Mutex::new(ClientSession::new(client_id))); | |
| // Register client | |
| server_state.clients.write().await.insert(client_id, Arc::clone(&session)); | |
| let (mut ws_sender, mut ws_receiver) = ws_stream.split(); | |
| // Send welcome message | |
| let welcome_msg = ServerMessage::Status { | |
| message: format!("Connected to Kyutai STT Server. Session ID: {}", client_id), | |
| }; | |
| if let Ok(msg) = serde_json::to_string(&welcome_msg) { | |
| if let Err(e) = ws_sender.send(Message::Text(msg)).await { | |
| error!("Failed to send welcome message: {}", e); | |
| } | |
| } | |
| // Handle incoming messages | |
| while let Some(msg) = ws_receiver.next().await { | |
| match msg { | |
| Ok(Message::Text(text)) => { | |
| if let Err(e) = handle_text_message( | |
| text, | |
| &session, | |
| &mut ws_sender, | |
| &server_state.model_manager, | |
| ).await { | |
| error!("Error handling message: {}", e); | |
| let error_msg = ServerMessage::Error { | |
| message: format!("Error processing message: {}", e), | |
| }; | |
| if let Ok(msg) = serde_json::to_string(&error_msg) { | |
| let _ = ws_sender.send(Message::Text(msg)).await; | |
| } | |
| } | |
| } | |
| Ok(Message::Binary(data)) => { | |
| warn!("Received binary message (not supported): {} bytes", data.len()); | |
| } | |
| Ok(Message::Close(_)) => { | |
| info!("Client {} disconnected", client_id); | |
| break; | |
| } | |
| Ok(Message::Ping(data)) => { | |
| let _ = ws_sender.send(Message::Pong(data)).await; | |
| } | |
| Ok(Message::Pong(_)) => { | |
| debug!("Received pong from client {}", client_id); | |
| } | |
| Ok(Message::Frame(_)) => { | |
| // Raw frame messages are handled internally by tungstenite | |
| debug!("Received raw frame from client {}", client_id); | |
| } | |
| Err(e) => { | |
| error!("WebSocket error: {}", e); | |
| break; | |
| } | |
| } | |
| } | |
| // Cleanup | |
| server_state.clients.write().await.remove(&client_id); | |
| info!("Client {} session ended", client_id); | |
| } | |
| async fn handle_text_message( | |
| text: String, | |
| session: &Arc<Mutex<ClientSession>>, | |
| ws_sender: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, Message>, | |
| model_manager: &Arc<ModelManager>, | |
| ) -> Result<()> { | |
| let client_msg: ClientMessage = serde_json::from_str(&text)?; | |
| let mut session = session.lock().await; | |
| match client_msg { | |
| ClientMessage::StartStream { config } => { | |
| info!("Starting stream for client {}", session.id); | |
| session.is_streaming = true; | |
| session.config = config; | |
| session.audio_processor.clear_buffer(); | |
| let response = ServerMessage::Status { | |
| message: "Streaming started".to_string(), | |
| }; | |
| send_response(ws_sender, response).await?; | |
| } | |
| ClientMessage::StopStream => { | |
| info!("Stopping stream for client {}", session.id); | |
| session.is_streaming = false; | |
| session.audio_processor.clear_buffer(); | |
| let response = ServerMessage::Status { | |
| message: "Streaming stopped".to_string(), | |
| }; | |
| send_response(ws_sender, response).await?; | |
| } | |
| ClientMessage::Audio { data, sample_rate, channels, timestamp } => { | |
| if !session.is_streaming { | |
| return Ok(()); | |
| } | |
| // Decode base64 audio data | |
| let audio_bytes = general_purpose::STANDARD.decode(&data)?; | |
| let audio_data: Vec<f32> = audio_bytes | |
| .chunks(4) | |
| .map(|chunk| { | |
| let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]]; | |
| f32::from_le_bytes(bytes) | |
| }) | |
| .collect(); | |
| let audio_chunk = AudioChunk { | |
| data: audio_data, | |
| sample_rate, | |
| channels, | |
| timestamp, | |
| }; | |
| // Process audio | |
| let _processed_audio = session.audio_processor.process_chunk(audio_chunk)?; | |
| // Check if we have enough audio for transcription | |
| const MIN_AUDIO_LENGTH: usize = 24000; // 1 second at 24kHz | |
| if let Some(buffered_audio) = session.audio_processor.get_buffered_audio(MIN_AUDIO_LENGTH) { | |
| debug!("Processing {} samples for transcription", buffered_audio.len()); | |
| // Run transcription with language from config | |
| let language = session.config.as_ref().and_then(|c| c.language.clone()); | |
| match model_manager.transcribe_with_language(buffered_audio, language).await { | |
| Ok(transcription_text) => { | |
| let result = TranscriptionResult { | |
| text: transcription_text, | |
| confidence: None, | |
| start_time: None, | |
| end_time: None, | |
| words: None, | |
| }; | |
| let response = ServerMessage::Transcription { result }; | |
| send_response(ws_sender, response).await?; | |
| } | |
| Err(e) => { | |
| error!("Transcription error: {}", e); | |
| let response = ServerMessage::Error { | |
| message: format!("Transcription failed: {}", e), | |
| }; | |
| send_response(ws_sender, response).await?; | |
| } | |
| } | |
| } | |
| } | |
| ClientMessage::Ping => { | |
| let response = ServerMessage::Pong; | |
| send_response(ws_sender, response).await?; | |
| } | |
| } | |
| Ok(()) | |
| } | |
| async fn send_response( | |
| ws_sender: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, Message>, | |
| response: ServerMessage, | |
| ) -> Result<()> { | |
| let json = serde_json::to_string(&response)?; | |
| ws_sender.send(Message::Text(json)).await?; | |
| Ok(()) | |
| } |