stt-gpu-service-v3 / src /model.rs
pgits's picture
MAJOR FIX: Switch from conversation model to dedicated STT model
22c8baf
use anyhow::Result;
use candle::{Device, Tensor, DType, IndexOp};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::api::tokio::Api;
use moshi::mimi::Mimi;
use moshi::{lm, lm_generate_multistream};
use moshi::{StreamTensor, StreamMask};
use sentencepiece::SentencePieceProcessor;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{info, warn, error};
use crate::config::ServerConfig;
// Helper function for basic voice activity detection
fn calculate_rms(audio_data: &[f32]) -> f32 {
if audio_data.is_empty() {
return 0.0;
}
let sum_squares: f32 = audio_data.iter().map(|&x| x * x).sum();
(sum_squares / audio_data.len() as f32).sqrt()
}
// Real Moshi STT model using streaming architecture following moshi-backend patterns
pub struct MoshiAsrModel {
mimi_model: Option<Mimi>,
lm_state: Option<lm_generate_multistream::State>, // STT state from moshi
text_tokenizer: Option<SentencePieceProcessor>,
device: Device,
dtype: DType,
initialized: bool,
vocab_size: usize, // Store vocab size for validation
}
impl MoshiAsrModel {
pub async fn load(
device: &Device,
model_path: &std::path::PathBuf,
config: &crate::config::ServerConfig
) -> Result<Self> {
info!("Loading moshi STT components following moshi-backend pattern");
// Use F32 for T4 GPU compatibility (compute capability 7.5)
// BF16 requires compute capability 8.0+ which T4 doesn't support
let dtype = DType::F32;
info!("Loading Mimi audio tokenizer from: {}", config.audio_tokenizer.path);
// Load Mimi audio tokenizer from safetensors
let api = hf_hub::api::tokio::Api::new()?;
let mimi_path = ModelManager::download_model_file(&api, &config.audio_tokenizer.path).await?;
let mimi_weights = candle::safetensors::load(&mimi_path, device)?;
let mimi_vb = VarBuilder::from_tensors(mimi_weights, dtype, device);
// Configure Mimi for 8 codebooks - MATCHING OFFICIAL MOSHI BACKEND CONFIGURATION
// Official config.json shows "mimi_codebooks": 8, this should avoid bounds errors
// Using exact same configuration as working kyutai/moshi backend
let mimi_config = moshi::mimi::Config::v0_1(Some(8));
let mimi_model = Mimi::new(mimi_config, mimi_vb)?;
info!("Mimi audio tokenizer loaded successfully");
info!("Loading STT transformer from: {}", config.model.path);
// Create LM model for STT-1B model - PROPER STT CONFIGURATION
// Use dedicated STT configuration instead of conversation model config
let lm_config = lm::Config::asr_v0_1_1b(); // Proper 1B STT config (not conversation)
info!("Using dedicated STT configuration for kyutai/stt-1b-en_fr model");
// Store vocab size before moving lm_config
let vocab_size = lm_config.text_out_vocab_size;
let lm_model = lm::load_lm_model(lm_config, model_path, dtype, device)?;
info!("STT transformer loaded successfully");
// Use standard configuration - MATCHING OFFICIAL MOSHI BACKEND
// Official backend should use proper configuration that avoids bounds errors
// Testing with v0_1() as official backend likely uses standard config
let mut state_config = lm_generate_multistream::Config::v0_1();
// CRITICAL FIX: Use BOS token for STT transcription start
state_config.text_start_token = 1; // BOS (Beginning of Sentence) token for STT mode
info!("Using custom config with text_start_token: {}, model vocab_size: {}",
state_config.text_start_token, vocab_size);
// Create logits processors (required for State::new)
let audio_lp = LogitsProcessor::new(0, None, None); // No temperature, no repetition penalty
let text_lp = LogitsProcessor::new(0, None, None); // Greedy decoding
let max_step_idx = 200; // Optimized for STT chunks (was 1000 for conversation)
let lm_state = lm_generate_multistream::State::new(
lm_model,
max_step_idx,
audio_lp,
text_lp,
None, // pad_mult
None, // repetition_penalty
None, // cfg_alpha
state_config,
);
info!("LM multistream state initialized");
Ok(MoshiAsrModel {
mimi_model: Some(mimi_model),
lm_state: Some(lm_state),
text_tokenizer: None, // Will load separately in ModelManager
device: device.clone(),
dtype,
initialized: true,
vocab_size, // Store the vocab size
})
}
pub async fn transcribe_stream(&mut self, audio_data: Vec<f32>) -> Result<String> {
self.transcribe_stream_with_language(audio_data, None).await
}
pub async fn transcribe_stream_with_language(&mut self, audio_data: Vec<f32>, language: Option<String>) -> Result<String> {
if !self.initialized {
return Err(anyhow::anyhow!("ASR model not initialized"));
}
info!("Processing {} audio samples through streaming pipeline with language: {:?}", audio_data.len(), language);
let audio_len = audio_data.len();
let duration_seconds = audio_len as f32 / 24000.0; // Mimi uses 24kHz
if duration_seconds < 0.08 {
return Ok("Audio chunk too short (< 80ms)".to_string());
}
// Implement proper streaming pipeline following moshi-backend patterns:
// 1. Preprocess audio to 24kHz with frame alignment (1920 samples = 80ms)
let processed_audio = self.preprocess_audio_for_mimi(audio_data)?;
let processed_len = processed_audio.len();
// 2. Convert audio to tensor and create StreamTensor
// Mimi expects 3D tensor: (batch, channels, samples)
let audio_tensor = Tensor::from_vec(processed_audio, (1, 1, processed_len), &self.device)?
.to_dtype(self.dtype)?;
let stream_tensor = StreamTensor::from_tensor(audio_tensor);
// 3. Create mask for active processing (needs device parameter)
let mask = StreamMask::new(vec![true], &self.device)?;
// 4. Get mutable references after preprocessing is complete
let mimi = self.mimi_model.as_mut().ok_or_else(|| anyhow::anyhow!("Mimi model not loaded"))?;
let lm_state = self.lm_state.as_mut().ok_or_else(|| anyhow::anyhow!("LM state not loaded"))?;
// 5. Process through Mimi audio tokenizer (audio → tokens)
info!("Encoding audio through Mimi tokenizer");
let audio_tokens = mimi.encode_step(&stream_tensor, &mask)?;
// 6. Process through STT transformer using moshi-backend pattern
info!("Processing tokens through STT transformer");
let generated_tokens = if let Some(audio_tensor) = audio_tokens.as_option() {
let (batch_size, codebooks, steps) = audio_tensor.dims3()?;
info!("Audio tokens shape: {}x{}x{}", batch_size, codebooks, steps);
// Following exact moshi-backend pattern from stream_both.rs
let mut prev_text_token = lm_state.config().text_start_token;
let mut text_tokens = Vec::new();
// Process each timestep with error handling and early termination
info!("Starting processing loop for {} steps", steps);
for step in 0..steps {
info!("Processing step {}/{}", step + 1, steps);
// Extract audio codes for this step with error handling
let mut codes = match audio_tensor.i((0, .., step))?.to_vec1::<u32>() {
Ok(codes) => {
info!("Step {} audio codes: {:?} (len: {})", step,
if codes.len() > 10 { &codes[..10] } else { &codes }, codes.len());
codes
},
Err(e) => {
error!("Failed to extract audio codes at step {}: {}", step, e);
break;
}
};
// Safety check: Ensure prev_text_token is within vocab range
if prev_text_token as usize >= self.vocab_size {
error!("prev_text_token {} is out of vocab range (0-{}), using fallback",
prev_text_token, self.vocab_size - 1);
prev_text_token = 0; // Use EOS token as fallback
}
info!("About to call lm_state.step with prev_text_token: {}", prev_text_token);
// Verify Mimi produces exactly 8 codebooks - MATCHING OFFICIAL CONFIGURATION
if codes.len() != 8 {
error!("❌ Mimi produced {} codebooks, but official config expects exactly 8", codes.len());
return Err(anyhow::anyhow!("Codebook count mismatch: got {}, expected 8", codes.len()));
}
info!("✅ Mimi produced exactly {} codebooks - matching official Moshi backend config", codes.len());
// DEFENSIVE BOUNDS CHECKING: Validate all inputs before calling lm_state.step()
// This prevents the "index out of bounds" error at moshi lm_generate_multistream.rs:198
// 1. Validate prev_text_token is within safe bounds
if prev_text_token as usize >= self.vocab_size {
warn!("🛡️ BOUNDS CHECK: prev_text_token {} >= vocab_size {}, clamping to safe value",
prev_text_token, self.vocab_size);
prev_text_token = (self.vocab_size - 1) as u32; // Use last valid token
}
// 2. Validate audio codes array is properly sized
if codes.len() > 32 { // Reasonable upper bound to prevent memory issues
warn!("🛡️ BOUNDS CHECK: Audio codes array too large: {}, truncating to 8", codes.len());
codes.truncate(8);
}
// 3. Validate individual code values are reasonable (prevent integer overflow/underflow)
for i in 0..codes.len() {
if codes[i] > 4096 { // Reasonable upper bound for audio codes
warn!("🛡️ BOUNDS CHECK: Audio code[{}] = {} is suspiciously large, clamping to 2047", i, codes[i]);
codes[i] = 2047; // Safe default within expected range
}
}
info!("🛡️ BOUNDS CHECK: All inputs validated - prev_text_token: {}, codes.len(): {}",
prev_text_token, codes.len());
// Create language conditioning if provided
let language_condition = if let Some(ref lang) = language {
if lang == "en" || lang == "english" {
info!("🌍 Using English language conditioning");
// For multilingual models, we hint that we expect English speech
// by providing a slightly different prev_text_token on first step
if step == 0 {
// Use a token that suggests English is expected
// For the Moshiko Q8 model, we'll use a specific pattern
prev_text_token = if prev_text_token == 31999 { 31998 } else { prev_text_token };
}
}
None // No explicit condition object for now
} else {
None
};
// Use the step method with comprehensive error handling and language conditioning
let text_token = match lm_state.step(prev_text_token, &codes, None, language_condition.as_ref()) {
Ok(token) => {
info!("✅ Generated text token: {} for step {}", token, step);
token
},
Err(e) => {
error!("❌ lm_state.step failed at step {}: {}", step, e);
warn!("Terminating processing early due to step failure");
break;
}
};
// Collect ALL tokens for debugging (including pad tokens)
// TODO: Restore filtering after debugging
text_tokens.push(text_token);
// Log what we're getting
if text_token == 0 {
info!("🔚 Got EOS token (0)");
} else if text_token == 3 {
info!("📄 Got PAD token (3)");
} else {
info!("📝 Got TEXT token ({})", text_token);
}
prev_text_token = text_token;
// Early termination on end-of-sequence
if text_token == 0 {
break;
}
}
info!("Generated {} text tokens: {:?}", text_tokens.len(), text_tokens);
text_tokens
} else {
info!("No audio tokens from Mimi");
Vec::new()
};
let result = if generated_tokens.is_empty() {
format!("STT: Processed {:.2}s audio chunk (no tokens generated)", duration_seconds)
} else {
// Count token types for debugging
let pad_count = generated_tokens.iter().filter(|&&t| t == 3).count();
let eos_count = generated_tokens.iter().filter(|&&t| t == 0).count();
let text_count = generated_tokens.iter().filter(|&&t| t != 0 && t != 3).count();
format!("STT: {} tokens from {:.2}s chunk - PAD:{}, EOS:{}, TEXT:{} - {:?}",
generated_tokens.len(), duration_seconds, pad_count, eos_count, text_count, generated_tokens)
};
Ok(result)
}
// Preprocess audio for Mimi (24kHz, frame-aligned to 1920 samples = 80ms)
fn preprocess_audio_for_mimi(&self, mut audio: Vec<f32>) -> Result<Vec<f32>> {
// Convert from 16kHz to 24kHz if needed (simple upsampling)
if audio.len() > 0 {
// For now, assume input is close to target rate
// TODO: Implement proper resampling
}
// Ensure frame alignment to 1920 samples (80ms at 24kHz)
let frame_size = 1920;
let remainder = audio.len() % frame_size;
if remainder != 0 {
// Pad to next frame boundary
let pad_size = frame_size - remainder;
audio.extend(vec![0.0; pad_size]);
}
// CRITICAL FIX: Follow moshi library patterns exactly!
// Mimi has renormalize: true in Config, so it handles normalization internally
// Our job is just to ensure the audio is in valid PCM float32 range [-1.0, 1.0]
// and let moshi do its neural-based renormalization
// Basic safety clipping (moshi expects PCM float32 in [-1.0, 1.0])
for sample in &mut audio {
*sample = sample.clamp(-1.0, 1.0);
}
// DEBUGGING: Examine actual audio data quality
let min_val = audio.iter().fold(f32::INFINITY, |acc, &x| acc.min(x));
let max_val = audio.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
let rms = (audio.iter().map(|&x| x * x).sum::<f32>() / audio.len() as f32).sqrt();
let non_zero_count = audio.iter().filter(|&&x| x.abs() > 0.001).count();
info!("🔍 AUDIO DEBUG: {} samples, range [{:.4}, {:.4}], RMS: {:.4}, non-zero: {}/{}",
audio.len(), min_val, max_val, rms, non_zero_count, audio.len());
// CRITICAL FIX: Boost extremely quiet audio for speech recognition
if rms < 0.12 { // Boost audio below good speech level
let target_rms = 0.15; // Good level for speech recognition
let boost_factor = target_rms / rms.max(0.0001); // Prevent division by zero
let max_boost = 20.0; // Reasonable limit
let applied_boost = boost_factor.min(max_boost);
for sample in &mut audio {
*sample *= applied_boost;
*sample = sample.clamp(-1.0, 1.0); // Prevent clipping
}
let new_rms = (audio.iter().map(|&x| x * x).sum::<f32>() / audio.len() as f32).sqrt();
info!("🔊 BOOSTED QUIET AUDIO: RMS {:.6} → {:.4} (boost: {:.1}x)",
rms, new_rms, applied_boost);
} else if rms < 0.05 {
warn!("⚠️ AUDIO WARNING: Low RMS {:.4} - might be too quiet for recognition", rms);
}
if non_zero_count < audio.len() / 10 {
warn!("⚠️ AUDIO WARNING: {}% samples near zero - mostly silence?",
(audio.len() - non_zero_count) * 100 / audio.len());
}
info!("Mimi audio: {} samples prepared, letting moshi handle renormalization", audio.len());
Ok(audio)
}
}
pub struct ModelManager {
model: Arc<Mutex<MoshiAsrModel>>,
text_tokenizer: SentencePieceProcessor,
config: ServerConfig,
device: Device,
}
impl ModelManager {
pub async fn new(config: ServerConfig, device: Device) -> Result<Self> {
info!("Loading model from: {}", config.model.path);
// Download model files from HuggingFace
let api = Api::new()?;
let model_path = Self::download_model_file(&api, &config.model.path).await?;
let tokenizer_path = Self::download_model_file(&api, &config.text_tokenizer.path).await?;
// Load tokenizer
info!("Loading text tokenizer");
let text_tokenizer = SentencePieceProcessor::open(&tokenizer_path)?;
// Load moshi ASR model with config
info!("Loading moshi ASR model");
let model = MoshiAsrModel::load(&device, &model_path, &config).await?;
let model = Arc::new(Mutex::new(model));
Ok(ModelManager {
model,
text_tokenizer,
config,
device,
})
}
async fn download_model_file(api: &Api, path: &str) -> Result<std::path::PathBuf> {
if path.starts_with("hf://") {
let path = path.strip_prefix("hf://").unwrap();
let parts: Vec<&str> = path.split('/').collect();
if parts.len() < 2 {
return Err(anyhow::anyhow!("Invalid HuggingFace path: {}", path));
}
let repo = format!("{}/{}", parts[0], parts[1]);
let file_path = parts[2..].join("/");
// Check for pre-loaded model files first
let local_path = std::path::PathBuf::from(format!("models/{}/{}", repo, file_path));
if local_path.exists() {
info!("Using pre-loaded model: {}", local_path.display());
return Ok(local_path);
}
info!("Pre-loaded model not found, downloading {} from {}", file_path, repo);
let repo = api.model(repo);
// Simple retry logic for lock acquisition failures
let mut attempts = 0;
let max_attempts = 3; // Reduced since we shouldn't have multiple processes now
loop {
match repo.get(&file_path).await {
Ok(path) => return Ok(path),
Err(e) if attempts < max_attempts => {
if e.to_string().contains("Lock acquisition failed") {
attempts += 1;
warn!("Lock acquisition failed, retrying ({}/{}): {}", attempts, max_attempts, e);
tokio::time::sleep(tokio::time::Duration::from_millis(2000)).await;
continue;
}
return Err(anyhow::anyhow!("Download failed: {}", e));
}
Err(e) => return Err(anyhow::anyhow!("Download failed after {} attempts: {}", max_attempts, e)),
}
}
} else {
Ok(std::path::PathBuf::from(path))
}
}
pub async fn transcribe(&self, audio_data: Vec<f32>) -> Result<String> {
self.transcribe_with_language(audio_data, None).await
}
pub async fn transcribe_with_language(&self, audio_data: Vec<f32>, language: Option<String>) -> Result<String> {
info!("Processing audio chunk of length: {} with language: {:?}", audio_data.len(), language);
// Resample audio to 24kHz if needed
let audio_data = self.preprocess_audio(audio_data)?;
// Run inference with real moshi ASR
let mut model = self.model.lock().await;
match model.transcribe_stream_with_language(audio_data, language).await {
Ok(transcription) => {
info!("Transcription successful: {}", transcription);
Ok(transcription)
}
Err(e) => {
error!("Moshi ASR transcription failed: {}", e);
Err(e)
}
}
}
fn preprocess_audio(&self, audio: Vec<f32>) -> Result<Vec<f32>> {
// Ensure 24kHz sample rate
// TODO: Implement proper resampling if needed
// REMOVED: Double normalization (already normalized in preprocess_audio_for_mimi)
// The audio has already been properly normalized with RMS-based approach
info!("ModelManager: Skipping normalization (already done in Mimi preprocessing)");
Ok(audio)
}
pub fn get_config(&self) -> &ServerConfig {
&self.config
}
}