prosody-predictor / infer_prosody.py
hidude562's picture
Upload infer_prosody.py with huggingface_hub
ff02e9f verified
"""Inference: text → (pitch, volume) contour."""
import argparse
import json
import numpy as np
import torch
from model_prosody import ProsodyPredictor
from extract_features import VOCAB, VOCAB_SIZE, tokenize
def predict_prosody(text, model, norm_stats, device='cpu'):
"""Run inference on a single text string.
Returns:
dict with f0_hz (array), rms (array), duration_s (float)
"""
model.eval()
char_ids = torch.tensor([tokenize(text)], dtype=torch.long, device=device)
char_lengths = torch.tensor([char_ids.size(1)], dtype=torch.long, device=device)
with torch.no_grad():
pred_f0, pred_rms, pred_log_dur, frame_lengths = model(
char_ids, durations=None, char_lengths=char_lengths
)
T = frame_lengths[0].item()
f0_norm = pred_f0[0, :T].cpu().numpy()
rms_norm = pred_rms[0, :T].cpu().numpy()
# Denormalize
f0_log = f0_norm * norm_stats['f0_std'] + norm_stats['f0_mean']
f0_hz = np.exp(f0_log)
f0_hz = np.clip(f0_hz, 50, 600)
rms_log = rms_norm * norm_stats['rms_std'] + norm_stats['rms_mean']
rms = np.exp(rms_log)
duration_s = T * 0.1 # 100ms per frame
return {
'f0_hz': f0_hz,
'rms': rms,
'duration_s': duration_s,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--text', type=str, default=None)
parser.add_argument('--texts_file', type=str, default=None, help='JSON file or one text per line')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
norm_stats = ckpt['norm_stats']
vocab_size = ckpt.get('vocab_size', VOCAB_SIZE)
model = ProsodyPredictor(vocab_size=vocab_size, d_model=128, dropout=0.0).to(device)
model.load_state_dict(ckpt['model'])
model.eval()
print(f"Loaded model from {args.checkpoint}")
texts = []
if args.text:
texts = [args.text]
elif args.texts_file:
if args.texts_file.endswith('.json'):
with open(args.texts_file) as f:
data = json.load(f)
if isinstance(data, list):
texts = data
elif isinstance(data, dict):
texts = list(data.values())
else:
with open(args.texts_file) as f:
texts = [line.strip() for line in f if line.strip()]
else:
parser.error("Provide --text or --texts_file")
for i, text in enumerate(texts):
result = predict_prosody(text, model, norm_stats, device)
f0 = result['f0_hz']
rms = result['rms']
print(f"\n[{i}] \"{text}\"")
print(f" Duration: {result['duration_s']:.1f}s ({len(f0)} frames)")
print(f" F0: mean={f0.mean():.1f} Hz, min={f0.min():.1f}, max={f0.max():.1f}")
print(f" RMS: mean={rms.mean():.4f}, min={rms.min():.4f}, max={rms.max():.4f}")
if __name__ == '__main__':
main()