use std::usize; use anyhow::anyhow; use ort::{ inputs, session::{Session, builder::GraphOptimizationLevel}, value::Tensor, }; use serde::Serialize; pub fn initialize_forced_aligner() -> Result { let lib_path = if cfg!(target_os = "windows") { "./lib/onnxruntime.dll" } else { "./onnxruntime-linux-x64-1.24.4/lib/libonnxruntime.so" }; let ok = ort::init_from(lib_path)?.commit(); if !ok { anyhow::bail!("Failed to initialize ONNX Runtime") } Ok(Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3) .map_err(|_| anyhow!("Can't set optimization"))? .with_intra_threads(2) .map_err(|_| anyhow!("Can't set intra threads"))? .commit_from_file("./wav2vec2_q4f16.onnx")?) } fn char_to_token(c: char) -> usize { match c.to_ascii_uppercase() { 'E' => 5, 'T' => 6, 'A' => 7, 'O' => 8, 'N' => 9, 'I' => 10, 'H' => 11, 'S' => 12, 'R' => 13, 'D' => 14, 'L' => 15, 'U' => 16, 'M' => 17, 'W' => 18, 'C' => 19, 'F' => 20, 'G' => 21, 'Y' => 22, 'P' => 23, 'B' => 24, 'V' => 25, 'K' => 26, '\'' => 27, 'X' => 28, 'J' => 29, 'Q' => 30, 'Z' => 31, ' ' => 4, _ => 3, // } } pub struct WordAlignment { pub word: String, pub start: f32, pub end: f32, } #[derive(Serialize)] pub struct VisemeEntry { pub viseme: String, pub start: f32, } pub struct AlignmentResult { pub words: Vec, pub visemes: Vec, } // For vowels only, works pretty good fn char_to_viseme(c: char) -> Option<&'static str> { match c.to_ascii_uppercase() { 'A' => Some("aa"), 'E' => Some("ee"), 'I' => Some("ih"), 'O' => Some("oh"), 'U' => Some("ou"), // consonants - no visemes, mouth closed _ => None, } } // CTC Forced alignment algorithm pub fn forced_align( session: &mut Session, audio: Vec, transcript: &str, ) -> Result { // Inference, extract probabilities let len = audio.len(); let tensor = Tensor::from_array(([1usize, len], audio.into_boxed_slice()))?; let outputs = session.run(inputs!["input_values" => tensor])?; let logits = outputs["logits"].try_extract_tensor::()?; let shape = logits.0; let time_steps = shape[1] as usize; let vocab_size = shape[2] as usize; let data = logits.1.iter().as_slice(); // Build CTC label sequence with blanks // [blank, t0, blank, t1, blank, t2, blank, ...] let tokens: Vec = transcript.chars().map(char_to_token).collect(); let mut seq = vec![0usize]; // start with blank for &t in &tokens { seq.push(t); seq.push(0); // blank after each token } let s_len = seq.len(); // Viterbi forward pass let neg_inf = f32::NEG_INFINITY; let mut dp = vec![vec![neg_inf; s_len]; time_steps]; let mut bt = vec![vec![0usize; s_len]; time_steps]; let log_prob = |t: usize, tok: usize| -> f32 { data[t * vocab_size + tok] }; // Initialize first framee dp[0][0] = log_prob(0, seq[0]); if s_len > 1 { dp[0][1] = log_prob(0, seq[1]); } for t in 1..time_steps { for s in 0..s_len { let mut best_score = neg_inf; let mut best_prev = s; // Transition from s (stay) if dp[t - 1][s] > best_score { best_score = dp[t - 1][s]; best_prev = s; } // Transition from s-1 if s >= 1 && dp[t - 1][s - 1] > best_score { best_score = dp[t - 1][s - 1]; best_prev = s - 1; } // Transition from s-2 (skip blank) if s >= 2 && seq[s - 1] == 0 && seq[s] != seq[s - 2] && dp[t - 1][s - 2] > best_score { best_score = dp[t - 1][s - 2]; best_prev = s - 2; } dp[t][s] = best_score + log_prob(t, seq[s]); bt[t][s] = best_prev; } } // --- Backtrack --- let mut path = vec![0usize; time_steps]; // Start backtrack from the best final state (last token or last blank) path[time_steps - 1] = if dp[time_steps - 1][s_len - 1] > dp[time_steps - 1][s_len - 2] { s_len - 1 } else { s_len - 2 }; for t in (0..time_steps - 1).rev() { path[t] = bt[t + 1][path[t + 1]]; } // --- Convert path to char timestamps --- // wav2vec2 base: 1 frame = 320 samples at 16kHz = 20ms let frame_duration = 320.0 / 16000.0; let transcript_chars: Vec = transcript.chars().collect(); // Collect per-token (start_frame, end_frame), skipping blanks let mut token_spans: Vec<(usize, usize, char)> = Vec::new(); let mut t = 0; while t < time_steps { let s = path[t]; if seq[s] != 0 { let start_frame = t; while t < time_steps && path[t] == s { t += 1; } let ch = transcript_chars[s / 2]; // seq is [blank, c0, blank, c1, ...] so char index = s/2 token_spans.push((start_frame, t, ch)); } else { t += 1; } } // --- Group chars into words at ' ' boundaries --- let mut alignments = Vec::new(); let mut current_word = String::new(); let mut word_start = 0.0f32; let mut word_end = 0.0f32; let mut visemes = Vec::new(); for (start_frame, end_frame, ch) in token_spans { let start_sec = start_frame as f32 * frame_duration; let end_sec = end_frame as f32 * frame_duration; if let Some(viseme) = char_to_viseme(ch) { visemes.push(VisemeEntry { viseme: viseme.to_string(), start: start_sec, }); } if ch == ' ' { if !current_word.is_empty() { alignments.push(WordAlignment { word: current_word.clone(), start: word_start, end: word_end, }); current_word.clear(); } } else { if current_word.is_empty() { word_start = start_sec; } current_word.push(ch); word_end = end_sec; } } // Push last word if !current_word.is_empty() { alignments.push(WordAlignment { word: current_word, start: word_start, end: word_end, }); } Ok(AlignmentResult { words: alignments, visemes, }) }