import argparse import time import os import json import torch from datasets import load_dataset from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan from soundfile import write DEFAULT_PROMPT = "Hugging Face Jobs make cloud compute incredibly straightforward." def run_batch_inference(prompts, run_id=None): # 1. Set up input/output paths (we map this directory via HF Jobs) base_output_dir = os.environ.get("OUTPUT_DIR", "/data/output") # Each run gets its own subdirectory so concurrent jobs don't collide in the bucket. run_id = run_id or os.environ.get("JOB_ID", "local") output_dir = os.path.join(base_output_dir, run_id) os.makedirs(output_dir, exist_ok=True) print("🚀 Initializing TTS Model...") device = "cuda" if torch.cuda.is_available() else "cpu" processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) # Load speaker embeddings for voice styling (via parquet to avoid deprecated dataset script) parquet_url = "https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet" embeddings_dataset = load_dataset("parquet", data_files=parquet_url, split="train") speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device) metrics = [] print(f"đŸŽ™ī¸ Starting batch inference on {device}...") for idx, text in enumerate(prompts): start_time = time.time() # Tokenize and generate audio inputs = processor(text=text, return_tensors="pt").to(device) speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder) generation_time = time.time() - start_time audio_filename = f"audio_sample_{idx}.wav" audio_path = os.path.join(output_dir, audio_filename) # Save audio file write(audio_path, speech.cpu().numpy(), samplerate=16000) # Track metrics metrics.append({ "sample_id": idx, "text": text, "generation_time_seconds": round(generation_time, 3), "audio_file": audio_filename }) print(f"✅ Generated sample {idx} in {generation_time:.2f}s") # 2. Save your metrics JSON summary_metrics = { "total_samples": len(prompts), "hardware_used": device, "average_generation_time": round(sum(m["generation_time_seconds"] for m in metrics) / len(metrics), 3), "detailed_runs": metrics } metrics_path = os.path.join(output_dir, "inference_metrics.json") with open(metrics_path, "w") as f: json.dump(summary_metrics, f, indent=4) print(f"📊 Metrics and audio saved to {output_dir}") def parse_args(): parser = argparse.ArgumentParser( description="Generate speech from a text prompt using SpeechT5.", ) parser.add_argument( "--run-id", default=None, help="Subdirectory name under OUTPUT_DIR for this run's outputs. " "Defaults to $JOB_ID when set, else 'local'.", ) parser.add_argument( "--model-id", default=None, help="Hub model id requested by the caller. Currently ignored — the " "script always runs SpeechT5. Wire branching here when adding " "support for more TTS models.", ) parser.add_argument( "text", nargs="?", default=DEFAULT_PROMPT, help="Sentence to synthesize. Defaults to a built-in demo prompt.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.model_id: print(f"â„šī¸ Received --model-id={args.model_id} (ignored; running SpeechT5).") run_batch_inference([args.text], run_id=args.run_id)