Spaces:
Sleeping
Sleeping
| use std::usize; | |
| use anyhow::anyhow; | |
| use ort::{ | |
| inputs, | |
| session::{Session, builder::GraphOptimizationLevel}, | |
| value::Tensor, | |
| }; | |
| use serde::Serialize; | |
| pub fn initialize_forced_aligner() -> Result<Session, anyhow::Error> { | |
| let lib_path = if cfg!(target_os = "windows") { | |
| "./lib/onnxruntime.dll" | |
| } else { | |
| "./onnxruntime-linux-x64-1.24.4/lib/libonnxruntime.so" | |
| }; | |
| let ok = ort::init_from(lib_path)?.commit(); | |
| if !ok { | |
| anyhow::bail!("Failed to initialize ONNX Runtime") | |
| } | |
| Ok(Session::builder()? | |
| .with_optimization_level(GraphOptimizationLevel::Level3) | |
| .map_err(|_| anyhow!("Can't set optimization"))? | |
| .with_intra_threads(2) | |
| .map_err(|_| anyhow!("Can't set intra threads"))? | |
| .commit_from_file("./wav2vec2_q4f16.onnx")?) | |
| } | |
| fn char_to_token(c: char) -> usize { | |
| match c.to_ascii_uppercase() { | |
| 'E' => 5, | |
| 'T' => 6, | |
| 'A' => 7, | |
| 'O' => 8, | |
| 'N' => 9, | |
| 'I' => 10, | |
| 'H' => 11, | |
| 'S' => 12, | |
| 'R' => 13, | |
| 'D' => 14, | |
| 'L' => 15, | |
| 'U' => 16, | |
| 'M' => 17, | |
| 'W' => 18, | |
| 'C' => 19, | |
| 'F' => 20, | |
| 'G' => 21, | |
| 'Y' => 22, | |
| 'P' => 23, | |
| 'B' => 24, | |
| 'V' => 25, | |
| 'K' => 26, | |
| '\'' => 27, | |
| 'X' => 28, | |
| 'J' => 29, | |
| 'Q' => 30, | |
| 'Z' => 31, | |
| ' ' => 4, | |
| _ => 3, // <unk> | |
| } | |
| } | |
| pub struct WordAlignment { | |
| pub word: String, | |
| pub start: f32, | |
| pub end: f32, | |
| } | |
| pub struct VisemeEntry { | |
| pub viseme: String, | |
| pub start: f32, | |
| } | |
| pub struct AlignmentResult { | |
| pub words: Vec<WordAlignment>, | |
| pub visemes: Vec<VisemeEntry>, | |
| } | |
| // For vowels only, works pretty good | |
| fn char_to_viseme(c: char) -> Option<&'static str> { | |
| match c.to_ascii_uppercase() { | |
| 'A' => Some("aa"), | |
| 'E' => Some("ee"), | |
| 'I' => Some("ih"), | |
| 'O' => Some("oh"), | |
| 'U' => Some("ou"), | |
| // consonants - no visemes, mouth closed | |
| _ => None, | |
| } | |
| } | |
| // CTC Forced alignment algorithm | |
| pub fn forced_align( | |
| session: &mut Session, | |
| audio: Vec<f32>, | |
| transcript: &str, | |
| ) -> Result<AlignmentResult, anyhow::Error> { | |
| // Inference, extract probabilities | |
| let len = audio.len(); | |
| let tensor = Tensor::from_array(([1usize, len], audio.into_boxed_slice()))?; | |
| let outputs = session.run(inputs!["input_values" => tensor])?; | |
| let logits = outputs["logits"].try_extract_tensor::<f32>()?; | |
| let shape = logits.0; | |
| let time_steps = shape[1] as usize; | |
| let vocab_size = shape[2] as usize; | |
| let data = logits.1.iter().as_slice(); | |
| // Build CTC label sequence with blanks | |
| // [blank, t0, blank, t1, blank, t2, blank, ...] | |
| let tokens: Vec<usize> = transcript.chars().map(char_to_token).collect(); | |
| let mut seq = vec![0usize]; // start with blank | |
| for &t in &tokens { | |
| seq.push(t); | |
| seq.push(0); // blank after each token | |
| } | |
| let s_len = seq.len(); | |
| // Viterbi forward pass | |
| let neg_inf = f32::NEG_INFINITY; | |
| let mut dp = vec![vec![neg_inf; s_len]; time_steps]; | |
| let mut bt = vec![vec![0usize; s_len]; time_steps]; | |
| let log_prob = |t: usize, tok: usize| -> f32 { data[t * vocab_size + tok] }; | |
| // Initialize first framee | |
| dp[0][0] = log_prob(0, seq[0]); | |
| if s_len > 1 { | |
| dp[0][1] = log_prob(0, seq[1]); | |
| } | |
| for t in 1..time_steps { | |
| for s in 0..s_len { | |
| let mut best_score = neg_inf; | |
| let mut best_prev = s; | |
| // Transition from s (stay) | |
| if dp[t - 1][s] > best_score { | |
| best_score = dp[t - 1][s]; | |
| best_prev = s; | |
| } | |
| // Transition from s-1 | |
| if s >= 1 && dp[t - 1][s - 1] > best_score { | |
| best_score = dp[t - 1][s - 1]; | |
| best_prev = s - 1; | |
| } | |
| // Transition from s-2 (skip blank) | |
| if s >= 2 && seq[s - 1] == 0 && seq[s] != seq[s - 2] && dp[t - 1][s - 2] > best_score { | |
| best_score = dp[t - 1][s - 2]; | |
| best_prev = s - 2; | |
| } | |
| dp[t][s] = best_score + log_prob(t, seq[s]); | |
| bt[t][s] = best_prev; | |
| } | |
| } | |
| // --- Backtrack --- | |
| let mut path = vec![0usize; time_steps]; | |
| // Start backtrack from the best final state (last token or last blank) | |
| path[time_steps - 1] = if dp[time_steps - 1][s_len - 1] > dp[time_steps - 1][s_len - 2] { | |
| s_len - 1 | |
| } else { | |
| s_len - 2 | |
| }; | |
| for t in (0..time_steps - 1).rev() { | |
| path[t] = bt[t + 1][path[t + 1]]; | |
| } | |
| // --- Convert path to char timestamps --- | |
| // wav2vec2 base: 1 frame = 320 samples at 16kHz = 20ms | |
| let frame_duration = 320.0 / 16000.0; | |
| let transcript_chars: Vec<char> = transcript.chars().collect(); | |
| // Collect per-token (start_frame, end_frame), skipping blanks | |
| let mut token_spans: Vec<(usize, usize, char)> = Vec::new(); | |
| let mut t = 0; | |
| while t < time_steps { | |
| let s = path[t]; | |
| if seq[s] != 0 { | |
| let start_frame = t; | |
| while t < time_steps && path[t] == s { | |
| t += 1; | |
| } | |
| let ch = transcript_chars[s / 2]; // seq is [blank, c0, blank, c1, ...] so char index = s/2 | |
| token_spans.push((start_frame, t, ch)); | |
| } else { | |
| t += 1; | |
| } | |
| } | |
| // --- Group chars into words at ' ' boundaries --- | |
| let mut alignments = Vec::new(); | |
| let mut current_word = String::new(); | |
| let mut word_start = 0.0f32; | |
| let mut word_end = 0.0f32; | |
| let mut visemes = Vec::new(); | |
| for (start_frame, end_frame, ch) in token_spans { | |
| let start_sec = start_frame as f32 * frame_duration; | |
| let end_sec = end_frame as f32 * frame_duration; | |
| if let Some(viseme) = char_to_viseme(ch) { | |
| visemes.push(VisemeEntry { | |
| viseme: viseme.to_string(), | |
| start: start_sec, | |
| }); | |
| } | |
| if ch == ' ' { | |
| if !current_word.is_empty() { | |
| alignments.push(WordAlignment { | |
| word: current_word.clone(), | |
| start: word_start, | |
| end: word_end, | |
| }); | |
| current_word.clear(); | |
| } | |
| } else { | |
| if current_word.is_empty() { | |
| word_start = start_sec; | |
| } | |
| current_word.push(ch); | |
| word_end = end_sec; | |
| } | |
| } | |
| // Push last word | |
| if !current_word.is_empty() { | |
| alignments.push(WordAlignment { | |
| word: current_word, | |
| start: word_start, | |
| end: word_end, | |
| }); | |
| } | |
| Ok(AlignmentResult { | |
| words: alignments, | |
| visemes, | |
| }) | |
| } | |