prosody-predictor / extract_features.py
hidude562's picture
Upload extract_features.py with huggingface_hub
fb18897 verified
"""Extract F0 (pyin) + RMS from WAVs, tokenize text, compute durations."""
import argparse
import json
import os
import numpy as np
import torch
import librosa
# Character vocabulary: PAD=0, UNK=1, space=2, a-z=3-28, 0-9=29-38, punct 39+
VOCAB = {chr(0): 0} # PAD
VOCAB['<UNK>'] = 1
VOCAB[' '] = 2
for i, c in enumerate('abcdefghijklmnopqrstuvwxyz'):
VOCAB[c] = 3 + i
for i, c in enumerate('0123456789'):
VOCAB[c] = 29 + i
PUNCT = ".,;:!?'-\"()/"
for i, c in enumerate(PUNCT):
VOCAB[c] = 39 + i
VOCAB_SIZE = max(VOCAB.values()) + 1
def tokenize(text):
"""Lowercase text → char ID list."""
text = text.lower()
return [VOCAB.get(c, VOCAB['<UNK>']) for c in text]
def proportional_durations(n_chars, n_frames):
"""Split n_frames proportionally across n_chars."""
if n_chars == 0:
return []
base = n_frames // n_chars
remainder = n_frames % n_chars
durations = [base + (1 if i < remainder else 0) for i in range(n_chars)]
return durations
def extract_one(wav_path, text, sr=24000, hop_length=2400):
"""Extract features for a single sample."""
y, _ = librosa.load(wav_path, sr=sr, mono=True)
# F0 via pyin
f0, voiced_flag, _ = librosa.pyin(
y, fmin=50, fmax=600, sr=sr, hop_length=hop_length
)
# Clip to 300 Hz to suppress octave jumps; mark >300 as unvoiced
too_high = ~np.isnan(f0) & (f0 > 300)
voiced_flag[too_high] = False
f0 = np.where(np.isnan(f0), 0.0, np.clip(f0, 50, 300))
# RMS
rms = librosa.feature.rms(y=y, hop_length=hop_length, frame_length=hop_length)[0]
# Align lengths
n_frames = min(len(f0), len(rms))
f0 = f0[:n_frames]
voiced_flag = voiced_flag[:n_frames]
rms = rms[:n_frames]
# Log space
voiced_mask = voiced_flag.astype(bool)
log_f0 = np.zeros_like(f0)
log_f0[voiced_mask] = np.log(f0[voiced_mask])
log_rms = np.log(rms + 1e-8)
# Tokenize
char_ids = tokenize(text)
# Duration alignment
durations = proportional_durations(len(char_ids), n_frames)
return {
'char_ids': np.array(char_ids, dtype=np.int64),
'durations': np.array(durations, dtype=np.int64),
'log_f0': log_f0.astype(np.float32),
'log_rms': log_rms.astype(np.float32),
'voiced_mask': voiced_mask,
'n_frames': n_frames,
'text': text,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--audio_dir', required=True)
parser.add_argument('--transcripts', required=True)
parser.add_argument('--output', default='features.pt')
args = parser.parse_args()
with open(args.transcripts) as f:
transcripts = json.load(f)
# Filter to sample_*.wav keys only
keys = sorted([k for k in transcripts if k.startswith('sample_') and k.endswith('.wav')])
print(f"Processing {len(keys)} samples...")
samples = []
all_voiced_f0 = []
all_log_rms = []
for i, key in enumerate(keys):
wav_path = os.path.join(args.audio_dir, key)
if not os.path.exists(wav_path):
print(f" SKIP {key}: file not found")
continue
feat = extract_one(wav_path, transcripts[key])
samples.append(feat)
# Collect for normalization
if feat['voiced_mask'].any():
all_voiced_f0.append(feat['log_f0'][feat['voiced_mask']])
all_log_rms.append(feat['log_rms'])
if (i + 1) % 200 == 0:
print(f" {i+1}/{len(keys)} done")
# Global normalization stats
all_voiced_f0 = np.concatenate(all_voiced_f0)
all_log_rms = np.concatenate(all_log_rms)
norm_stats = {
'f0_mean': float(np.mean(all_voiced_f0)),
'f0_std': float(np.std(all_voiced_f0)),
'rms_mean': float(np.mean(all_log_rms)),
'rms_std': float(np.std(all_log_rms)),
}
print(f"Norm stats: {norm_stats}")
# Z-score normalize
for s in samples:
voiced = s['voiced_mask']
s['log_f0'][voiced] = (s['log_f0'][voiced] - norm_stats['f0_mean']) / norm_stats['f0_std']
s['log_rms'] = (s['log_rms'] - norm_stats['rms_mean']) / norm_stats['rms_std']
# Print stats
voiced_ratios = [s['voiced_mask'].mean() for s in samples]
frame_counts = [s['n_frames'] for s in samples]
print(f"Samples: {len(samples)}")
print(f"Voiced ratio: {np.mean(voiced_ratios):.3f} ± {np.std(voiced_ratios):.3f}")
print(f"Frame counts: {np.mean(frame_counts):.1f} ± {np.std(frame_counts):.1f} "
f"(min={np.min(frame_counts)}, max={np.max(frame_counts)})")
# Convert to tensors for saving
for s in samples:
s['char_ids'] = torch.from_numpy(s['char_ids'])
s['durations'] = torch.from_numpy(s['durations'])
s['log_f0'] = torch.from_numpy(s['log_f0'])
s['log_rms'] = torch.from_numpy(s['log_rms'])
s['voiced_mask'] = torch.from_numpy(s['voiced_mask'])
torch.save({'samples': samples, 'norm_stats': norm_stats, 'vocab_size': VOCAB_SIZE}, args.output)
print(f"Saved to {args.output}")
if __name__ == '__main__':
main()