Spaces:
Sleeping
Sleeping
| use anyhow::Result; | |
| use candle::{Device, Tensor, DType, IndexOp}; | |
| use candle_nn::VarBuilder; | |
| use candle_transformers::generation::LogitsProcessor; | |
| use hf_hub::api::tokio::Api; | |
| use moshi::mimi::Mimi; | |
| use moshi::{lm, lm_generate_multistream}; | |
| use moshi::{StreamTensor, StreamMask}; | |
| use sentencepiece::SentencePieceProcessor; | |
| use std::sync::Arc; | |
| use tokio::sync::Mutex; | |
| use tracing::{info, warn, error}; | |
| use crate::config::ServerConfig; | |
| // Helper function for basic voice activity detection | |
| fn calculate_rms(audio_data: &[f32]) -> f32 { | |
| if audio_data.is_empty() { | |
| return 0.0; | |
| } | |
| let sum_squares: f32 = audio_data.iter().map(|&x| x * x).sum(); | |
| (sum_squares / audio_data.len() as f32).sqrt() | |
| } | |
| // Real Moshi STT model using streaming architecture following moshi-backend patterns | |
| pub struct MoshiAsrModel { | |
| mimi_model: Option<Mimi>, | |
| lm_state: Option<lm_generate_multistream::State>, // STT state from moshi | |
| text_tokenizer: Option<SentencePieceProcessor>, | |
| device: Device, | |
| dtype: DType, | |
| initialized: bool, | |
| vocab_size: usize, // Store vocab size for validation | |
| } | |
| impl MoshiAsrModel { | |
| pub async fn load( | |
| device: &Device, | |
| model_path: &std::path::PathBuf, | |
| config: &crate::config::ServerConfig | |
| ) -> Result<Self> { | |
| info!("Loading moshi STT components following moshi-backend pattern"); | |
| // Use F32 for T4 GPU compatibility (compute capability 7.5) | |
| // BF16 requires compute capability 8.0+ which T4 doesn't support | |
| let dtype = DType::F32; | |
| info!("Loading Mimi audio tokenizer from: {}", config.audio_tokenizer.path); | |
| // Load Mimi audio tokenizer from safetensors | |
| let api = hf_hub::api::tokio::Api::new()?; | |
| let mimi_path = ModelManager::download_model_file(&api, &config.audio_tokenizer.path).await?; | |
| let mimi_weights = candle::safetensors::load(&mimi_path, device)?; | |
| let mimi_vb = VarBuilder::from_tensors(mimi_weights, dtype, device); | |
| // Configure Mimi for 8 codebooks - MATCHING OFFICIAL MOSHI BACKEND CONFIGURATION | |
| // Official config.json shows "mimi_codebooks": 8, this should avoid bounds errors | |
| // Using exact same configuration as working kyutai/moshi backend | |
| let mimi_config = moshi::mimi::Config::v0_1(Some(8)); | |
| let mimi_model = Mimi::new(mimi_config, mimi_vb)?; | |
| info!("Mimi audio tokenizer loaded successfully"); | |
| info!("Loading STT transformer from: {}", config.model.path); | |
| // Create LM model for STT-1B model - PROPER STT CONFIGURATION | |
| // Use dedicated STT configuration instead of conversation model config | |
| let lm_config = lm::Config::asr_v0_1_1b(); // Proper 1B STT config (not conversation) | |
| info!("Using dedicated STT configuration for kyutai/stt-1b-en_fr model"); | |
| // Store vocab size before moving lm_config | |
| let vocab_size = lm_config.text_out_vocab_size; | |
| let lm_model = lm::load_lm_model(lm_config, model_path, dtype, device)?; | |
| info!("STT transformer loaded successfully"); | |
| // Use standard configuration - MATCHING OFFICIAL MOSHI BACKEND | |
| // Official backend should use proper configuration that avoids bounds errors | |
| // Testing with v0_1() as official backend likely uses standard config | |
| let mut state_config = lm_generate_multistream::Config::v0_1(); | |
| // CRITICAL FIX: Use BOS token for STT transcription start | |
| state_config.text_start_token = 1; // BOS (Beginning of Sentence) token for STT mode | |
| info!("Using custom config with text_start_token: {}, model vocab_size: {}", | |
| state_config.text_start_token, vocab_size); | |
| // Create logits processors (required for State::new) | |
| let audio_lp = LogitsProcessor::new(0, None, None); // No temperature, no repetition penalty | |
| let text_lp = LogitsProcessor::new(0, None, None); // Greedy decoding | |
| let max_step_idx = 200; // Optimized for STT chunks (was 1000 for conversation) | |
| let lm_state = lm_generate_multistream::State::new( | |
| lm_model, | |
| max_step_idx, | |
| audio_lp, | |
| text_lp, | |
| None, // pad_mult | |
| None, // repetition_penalty | |
| None, // cfg_alpha | |
| state_config, | |
| ); | |
| info!("LM multistream state initialized"); | |
| Ok(MoshiAsrModel { | |
| mimi_model: Some(mimi_model), | |
| lm_state: Some(lm_state), | |
| text_tokenizer: None, // Will load separately in ModelManager | |
| device: device.clone(), | |
| dtype, | |
| initialized: true, | |
| vocab_size, // Store the vocab size | |
| }) | |
| } | |
| pub async fn transcribe_stream(&mut self, audio_data: Vec<f32>) -> Result<String> { | |
| self.transcribe_stream_with_language(audio_data, None).await | |
| } | |
| pub async fn transcribe_stream_with_language(&mut self, audio_data: Vec<f32>, language: Option<String>) -> Result<String> { | |
| if !self.initialized { | |
| return Err(anyhow::anyhow!("ASR model not initialized")); | |
| } | |
| info!("Processing {} audio samples through streaming pipeline with language: {:?}", audio_data.len(), language); | |
| let audio_len = audio_data.len(); | |
| let duration_seconds = audio_len as f32 / 24000.0; // Mimi uses 24kHz | |
| if duration_seconds < 0.08 { | |
| return Ok("Audio chunk too short (< 80ms)".to_string()); | |
| } | |
| // Implement proper streaming pipeline following moshi-backend patterns: | |
| // 1. Preprocess audio to 24kHz with frame alignment (1920 samples = 80ms) | |
| let processed_audio = self.preprocess_audio_for_mimi(audio_data)?; | |
| let processed_len = processed_audio.len(); | |
| // 2. Convert audio to tensor and create StreamTensor | |
| // Mimi expects 3D tensor: (batch, channels, samples) | |
| let audio_tensor = Tensor::from_vec(processed_audio, (1, 1, processed_len), &self.device)? | |
| .to_dtype(self.dtype)?; | |
| let stream_tensor = StreamTensor::from_tensor(audio_tensor); | |
| // 3. Create mask for active processing (needs device parameter) | |
| let mask = StreamMask::new(vec![true], &self.device)?; | |
| // 4. Get mutable references after preprocessing is complete | |
| let mimi = self.mimi_model.as_mut().ok_or_else(|| anyhow::anyhow!("Mimi model not loaded"))?; | |
| let lm_state = self.lm_state.as_mut().ok_or_else(|| anyhow::anyhow!("LM state not loaded"))?; | |
| // 5. Process through Mimi audio tokenizer (audio → tokens) | |
| info!("Encoding audio through Mimi tokenizer"); | |
| let audio_tokens = mimi.encode_step(&stream_tensor, &mask)?; | |
| // 6. Process through STT transformer using moshi-backend pattern | |
| info!("Processing tokens through STT transformer"); | |
| let generated_tokens = if let Some(audio_tensor) = audio_tokens.as_option() { | |
| let (batch_size, codebooks, steps) = audio_tensor.dims3()?; | |
| info!("Audio tokens shape: {}x{}x{}", batch_size, codebooks, steps); | |
| // Following exact moshi-backend pattern from stream_both.rs | |
| let mut prev_text_token = lm_state.config().text_start_token; | |
| let mut text_tokens = Vec::new(); | |
| // Process each timestep with error handling and early termination | |
| info!("Starting processing loop for {} steps", steps); | |
| for step in 0..steps { | |
| info!("Processing step {}/{}", step + 1, steps); | |
| // Extract audio codes for this step with error handling | |
| let mut codes = match audio_tensor.i((0, .., step))?.to_vec1::<u32>() { | |
| Ok(codes) => { | |
| info!("Step {} audio codes: {:?} (len: {})", step, | |
| if codes.len() > 10 { &codes[..10] } else { &codes }, codes.len()); | |
| codes | |
| }, | |
| Err(e) => { | |
| error!("Failed to extract audio codes at step {}: {}", step, e); | |
| break; | |
| } | |
| }; | |
| // Safety check: Ensure prev_text_token is within vocab range | |
| if prev_text_token as usize >= self.vocab_size { | |
| error!("prev_text_token {} is out of vocab range (0-{}), using fallback", | |
| prev_text_token, self.vocab_size - 1); | |
| prev_text_token = 0; // Use EOS token as fallback | |
| } | |
| info!("About to call lm_state.step with prev_text_token: {}", prev_text_token); | |
| // Verify Mimi produces exactly 8 codebooks - MATCHING OFFICIAL CONFIGURATION | |
| if codes.len() != 8 { | |
| error!("❌ Mimi produced {} codebooks, but official config expects exactly 8", codes.len()); | |
| return Err(anyhow::anyhow!("Codebook count mismatch: got {}, expected 8", codes.len())); | |
| } | |
| info!("✅ Mimi produced exactly {} codebooks - matching official Moshi backend config", codes.len()); | |
| // DEFENSIVE BOUNDS CHECKING: Validate all inputs before calling lm_state.step() | |
| // This prevents the "index out of bounds" error at moshi lm_generate_multistream.rs:198 | |
| // 1. Validate prev_text_token is within safe bounds | |
| if prev_text_token as usize >= self.vocab_size { | |
| warn!("🛡️ BOUNDS CHECK: prev_text_token {} >= vocab_size {}, clamping to safe value", | |
| prev_text_token, self.vocab_size); | |
| prev_text_token = (self.vocab_size - 1) as u32; // Use last valid token | |
| } | |
| // 2. Validate audio codes array is properly sized | |
| if codes.len() > 32 { // Reasonable upper bound to prevent memory issues | |
| warn!("🛡️ BOUNDS CHECK: Audio codes array too large: {}, truncating to 8", codes.len()); | |
| codes.truncate(8); | |
| } | |
| // 3. Validate individual code values are reasonable (prevent integer overflow/underflow) | |
| for i in 0..codes.len() { | |
| if codes[i] > 4096 { // Reasonable upper bound for audio codes | |
| warn!("🛡️ BOUNDS CHECK: Audio code[{}] = {} is suspiciously large, clamping to 2047", i, codes[i]); | |
| codes[i] = 2047; // Safe default within expected range | |
| } | |
| } | |
| info!("🛡️ BOUNDS CHECK: All inputs validated - prev_text_token: {}, codes.len(): {}", | |
| prev_text_token, codes.len()); | |
| // Create language conditioning if provided | |
| let language_condition = if let Some(ref lang) = language { | |
| if lang == "en" || lang == "english" { | |
| info!("🌍 Using English language conditioning"); | |
| // For multilingual models, we hint that we expect English speech | |
| // by providing a slightly different prev_text_token on first step | |
| if step == 0 { | |
| // Use a token that suggests English is expected | |
| // For the Moshiko Q8 model, we'll use a specific pattern | |
| prev_text_token = if prev_text_token == 31999 { 31998 } else { prev_text_token }; | |
| } | |
| } | |
| None // No explicit condition object for now | |
| } else { | |
| None | |
| }; | |
| // Use the step method with comprehensive error handling and language conditioning | |
| let text_token = match lm_state.step(prev_text_token, &codes, None, language_condition.as_ref()) { | |
| Ok(token) => { | |
| info!("✅ Generated text token: {} for step {}", token, step); | |
| token | |
| }, | |
| Err(e) => { | |
| error!("❌ lm_state.step failed at step {}: {}", step, e); | |
| warn!("Terminating processing early due to step failure"); | |
| break; | |
| } | |
| }; | |
| // Collect ALL tokens for debugging (including pad tokens) | |
| // TODO: Restore filtering after debugging | |
| text_tokens.push(text_token); | |
| // Log what we're getting | |
| if text_token == 0 { | |
| info!("🔚 Got EOS token (0)"); | |
| } else if text_token == 3 { | |
| info!("📄 Got PAD token (3)"); | |
| } else { | |
| info!("📝 Got TEXT token ({})", text_token); | |
| } | |
| prev_text_token = text_token; | |
| // Early termination on end-of-sequence | |
| if text_token == 0 { | |
| break; | |
| } | |
| } | |
| info!("Generated {} text tokens: {:?}", text_tokens.len(), text_tokens); | |
| text_tokens | |
| } else { | |
| info!("No audio tokens from Mimi"); | |
| Vec::new() | |
| }; | |
| let result = if generated_tokens.is_empty() { | |
| format!("STT: Processed {:.2}s audio chunk (no tokens generated)", duration_seconds) | |
| } else { | |
| // Count token types for debugging | |
| let pad_count = generated_tokens.iter().filter(|&&t| t == 3).count(); | |
| let eos_count = generated_tokens.iter().filter(|&&t| t == 0).count(); | |
| let text_count = generated_tokens.iter().filter(|&&t| t != 0 && t != 3).count(); | |
| format!("STT: {} tokens from {:.2}s chunk - PAD:{}, EOS:{}, TEXT:{} - {:?}", | |
| generated_tokens.len(), duration_seconds, pad_count, eos_count, text_count, generated_tokens) | |
| }; | |
| Ok(result) | |
| } | |
| // Preprocess audio for Mimi (24kHz, frame-aligned to 1920 samples = 80ms) | |
| fn preprocess_audio_for_mimi(&self, mut audio: Vec<f32>) -> Result<Vec<f32>> { | |
| // Convert from 16kHz to 24kHz if needed (simple upsampling) | |
| if audio.len() > 0 { | |
| // For now, assume input is close to target rate | |
| // TODO: Implement proper resampling | |
| } | |
| // Ensure frame alignment to 1920 samples (80ms at 24kHz) | |
| let frame_size = 1920; | |
| let remainder = audio.len() % frame_size; | |
| if remainder != 0 { | |
| // Pad to next frame boundary | |
| let pad_size = frame_size - remainder; | |
| audio.extend(vec![0.0; pad_size]); | |
| } | |
| // CRITICAL FIX: Follow moshi library patterns exactly! | |
| // Mimi has renormalize: true in Config, so it handles normalization internally | |
| // Our job is just to ensure the audio is in valid PCM float32 range [-1.0, 1.0] | |
| // and let moshi do its neural-based renormalization | |
| // Basic safety clipping (moshi expects PCM float32 in [-1.0, 1.0]) | |
| for sample in &mut audio { | |
| *sample = sample.clamp(-1.0, 1.0); | |
| } | |
| // DEBUGGING: Examine actual audio data quality | |
| let min_val = audio.iter().fold(f32::INFINITY, |acc, &x| acc.min(x)); | |
| let max_val = audio.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x)); | |
| let rms = (audio.iter().map(|&x| x * x).sum::<f32>() / audio.len() as f32).sqrt(); | |
| let non_zero_count = audio.iter().filter(|&&x| x.abs() > 0.001).count(); | |
| info!("🔍 AUDIO DEBUG: {} samples, range [{:.4}, {:.4}], RMS: {:.4}, non-zero: {}/{}", | |
| audio.len(), min_val, max_val, rms, non_zero_count, audio.len()); | |
| // CRITICAL FIX: Boost extremely quiet audio for speech recognition | |
| if rms < 0.12 { // Boost audio below good speech level | |
| let target_rms = 0.15; // Good level for speech recognition | |
| let boost_factor = target_rms / rms.max(0.0001); // Prevent division by zero | |
| let max_boost = 20.0; // Reasonable limit | |
| let applied_boost = boost_factor.min(max_boost); | |
| for sample in &mut audio { | |
| *sample *= applied_boost; | |
| *sample = sample.clamp(-1.0, 1.0); // Prevent clipping | |
| } | |
| let new_rms = (audio.iter().map(|&x| x * x).sum::<f32>() / audio.len() as f32).sqrt(); | |
| info!("🔊 BOOSTED QUIET AUDIO: RMS {:.6} → {:.4} (boost: {:.1}x)", | |
| rms, new_rms, applied_boost); | |
| } else if rms < 0.05 { | |
| warn!("⚠️ AUDIO WARNING: Low RMS {:.4} - might be too quiet for recognition", rms); | |
| } | |
| if non_zero_count < audio.len() / 10 { | |
| warn!("⚠️ AUDIO WARNING: {}% samples near zero - mostly silence?", | |
| (audio.len() - non_zero_count) * 100 / audio.len()); | |
| } | |
| info!("Mimi audio: {} samples prepared, letting moshi handle renormalization", audio.len()); | |
| Ok(audio) | |
| } | |
| } | |
| pub struct ModelManager { | |
| model: Arc<Mutex<MoshiAsrModel>>, | |
| text_tokenizer: SentencePieceProcessor, | |
| config: ServerConfig, | |
| device: Device, | |
| } | |
| impl ModelManager { | |
| pub async fn new(config: ServerConfig, device: Device) -> Result<Self> { | |
| info!("Loading model from: {}", config.model.path); | |
| // Download model files from HuggingFace | |
| let api = Api::new()?; | |
| let model_path = Self::download_model_file(&api, &config.model.path).await?; | |
| let tokenizer_path = Self::download_model_file(&api, &config.text_tokenizer.path).await?; | |
| // Load tokenizer | |
| info!("Loading text tokenizer"); | |
| let text_tokenizer = SentencePieceProcessor::open(&tokenizer_path)?; | |
| // Load moshi ASR model with config | |
| info!("Loading moshi ASR model"); | |
| let model = MoshiAsrModel::load(&device, &model_path, &config).await?; | |
| let model = Arc::new(Mutex::new(model)); | |
| Ok(ModelManager { | |
| model, | |
| text_tokenizer, | |
| config, | |
| device, | |
| }) | |
| } | |
| async fn download_model_file(api: &Api, path: &str) -> Result<std::path::PathBuf> { | |
| if path.starts_with("hf://") { | |
| let path = path.strip_prefix("hf://").unwrap(); | |
| let parts: Vec<&str> = path.split('/').collect(); | |
| if parts.len() < 2 { | |
| return Err(anyhow::anyhow!("Invalid HuggingFace path: {}", path)); | |
| } | |
| let repo = format!("{}/{}", parts[0], parts[1]); | |
| let file_path = parts[2..].join("/"); | |
| // Check for pre-loaded model files first | |
| let local_path = std::path::PathBuf::from(format!("models/{}/{}", repo, file_path)); | |
| if local_path.exists() { | |
| info!("Using pre-loaded model: {}", local_path.display()); | |
| return Ok(local_path); | |
| } | |
| info!("Pre-loaded model not found, downloading {} from {}", file_path, repo); | |
| let repo = api.model(repo); | |
| // Simple retry logic for lock acquisition failures | |
| let mut attempts = 0; | |
| let max_attempts = 3; // Reduced since we shouldn't have multiple processes now | |
| loop { | |
| match repo.get(&file_path).await { | |
| Ok(path) => return Ok(path), | |
| Err(e) if attempts < max_attempts => { | |
| if e.to_string().contains("Lock acquisition failed") { | |
| attempts += 1; | |
| warn!("Lock acquisition failed, retrying ({}/{}): {}", attempts, max_attempts, e); | |
| tokio::time::sleep(tokio::time::Duration::from_millis(2000)).await; | |
| continue; | |
| } | |
| return Err(anyhow::anyhow!("Download failed: {}", e)); | |
| } | |
| Err(e) => return Err(anyhow::anyhow!("Download failed after {} attempts: {}", max_attempts, e)), | |
| } | |
| } | |
| } else { | |
| Ok(std::path::PathBuf::from(path)) | |
| } | |
| } | |
| pub async fn transcribe(&self, audio_data: Vec<f32>) -> Result<String> { | |
| self.transcribe_with_language(audio_data, None).await | |
| } | |
| pub async fn transcribe_with_language(&self, audio_data: Vec<f32>, language: Option<String>) -> Result<String> { | |
| info!("Processing audio chunk of length: {} with language: {:?}", audio_data.len(), language); | |
| // Resample audio to 24kHz if needed | |
| let audio_data = self.preprocess_audio(audio_data)?; | |
| // Run inference with real moshi ASR | |
| let mut model = self.model.lock().await; | |
| match model.transcribe_stream_with_language(audio_data, language).await { | |
| Ok(transcription) => { | |
| info!("Transcription successful: {}", transcription); | |
| Ok(transcription) | |
| } | |
| Err(e) => { | |
| error!("Moshi ASR transcription failed: {}", e); | |
| Err(e) | |
| } | |
| } | |
| } | |
| fn preprocess_audio(&self, audio: Vec<f32>) -> Result<Vec<f32>> { | |
| // Ensure 24kHz sample rate | |
| // TODO: Implement proper resampling if needed | |
| // REMOVED: Double normalization (already normalized in preprocess_audio_for_mimi) | |
| // The audio has already been properly normalized with RMS-based approach | |
| info!("ModelManager: Skipping normalization (already done in Mimi preprocessing)"); | |
| Ok(audio) | |
| } | |
| pub fn get_config(&self) -> &ServerConfig { | |
| &self.config | |
| } | |
| } |