stt-gpu-service-v3 / src /server.rs
Peter Michael Gits
v1.4.8: Add language conditioning for multilingual STT model
26d8204
use anyhow::Result;
use base64::{Engine as _, engine::general_purpose};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, RwLock};
use tokio_tungstenite::{accept_async, tungstenite::Message};
use tracing::{info, warn, error, debug};
use uuid::Uuid;
use crate::audio::{AudioChunk, AudioProcessor, TranscriptionResult};
use crate::model::ModelManager;
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ClientMessage {
#[serde(rename = "audio")]
Audio {
data: String, // Base64 encoded audio data
sample_rate: u32,
channels: u16,
timestamp: u64,
},
#[serde(rename = "start")]
StartStream {
config: Option<StreamConfig>,
},
#[serde(rename = "stop")]
StopStream,
#[serde(rename = "ping")]
Ping,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamConfig {
pub enable_timestamps: Option<bool>,
pub enable_vad: Option<bool>,
pub language: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ServerMessage {
#[serde(rename = "transcription")]
Transcription { result: TranscriptionResult },
#[serde(rename = "error")]
Error { message: String },
#[serde(rename = "status")]
Status { message: String },
#[serde(rename = "pong")]
Pong,
}
pub struct ClientSession {
id: Uuid,
audio_processor: AudioProcessor,
is_streaming: bool,
config: Option<StreamConfig>,
}
impl ClientSession {
fn new(id: Uuid) -> Self {
Self {
id,
audio_processor: AudioProcessor::new(24000), // 24kHz target sample rate
is_streaming: false,
config: None,
}
}
}
pub struct ServerState {
clients: RwLock<HashMap<Uuid, Arc<Mutex<ClientSession>>>>,
model_manager: Arc<ModelManager>,
}
pub async fn start_server(addr: String, model_manager: ModelManager) -> Result<()> {
let server_state = Arc::new(ServerState {
clients: RwLock::new(HashMap::new()),
model_manager: Arc::new(model_manager),
});
// Start HTTP health check server on a separate port
let health_server_state = Arc::clone(&server_state);
let health_addr = addr.replace(":8080", ":8081"); // Health check on port 8081
tokio::spawn(async move {
if let Err(e) = start_health_server(health_addr, health_server_state).await {
error!("Health server error: {}", e);
}
});
// Start main WebSocket server
let listener = TcpListener::bind(&addr).await?;
info!("WebSocket server listening on: {}", addr);
info!("Health check available on: {}", addr.replace(":8080", ":8081"));
while let Ok((stream, addr)) = listener.accept().await {
let server_state = Arc::clone(&server_state);
tokio::spawn(handle_connection(stream, addr.to_string(), server_state));
}
Ok(())
}
async fn start_health_server(addr: String, _server_state: Arc<ServerState>) -> Result<()> {
use axum::{response::Json, routing::get, Router};
use serde_json::json;
let app = Router::new()
.route("/health", get(|| async {
Json(json!({
"status": "healthy",
"service": "kyutai-stt-server",
"version": "0.1.0"
}))
}));
let listener = tokio::net::TcpListener::bind(&addr).await?;
info!("Health server listening on: {}", addr);
axum::serve(listener, app).await?;
Ok(())
}
async fn handle_connection(
stream: TcpStream,
addr: String,
server_state: Arc<ServerState>,
) {
info!("New connection from: {}", addr);
let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
error!("WebSocket connection error: {}", e);
return;
}
};
let client_id = Uuid::new_v4();
let session = Arc::new(Mutex::new(ClientSession::new(client_id)));
// Register client
server_state.clients.write().await.insert(client_id, Arc::clone(&session));
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
// Send welcome message
let welcome_msg = ServerMessage::Status {
message: format!("Connected to Kyutai STT Server. Session ID: {}", client_id),
};
if let Ok(msg) = serde_json::to_string(&welcome_msg) {
if let Err(e) = ws_sender.send(Message::Text(msg)).await {
error!("Failed to send welcome message: {}", e);
}
}
// Handle incoming messages
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Err(e) = handle_text_message(
text,
&session,
&mut ws_sender,
&server_state.model_manager,
).await {
error!("Error handling message: {}", e);
let error_msg = ServerMessage::Error {
message: format!("Error processing message: {}", e),
};
if let Ok(msg) = serde_json::to_string(&error_msg) {
let _ = ws_sender.send(Message::Text(msg)).await;
}
}
}
Ok(Message::Binary(data)) => {
warn!("Received binary message (not supported): {} bytes", data.len());
}
Ok(Message::Close(_)) => {
info!("Client {} disconnected", client_id);
break;
}
Ok(Message::Ping(data)) => {
let _ = ws_sender.send(Message::Pong(data)).await;
}
Ok(Message::Pong(_)) => {
debug!("Received pong from client {}", client_id);
}
Ok(Message::Frame(_)) => {
// Raw frame messages are handled internally by tungstenite
debug!("Received raw frame from client {}", client_id);
}
Err(e) => {
error!("WebSocket error: {}", e);
break;
}
}
}
// Cleanup
server_state.clients.write().await.remove(&client_id);
info!("Client {} session ended", client_id);
}
async fn handle_text_message(
text: String,
session: &Arc<Mutex<ClientSession>>,
ws_sender: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, Message>,
model_manager: &Arc<ModelManager>,
) -> Result<()> {
let client_msg: ClientMessage = serde_json::from_str(&text)?;
let mut session = session.lock().await;
match client_msg {
ClientMessage::StartStream { config } => {
info!("Starting stream for client {}", session.id);
session.is_streaming = true;
session.config = config;
session.audio_processor.clear_buffer();
let response = ServerMessage::Status {
message: "Streaming started".to_string(),
};
send_response(ws_sender, response).await?;
}
ClientMessage::StopStream => {
info!("Stopping stream for client {}", session.id);
session.is_streaming = false;
session.audio_processor.clear_buffer();
let response = ServerMessage::Status {
message: "Streaming stopped".to_string(),
};
send_response(ws_sender, response).await?;
}
ClientMessage::Audio { data, sample_rate, channels, timestamp } => {
if !session.is_streaming {
return Ok(());
}
// Decode base64 audio data
let audio_bytes = general_purpose::STANDARD.decode(&data)?;
let audio_data: Vec<f32> = audio_bytes
.chunks(4)
.map(|chunk| {
let bytes = [chunk[0], chunk[1], chunk[2], chunk[3]];
f32::from_le_bytes(bytes)
})
.collect();
let audio_chunk = AudioChunk {
data: audio_data,
sample_rate,
channels,
timestamp,
};
// Process audio
let _processed_audio = session.audio_processor.process_chunk(audio_chunk)?;
// Check if we have enough audio for transcription
const MIN_AUDIO_LENGTH: usize = 24000; // 1 second at 24kHz
if let Some(buffered_audio) = session.audio_processor.get_buffered_audio(MIN_AUDIO_LENGTH) {
debug!("Processing {} samples for transcription", buffered_audio.len());
// Run transcription with language from config
let language = session.config.as_ref().and_then(|c| c.language.clone());
match model_manager.transcribe_with_language(buffered_audio, language).await {
Ok(transcription_text) => {
let result = TranscriptionResult {
text: transcription_text,
confidence: None,
start_time: None,
end_time: None,
words: None,
};
let response = ServerMessage::Transcription { result };
send_response(ws_sender, response).await?;
}
Err(e) => {
error!("Transcription error: {}", e);
let response = ServerMessage::Error {
message: format!("Transcription failed: {}", e),
};
send_response(ws_sender, response).await?;
}
}
}
}
ClientMessage::Ping => {
let response = ServerMessage::Pong;
send_response(ws_sender, response).await?;
}
}
Ok(())
}
async fn send_response(
ws_sender: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, Message>,
response: ServerMessage,
) -> Result<()> {
let json = serde_json::to_string(&response)?;
ws_sender.send(Message::Text(json)).await?;
Ok(())
}