import amfm_decompy.basic_tools as basic import amfm_decompy.pYAAPT as pYAAPT from dataclasses import dataclass from typing import Dict, List, Optional import numpy as np import torch import dataclasses import parselmouth from transformers import PreTrainedModel,PretrainedConfig, FeatureExtractionMixin from datasets import Dataset from scipy.signal import medfilt import scipy.interpolate as scipy_interp @dataclass class SpeakerStats: f0_mean: float f0_std: float intensity_mean: float intensity_std: float @classmethod def from_features(cls, f0_values: List[np.ndarray], intensity_values: List[np.ndarray]): """Calculate stats from a list of features""" # Convert lists to numpy arrays f0_arrays = [np.array(f0) for f0 in f0_values] intensity_arrays = [np.array(i) for i in intensity_values] # Now we can use numpy operations f0_concat = np.concatenate([f0[f0 != 0] for f0 in f0_arrays]) intensity_concat = np.concatenate(intensity_arrays) return cls( f0_mean=float(np.mean(f0_concat)), f0_std=float(np.std(f0_concat)), intensity_mean=float(np.mean(intensity_concat)), intensity_std=float(np.std(intensity_concat)) ) class ProsodyConfig(PretrainedConfig): """Configuration class for prosody preprocessing""" model_type = "prosody_preprocessor" def __init__( self, sampling_rate: int = 16000, frame_length: float = 20.0, # in ms frame_space: float = 5.0, # in ms torch_dtype: str = "float32", # Add default torch_dtype **kwargs ): super().__init__(torch_dtype=torch_dtype, **kwargs) # Pass torch_dtype to parent self.sampling_rate = sampling_rate self.frame_length = frame_length self.frame_space = frame_space class ProsodyPreprocessor(FeatureExtractionMixin): config_class = ProsodyConfig def __init__(self, sampling_rate: int = 16000, frame_length: float = 20.0, # in ms frame_space: float = 5.0, # in ms torch_dtype: str = "float32", # Add default torch_dtype config: Optional[ProsodyConfig] = None, **kwargs): # config = config or ProsodyConfig() super().__init__() self.config = config self.speaker_stats: Dict[str, SpeakerStats] = {} self.sampling_rate = sampling_rate self.frame_length = frame_length self.frame_space = frame_space def extract_features(self, audio): """Extract F0 and intensity features""" # Override the original method to fix a bug pYAAPT.PitchObj.interpolate = interpolate audio = torch.Tensor(audio) if audio.dim() == 1: audio = audio.unsqueeze(0) f0, f0_interp = self._get_f0(audio) f0 = f0[0, 0, :] f0_interpolated = f0_interp[0, 0, :] f0 = f0[6:] f0_interpolated = f0_interpolated[6:] sound = parselmouth.Sound(audio.numpy(), sampling_frequency=self.sampling_rate, start_time=0) # Extract intensity at 200Hz intensity = sound.to_intensity(time_step=1/200.0) intensity_values = intensity.values.T.flatten() # Ensure same length min_len = min(len(f0), len(intensity)) f0 = f0[:min_len] f0_interpolated = f0_interpolated[:min_len] intensity_values = intensity_values[:min_len] intensity_values[intensity_values < 20] = 20 return { "f0": f0, "f0_interp": f0_interpolated, "intensity": intensity_values, } def collect_stats(self, dataset: Dataset, num_proc: int = 4, batch_size: int = 32) -> Dict[str, SpeakerStats]: """First pass: collect speaker statistics using dataset.map""" def extract_features_batch(examples): features_list = [] for audio in examples['audio']: features = self.extract_features(audio) features_list.append(features) return { 'f0': [f['f0'] for f in features_list], 'intensity': [f['intensity'] for f in features_list], 'speaker_id': examples['speaker_id'] } features_dataset = dataset.map( extract_features_batch, batched=True, batch_size=batch_size, num_proc=num_proc, # load_from_cache_file=False remove_columns=dataset.column_names ) speaker_features = {} for item in features_dataset: speaker_id = item['speaker_id'] if speaker_id not in speaker_features: speaker_features[speaker_id] = {'f0': [], 'intensity': []} speaker_features[speaker_id]['f0'].append(item['f0']) speaker_features[speaker_id]['intensity'].append(item['intensity']) self.speaker_stats = { spk: SpeakerStats.from_features( feats['f0'], feats['intensity'] ) for spk, feats in speaker_features.items() } return features_dataset, self.speaker_stats def save_stats(self, path: str): """Save speaker stats to file""" stats_dict = { spk: dataclasses.asdict(stats) for spk, stats in self.speaker_stats.items() } torch.save(stats_dict, path) @classmethod def load_stats(cls, path: str) -> Dict[str, SpeakerStats]: """Load speaker stats from file""" stats_dict = torch.load(path) return { spk: SpeakerStats(**stats) for spk, stats in stats_dict.items() } def _get_f0(self, audio: torch.Tensor): """Extract F0 using YAAPT.""" to_pad = int(self.frame_length / 1000 * self.sampling_rate) // 2 f0s = [] f0s_interp = [] for y in audio.numpy().astype(np.float64): y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0) signal = basic.SignalObj(y_pad, self.sampling_rate) pitch = pYAAPT.yaapt( signal, frame_length=self.frame_length, frame_space=self.frame_space, nccf_thresh1=0.25, tda_frame_length=25.0 ) f0s_interp.append(pitch.samp_interp[None, None, :]) f0s.append(pitch.samp_values[None, None, :]) f0 = np.vstack(f0s) f0_interp = np.vstack(f0s_interp) # Apply frequency threshold f0[f0 > 500] = 0 f0_interp[f0_interp > 500] = 0 f0[f0 < 0] = 0 f0_interp[f0_interp < 0] = 0 return f0, f0_interp # def save_pretrained(self, save_directory: str, **kwargs): # """Save the preprocessor configuration.""" # self.config.save_pretrained(save_directory) # # def _load_pretrained_model(self, **kwargs): # """Override _load_pretrained_model to load speaker stats""" # # self.speaker_stats = { # # spk: SpeakerStats(**stats) # # for spk, stats in state_dict.items() # # } def interpolate(self): pitch = np.zeros((self.nframes)) pitch[:] = self.samp_values pitch2 = medfilt(self.samp_values, self.SMOOTH_FACTOR) # This part in the original code is kind of confused and caused # some problems with the extrapolated points before the first # voiced frame and after the last voiced frame. So, I made some # small modifications in order to make it work better. edges = self.edges_finder(pitch) first_sample = pitch[0] last_sample = pitch[-1] if len(np.nonzero(pitch2)[0]) < 2: pitch[pitch == 0] = self.PTCH_TYP else: nz_pitch = pitch2[pitch2 > 0] pitch2 = scipy_interp.pchip(np.nonzero(pitch2)[0], nz_pitch)(range(self.nframes)) pitch[pitch == 0] = pitch2[pitch == 0] if self.SMOOTH > 0: pitch = medfilt(pitch, self.SMOOTH_FACTOR) try: if first_sample == 0: # This if statement fixes the bug that caused the whole f0 to be flattened if edges[0] == 0: edges[0] = 1 pitch[:edges[0]-1] = pitch[edges[0]] if last_sample == 0: pitch[edges[-1]+1:] = pitch[edges[-1]] except: pass self.samp_interp = pitch