two_agents_live / src /forced_alignment.rs
inventwithdean
initial push
552370e
use std::usize;
use anyhow::anyhow;
use ort::{
inputs,
session::{Session, builder::GraphOptimizationLevel},
value::Tensor,
};
use serde::Serialize;
pub fn initialize_forced_aligner() -> Result<Session, anyhow::Error> {
let lib_path = if cfg!(target_os = "windows") {
"./lib/onnxruntime.dll"
} else {
"./onnxruntime-linux-x64-1.24.4/lib/libonnxruntime.so"
};
let ok = ort::init_from(lib_path)?.commit();
if !ok {
anyhow::bail!("Failed to initialize ONNX Runtime")
}
Ok(Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|_| anyhow!("Can't set optimization"))?
.with_intra_threads(2)
.map_err(|_| anyhow!("Can't set intra threads"))?
.commit_from_file("./wav2vec2_q4f16.onnx")?)
}
fn char_to_token(c: char) -> usize {
match c.to_ascii_uppercase() {
'E' => 5,
'T' => 6,
'A' => 7,
'O' => 8,
'N' => 9,
'I' => 10,
'H' => 11,
'S' => 12,
'R' => 13,
'D' => 14,
'L' => 15,
'U' => 16,
'M' => 17,
'W' => 18,
'C' => 19,
'F' => 20,
'G' => 21,
'Y' => 22,
'P' => 23,
'B' => 24,
'V' => 25,
'K' => 26,
'\'' => 27,
'X' => 28,
'J' => 29,
'Q' => 30,
'Z' => 31,
' ' => 4,
_ => 3, // <unk>
}
}
pub struct WordAlignment {
pub word: String,
pub start: f32,
pub end: f32,
}
#[derive(Serialize)]
pub struct VisemeEntry {
pub viseme: String,
pub start: f32,
}
pub struct AlignmentResult {
pub words: Vec<WordAlignment>,
pub visemes: Vec<VisemeEntry>,
}
// For vowels only, works pretty good
fn char_to_viseme(c: char) -> Option<&'static str> {
match c.to_ascii_uppercase() {
'A' => Some("aa"),
'E' => Some("ee"),
'I' => Some("ih"),
'O' => Some("oh"),
'U' => Some("ou"),
// consonants - no visemes, mouth closed
_ => None,
}
}
// CTC Forced alignment algorithm
pub fn forced_align(
session: &mut Session,
audio: Vec<f32>,
transcript: &str,
) -> Result<AlignmentResult, anyhow::Error> {
// Inference, extract probabilities
let len = audio.len();
let tensor = Tensor::from_array(([1usize, len], audio.into_boxed_slice()))?;
let outputs = session.run(inputs!["input_values" => tensor])?;
let logits = outputs["logits"].try_extract_tensor::<f32>()?;
let shape = logits.0;
let time_steps = shape[1] as usize;
let vocab_size = shape[2] as usize;
let data = logits.1.iter().as_slice();
// Build CTC label sequence with blanks
// [blank, t0, blank, t1, blank, t2, blank, ...]
let tokens: Vec<usize> = transcript.chars().map(char_to_token).collect();
let mut seq = vec![0usize]; // start with blank
for &t in &tokens {
seq.push(t);
seq.push(0); // blank after each token
}
let s_len = seq.len();
// Viterbi forward pass
let neg_inf = f32::NEG_INFINITY;
let mut dp = vec![vec![neg_inf; s_len]; time_steps];
let mut bt = vec![vec![0usize; s_len]; time_steps];
let log_prob = |t: usize, tok: usize| -> f32 { data[t * vocab_size + tok] };
// Initialize first framee
dp[0][0] = log_prob(0, seq[0]);
if s_len > 1 {
dp[0][1] = log_prob(0, seq[1]);
}
for t in 1..time_steps {
for s in 0..s_len {
let mut best_score = neg_inf;
let mut best_prev = s;
// Transition from s (stay)
if dp[t - 1][s] > best_score {
best_score = dp[t - 1][s];
best_prev = s;
}
// Transition from s-1
if s >= 1 && dp[t - 1][s - 1] > best_score {
best_score = dp[t - 1][s - 1];
best_prev = s - 1;
}
// Transition from s-2 (skip blank)
if s >= 2 && seq[s - 1] == 0 && seq[s] != seq[s - 2] && dp[t - 1][s - 2] > best_score {
best_score = dp[t - 1][s - 2];
best_prev = s - 2;
}
dp[t][s] = best_score + log_prob(t, seq[s]);
bt[t][s] = best_prev;
}
}
// --- Backtrack ---
let mut path = vec![0usize; time_steps];
// Start backtrack from the best final state (last token or last blank)
path[time_steps - 1] = if dp[time_steps - 1][s_len - 1] > dp[time_steps - 1][s_len - 2] {
s_len - 1
} else {
s_len - 2
};
for t in (0..time_steps - 1).rev() {
path[t] = bt[t + 1][path[t + 1]];
}
// --- Convert path to char timestamps ---
// wav2vec2 base: 1 frame = 320 samples at 16kHz = 20ms
let frame_duration = 320.0 / 16000.0;
let transcript_chars: Vec<char> = transcript.chars().collect();
// Collect per-token (start_frame, end_frame), skipping blanks
let mut token_spans: Vec<(usize, usize, char)> = Vec::new();
let mut t = 0;
while t < time_steps {
let s = path[t];
if seq[s] != 0 {
let start_frame = t;
while t < time_steps && path[t] == s {
t += 1;
}
let ch = transcript_chars[s / 2]; // seq is [blank, c0, blank, c1, ...] so char index = s/2
token_spans.push((start_frame, t, ch));
} else {
t += 1;
}
}
// --- Group chars into words at ' ' boundaries ---
let mut alignments = Vec::new();
let mut current_word = String::new();
let mut word_start = 0.0f32;
let mut word_end = 0.0f32;
let mut visemes = Vec::new();
for (start_frame, end_frame, ch) in token_spans {
let start_sec = start_frame as f32 * frame_duration;
let end_sec = end_frame as f32 * frame_duration;
if let Some(viseme) = char_to_viseme(ch) {
visemes.push(VisemeEntry {
viseme: viseme.to_string(),
start: start_sec,
});
}
if ch == ' ' {
if !current_word.is_empty() {
alignments.push(WordAlignment {
word: current_word.clone(),
start: word_start,
end: word_end,
});
current_word.clear();
}
} else {
if current_word.is_empty() {
word_start = start_sec;
}
current_word.push(ch);
word_end = end_sec;
}
}
// Push last word
if !current_word.is_empty() {
alignments.push(WordAlignment {
word: current_word,
start: word_start,
end: word_end,
});
}
Ok(AlignmentResult {
words: alignments,
visemes,
})
}