| """Inference: text → (pitch, volume) contour.""" |
|
|
| import argparse |
| import json |
| import numpy as np |
| import torch |
|
|
| from model_prosody import ProsodyPredictor |
| from extract_features import VOCAB, VOCAB_SIZE, tokenize |
|
|
|
|
| def predict_prosody(text, model, norm_stats, device='cpu'): |
| """Run inference on a single text string. |
| |
| Returns: |
| dict with f0_hz (array), rms (array), duration_s (float) |
| """ |
| model.eval() |
| char_ids = torch.tensor([tokenize(text)], dtype=torch.long, device=device) |
| char_lengths = torch.tensor([char_ids.size(1)], dtype=torch.long, device=device) |
|
|
| with torch.no_grad(): |
| pred_f0, pred_rms, pred_log_dur, frame_lengths = model( |
| char_ids, durations=None, char_lengths=char_lengths |
| ) |
|
|
| T = frame_lengths[0].item() |
| f0_norm = pred_f0[0, :T].cpu().numpy() |
| rms_norm = pred_rms[0, :T].cpu().numpy() |
|
|
| |
| f0_log = f0_norm * norm_stats['f0_std'] + norm_stats['f0_mean'] |
| f0_hz = np.exp(f0_log) |
| f0_hz = np.clip(f0_hz, 50, 600) |
|
|
| rms_log = rms_norm * norm_stats['rms_std'] + norm_stats['rms_mean'] |
| rms = np.exp(rms_log) |
|
|
| duration_s = T * 0.1 |
|
|
| return { |
| 'f0_hz': f0_hz, |
| 'rms': rms, |
| 'duration_s': duration_s, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--checkpoint', required=True) |
| parser.add_argument('--text', type=str, default=None) |
| parser.add_argument('--texts_file', type=str, default=None, help='JSON file or one text per line') |
| args = parser.parse_args() |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| norm_stats = ckpt['norm_stats'] |
| vocab_size = ckpt.get('vocab_size', VOCAB_SIZE) |
|
|
| model = ProsodyPredictor(vocab_size=vocab_size, d_model=128, dropout=0.0).to(device) |
| model.load_state_dict(ckpt['model']) |
| model.eval() |
| print(f"Loaded model from {args.checkpoint}") |
|
|
| texts = [] |
| if args.text: |
| texts = [args.text] |
| elif args.texts_file: |
| if args.texts_file.endswith('.json'): |
| with open(args.texts_file) as f: |
| data = json.load(f) |
| if isinstance(data, list): |
| texts = data |
| elif isinstance(data, dict): |
| texts = list(data.values()) |
| else: |
| with open(args.texts_file) as f: |
| texts = [line.strip() for line in f if line.strip()] |
| else: |
| parser.error("Provide --text or --texts_file") |
|
|
| for i, text in enumerate(texts): |
| result = predict_prosody(text, model, norm_stats, device) |
| f0 = result['f0_hz'] |
| rms = result['rms'] |
| print(f"\n[{i}] \"{text}\"") |
| print(f" Duration: {result['duration_s']:.1f}s ({len(f0)} frames)") |
| print(f" F0: mean={f0.mean():.1f} Hz, min={f0.min():.1f}, max={f0.max():.1f}") |
| print(f" RMS: mean={rms.mean():.4f}, min={rms.min():.4f}, max={rms.max():.4f}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|