| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import os |
| | import argparse |
| | import torch |
| | import soundfile as sf |
| | import logging |
| | from datetime import datetime |
| | import platform |
| |
|
| | from cli.SparkTTS import SparkTTS |
| |
|
| |
|
| | def parse_args(): |
| | """Parse command-line arguments.""" |
| | parser = argparse.ArgumentParser(description="Run TTS inference.") |
| |
|
| | parser.add_argument( |
| | "--model_dir", |
| | type=str, |
| | default="pretrained_models/Spark-TTS-0.5B", |
| | help="Path to the model directory", |
| | ) |
| | parser.add_argument( |
| | "--save_dir", |
| | type=str, |
| | default="example/results", |
| | help="Directory to save generated audio files", |
| | ) |
| | parser.add_argument("--device", type=int, default=0, help="CUDA device number") |
| | parser.add_argument( |
| | "--text", type=str, required=True, help="Text for TTS generation" |
| | ) |
| | parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio") |
| | parser.add_argument( |
| | "--prompt_speech_path", |
| | type=str, |
| | help="Path to the prompt audio file", |
| | ) |
| | parser.add_argument("--gender", choices=["male", "female"]) |
| | parser.add_argument( |
| | "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"] |
| | ) |
| | parser.add_argument( |
| | "--speed", choices=["very_low", "low", "moderate", "high", "very_high"] |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def run_tts(args): |
| | """Perform TTS inference and save the generated audio.""" |
| | logging.info(f"Using model from: {args.model_dir}") |
| | logging.info(f"Saving audio to: {args.save_dir}") |
| |
|
| | |
| | os.makedirs(args.save_dir, exist_ok=True) |
| |
|
| | |
| | if platform.system() == "Darwin" and torch.backends.mps.is_available(): |
| | |
| | device = torch.device(f"mps:{args.device}") |
| | logging.info(f"Using MPS device: {device}") |
| | elif torch.cuda.is_available(): |
| | |
| | device = torch.device(f"cuda:{args.device}") |
| | logging.info(f"Using CUDA device: {device}") |
| | else: |
| | |
| | device = torch.device("cpu") |
| | logging.info("GPU acceleration not available, using CPU") |
| |
|
| | |
| | model = SparkTTS(args.model_dir, device) |
| |
|
| | |
| | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
| | save_path = os.path.join(args.save_dir, f"{timestamp}.wav") |
| |
|
| | logging.info("Starting inference...") |
| |
|
| | |
| | with torch.no_grad(): |
| | wav = model.inference( |
| | args.text, |
| | args.prompt_speech_path, |
| | prompt_text=args.prompt_text, |
| | gender=args.gender, |
| | pitch=args.pitch, |
| | speed=args.speed, |
| | ) |
| | sf.write(save_path, wav, samplerate=16000) |
| |
|
| | logging.info(f"Audio saved at: {save_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
| | ) |
| |
|
| | args = parse_args() |
| | run_tts(args) |
| |
|