File size: 4,027 Bytes
e965984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import time
import os
import json
import torch
from datasets import load_dataset
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from soundfile import write

DEFAULT_PROMPT = "Hugging Face Jobs make cloud compute incredibly straightforward."


def run_batch_inference(prompts, run_id=None):
    # 1. Set up input/output paths (we map this directory via HF Jobs)
    base_output_dir = os.environ.get("OUTPUT_DIR", "/data/output")
    # Each run gets its own subdirectory so concurrent jobs don't collide in the bucket.
    run_id = run_id or os.environ.get("JOB_ID", "local")
    output_dir = os.path.join(base_output_dir, run_id)
    os.makedirs(output_dir, exist_ok=True)
    
    print("🚀 Initializing TTS Model...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
    model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
    vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
    
    # Load speaker embeddings for voice styling (via parquet to avoid deprecated dataset script)
    parquet_url = "https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet"
    embeddings_dataset = load_dataset("parquet", data_files=parquet_url, split="train")
    speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)

    metrics = []
    
    print(f"🎙️ Starting batch inference on {device}...")
    for idx, text in enumerate(prompts):
        start_time = time.time()
        
        # Tokenize and generate audio
        inputs = processor(text=text, return_tensors="pt").to(device)
        speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
        
        generation_time = time.time() - start_time
        audio_filename = f"audio_sample_{idx}.wav"
        audio_path = os.path.join(output_dir, audio_filename)
        
        # Save audio file
        write(audio_path, speech.cpu().numpy(), samplerate=16000)
        
        # Track metrics
        metrics.append({
            "sample_id": idx,
            "text": text,
            "generation_time_seconds": round(generation_time, 3),
            "audio_file": audio_filename
        })
        print(f"✅ Generated sample {idx} in {generation_time:.2f}s")

    # 2. Save your metrics JSON
    summary_metrics = {
        "total_samples": len(prompts),
        "hardware_used": device,
        "average_generation_time": round(sum(m["generation_time_seconds"] for m in metrics) / len(metrics), 3),
        "detailed_runs": metrics
    }
    
    metrics_path = os.path.join(output_dir, "inference_metrics.json")
    with open(metrics_path, "w") as f:
        json.dump(summary_metrics, f, indent=4)

    print(f"📊 Metrics and audio saved to {output_dir}")

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate speech from a text prompt using SpeechT5.",
    )
    parser.add_argument(
        "--run-id",
        default=None,
        help="Subdirectory name under OUTPUT_DIR for this run's outputs. "
             "Defaults to $JOB_ID when set, else 'local'.",
    )
    parser.add_argument(
        "--model-id",
        default=None,
        help="Hub model id requested by the caller. Currently ignored — the "
             "script always runs SpeechT5. Wire branching here when adding "
             "support for more TTS models.",
    )
    parser.add_argument(
        "text",
        nargs="?",
        default=DEFAULT_PROMPT,
        help="Sentence to synthesize. Defaults to a built-in demo prompt.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    if args.model_id:
        print(f"ℹ️ Received --model-id={args.model_id} (ignored; running SpeechT5).")
    run_batch_inference([args.text], run_id=args.run_id)