| | |
| |
|
| | use crate::{Error, Result}; |
| | use ndarray::{Array1, Array2, Array, IxDyn}; |
| | use std::collections::HashMap; |
| | use std::path::Path; |
| |
|
| | use super::OnnxSession; |
| |
|
| | |
| | pub struct SpeakerEncoder { |
| | session: Option<OnnxSession>, |
| | embedding_dim: usize, |
| | } |
| |
|
| | impl SpeakerEncoder { |
| | |
| | pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let session = OnnxSession::load(path)?; |
| | Ok(Self { |
| | session: Some(session), |
| | embedding_dim: 192, |
| | }) |
| | } |
| |
|
| | |
| | pub fn new_placeholder(embedding_dim: usize) -> Self { |
| | Self { |
| | session: None, |
| | embedding_dim, |
| | } |
| | } |
| |
|
| | |
| | pub fn encode(&self, mel_spectrogram: &Array2<f32>) -> Result<Array1<f32>> { |
| | if let Some(ref session) = self.session { |
| | |
| | let input = mel_spectrogram |
| | .clone() |
| | .into_shape(IxDyn(&[1, mel_spectrogram.nrows(), mel_spectrogram.ncols()]))?; |
| |
|
| | let mut inputs = HashMap::new(); |
| | inputs.insert("mel".to_string(), input); |
| |
|
| | let outputs = session.run(inputs)?; |
| |
|
| | let embedding = outputs |
| | .get("embedding") |
| | .ok_or_else(|| Error::Model("Missing embedding output".into()))?; |
| |
|
| | |
| | let flat: Vec<f32> = embedding.iter().cloned().collect(); |
| | Ok(Array1::from_vec(flat)) |
| | } else { |
| | |
| | Ok(Array1::from_vec(vec![0.0f32; self.embedding_dim])) |
| | } |
| | } |
| |
|
| | |
| | pub fn encode_audio(&self, audio_path: &str) -> Result<Array1<f32>> { |
| | use crate::audio::{compute_mel_from_file, AudioConfig}; |
| |
|
| | let config = AudioConfig::default(); |
| | let mel = compute_mel_from_file(audio_path, &config)?; |
| | self.encode(&mel) |
| | } |
| |
|
| | |
| | pub fn embedding_dim(&self) -> usize { |
| | self.embedding_dim |
| | } |
| |
|
| | |
| | pub fn normalize_embedding(&self, embedding: &Array1<f32>) -> Array1<f32> { |
| | let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); |
| | if norm > 1e-8 { |
| | embedding / norm |
| | } else { |
| | embedding.clone() |
| | } |
| | } |
| |
|
| | |
| | pub fn cosine_similarity(&self, emb1: &Array1<f32>, emb2: &Array1<f32>) -> f32 { |
| | let norm1 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt(); |
| | let norm2 = emb2.iter().map(|x| x * x).sum::<f32>().sqrt(); |
| |
|
| | if norm1 < 1e-8 || norm2 < 1e-8 { |
| | return 0.0; |
| | } |
| |
|
| | let dot: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum(); |
| | dot / (norm1 * norm2) |
| | } |
| | } |
| |
|
| | |
| | pub struct EmotionEncoder { |
| | |
| | emotion_matrix: Array2<f32>, |
| | |
| | num_dims: usize, |
| | |
| | dim_sizes: Vec<usize>, |
| | } |
| |
|
| | impl EmotionEncoder { |
| | |
| | pub fn new(num_dims: usize, dim_sizes: Vec<usize>, embedding_dim: usize) -> Self { |
| | let total_emotions: usize = dim_sizes.iter().sum(); |
| | let emotion_matrix = Array2::zeros((total_emotions, embedding_dim)); |
| |
|
| | Self { |
| | emotion_matrix, |
| | num_dims, |
| | dim_sizes, |
| | } |
| | } |
| |
|
| | |
| | pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let path = path.as_ref(); |
| | if !path.exists() { |
| | return Err(Error::FileNotFound(path.display().to_string())); |
| | } |
| |
|
| | |
| | let file_data = std::fs::read(path)?; |
| | let tensors = safetensors::SafeTensors::deserialize(&file_data) |
| | .map_err(|e| Error::ModelLoading(format!("Failed to load safetensors: {}", e)))?; |
| |
|
| | |
| | let tensor = tensors |
| | .tensor("emotion_matrix") |
| | .map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?; |
| |
|
| | let shape = tensor.shape(); |
| | let data: Vec<f32> = tensor.data().chunks(4).map(|b| { |
| | f32::from_le_bytes([b[0], b[1], b[2], b[3]]) |
| | }).collect(); |
| |
|
| | let emotion_matrix = Array2::from_shape_vec((shape[0], shape[1]), data) |
| | .map_err(|e| Error::ModelLoading(format!("Shape mismatch: {}", e)))?; |
| |
|
| | |
| | let num_dims = 8; |
| | let dim_sizes = vec![5, 6, 8, 6, 5, 4, 7, 6]; |
| |
|
| | Ok(Self { |
| | emotion_matrix, |
| | num_dims, |
| | dim_sizes, |
| | }) |
| | } |
| |
|
| | |
| | pub fn encode(&self, emotion_vector: &[f32]) -> Result<Array1<f32>> { |
| | if emotion_vector.len() != self.num_dims { |
| | return Err(Error::ShapeMismatch { |
| | expected: format!("{} dimensions", self.num_dims), |
| | actual: format!("{} dimensions", emotion_vector.len()), |
| | }); |
| | } |
| |
|
| | let embedding_dim = self.emotion_matrix.ncols(); |
| | let mut embedding = vec![0.0f32; embedding_dim]; |
| |
|
| | let mut offset = 0; |
| | for (dim_idx, (&value, &dim_size)) in emotion_vector.iter().zip(self.dim_sizes.iter()).enumerate() { |
| | |
| | let continuous_idx = value * (dim_size - 1) as f32; |
| | let lower_idx = continuous_idx.floor() as usize; |
| | let upper_idx = (lower_idx + 1).min(dim_size - 1); |
| | let alpha = continuous_idx - lower_idx as f32; |
| |
|
| | |
| | for i in 0..embedding_dim { |
| | let lower_val = self.emotion_matrix[[offset + lower_idx, i]]; |
| | let upper_val = self.emotion_matrix[[offset + upper_idx, i]]; |
| | embedding[i] += lower_val * (1.0 - alpha) + upper_val * alpha; |
| | } |
| |
|
| | offset += dim_size; |
| | } |
| |
|
| | |
| | let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); |
| | if norm > 1e-8 { |
| | for e in embedding.iter_mut() { |
| | *e /= norm; |
| | } |
| | } |
| |
|
| | Ok(Array1::from_vec(embedding)) |
| | } |
| |
|
| | |
| | pub fn neutral(&self) -> Vec<f32> { |
| | vec![0.5f32; self.num_dims] |
| | } |
| |
|
| | |
| | pub fn preset(&self, name: &str) -> Vec<f32> { |
| | match name { |
| | "happy" => vec![0.9, 0.7, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5], |
| | "sad" => vec![0.2, 0.3, 0.4, 0.5, 0.6, 0.5, 0.5, 0.5], |
| | "angry" => vec![0.8, 0.9, 0.7, 0.5, 0.3, 0.5, 0.5, 0.5], |
| | "fearful" => vec![0.3, 0.4, 0.8, 0.5, 0.7, 0.5, 0.5, 0.5], |
| | "surprised" => vec![0.7, 0.8, 0.7, 0.5, 0.5, 0.5, 0.5, 0.5], |
| | "neutral" | _ => self.neutral(), |
| | } |
| | } |
| |
|
| | |
| | pub fn interpolate(&self, emot1: &[f32], emot2: &[f32], alpha: f32) -> Vec<f32> { |
| | emot1 |
| | .iter() |
| | .zip(emot2.iter()) |
| | .map(|(&a, &b)| a * (1.0 - alpha) + b * alpha) |
| | .collect() |
| | } |
| |
|
| | |
| | pub fn apply_strength(&self, emotion: &[f32], strength: f32) -> Vec<f32> { |
| | let neutral = self.neutral(); |
| | self.interpolate(&neutral, emotion, strength) |
| | } |
| | } |
| |
|
| | |
| | pub struct SemanticEncoder { |
| | session: Option<OnnxSession>, |
| | embedding_dim: usize, |
| | } |
| |
|
| | impl SemanticEncoder { |
| | |
| | pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let session = OnnxSession::load(path)?; |
| | Ok(Self { |
| | session: Some(session), |
| | embedding_dim: 1024, |
| | }) |
| | } |
| |
|
| | |
| | pub fn new_placeholder() -> Self { |
| | Self { |
| | session: None, |
| | embedding_dim: 1024, |
| | } |
| | } |
| |
|
| | |
| | pub fn encode(&self, audio: &[f32], sample_rate: u32) -> Result<Vec<i64>> { |
| | if let Some(ref session) = self.session { |
| | let input = Array::from_shape_vec( |
| | IxDyn(&[1, audio.len()]), |
| | audio.to_vec(), |
| | )?; |
| |
|
| | let mut inputs = HashMap::new(); |
| | inputs.insert("audio".to_string(), input); |
| |
|
| | let outputs = session.run(inputs)?; |
| |
|
| | let codes = outputs |
| | .get("codes") |
| | .ok_or_else(|| Error::Model("Missing codes output".into()))?; |
| |
|
| | Ok(codes.iter().map(|&x| x as i64).collect()) |
| | } else { |
| | |
| | let num_codes = audio.len() / (sample_rate as usize / 50); |
| | Ok(vec![0i64; num_codes.max(1)]) |
| | } |
| | } |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_speaker_encoder_placeholder() { |
| | let encoder = SpeakerEncoder::new_placeholder(192); |
| | assert_eq!(encoder.embedding_dim(), 192); |
| | } |
| |
|
| | #[test] |
| | fn test_emotion_encoder() { |
| | let encoder = EmotionEncoder::new(8, vec![5, 6, 8, 6, 5, 4, 7, 6], 256); |
| | let neutral = encoder.neutral(); |
| | assert_eq!(neutral.len(), 8); |
| | assert!(neutral.iter().all(|&x| (x - 0.5).abs() < 1e-6)); |
| | } |
| |
|
| | #[test] |
| | fn test_emotion_presets() { |
| | let encoder = EmotionEncoder::new(8, vec![5, 6, 8, 6, 5, 4, 7, 6], 256); |
| | let happy = encoder.preset("happy"); |
| | assert_eq!(happy.len(), 8); |
| | assert!(happy[0] > 0.5); |
| | } |
| |
|
| | #[test] |
| | fn test_emotion_interpolation() { |
| | let encoder = EmotionEncoder::new(8, vec![5, 6, 8, 6, 5, 4, 7, 6], 256); |
| | let happy = encoder.preset("happy"); |
| | let sad = encoder.preset("sad"); |
| | let mid = encoder.interpolate(&happy, &sad, 0.5); |
| |
|
| | |
| | for i in 0..8 { |
| | assert!((mid[i] - (happy[i] + sad[i]) / 2.0).abs() < 1e-6); |
| | } |
| | } |
| |
|
| | #[test] |
| | fn test_cosine_similarity() { |
| | let encoder = SpeakerEncoder::new_placeholder(3); |
| | let emb1 = Array1::from_vec(vec![1.0, 0.0, 0.0]); |
| | let emb2 = Array1::from_vec(vec![1.0, 0.0, 0.0]); |
| | let sim = encoder.cosine_similarity(&emb1, &emb2); |
| | assert!((sim - 1.0).abs() < 1e-6); |
| |
|
| | let emb3 = Array1::from_vec(vec![0.0, 1.0, 0.0]); |
| | let sim2 = encoder.cosine_similarity(&emb1, &emb3); |
| | assert!(sim2.abs() < 1e-6); |
| | } |
| | } |
| |
|