spark-tts
commited on
Commit
·
6f15685
1
Parent(s):
832ac1a
support voice creation
Browse files- cli/SparkTTS.py +121 -20
- cli/inference.py +38 -10
cli/SparkTTS.py
CHANGED
|
@@ -15,12 +15,13 @@
|
|
| 15 |
|
| 16 |
import re
|
| 17 |
import torch
|
|
|
|
| 18 |
from pathlib import Path
|
| 19 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 20 |
|
| 21 |
from sparktts.utils.file import load_config
|
| 22 |
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
| 23 |
-
from sparktts.utils.token_parser import TASK_TOKEN_MAP
|
| 24 |
|
| 25 |
|
| 26 |
class SparkTTS:
|
|
@@ -49,36 +50,36 @@ class SparkTTS:
|
|
| 49 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
| 50 |
self.model.to(self.device)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
def inference(
|
| 54 |
self,
|
| 55 |
text: str,
|
| 56 |
prompt_speech_path: Path,
|
| 57 |
prompt_text: str = None,
|
| 58 |
-
|
| 59 |
-
top_k: float = 50,
|
| 60 |
-
top_p: float = 0.95,
|
| 61 |
-
) -> torch.Tensor:
|
| 62 |
"""
|
| 63 |
-
|
| 64 |
|
| 65 |
Args:
|
| 66 |
text (str): The text input to be converted to speech.
|
| 67 |
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
| 68 |
prompt_text (str, optional): Transcript of the prompt audio.
|
| 69 |
-
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
| 70 |
-
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
| 71 |
-
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
| 72 |
|
| 73 |
-
|
| 74 |
-
torch.Tensor:
|
| 75 |
"""
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# Prepare the input tokens for the model
|
| 80 |
if prompt_text is not None:
|
| 81 |
-
semantic_tokens = "".join(
|
|
|
|
|
|
|
| 82 |
inputs = [
|
| 83 |
TASK_TOKEN_MAP["tts"],
|
| 84 |
"<|start_content|>",
|
|
@@ -103,7 +104,94 @@ class SparkTTS:
|
|
| 103 |
]
|
| 104 |
|
| 105 |
inputs = "".join(inputs)
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# Generate speech using the model
|
| 109 |
generated_ids = self.model.generate(
|
|
@@ -117,14 +205,27 @@ class SparkTTS:
|
|
| 117 |
|
| 118 |
# Trim the output tokens to remove the input tokens
|
| 119 |
generated_ids = [
|
| 120 |
-
output_ids[len(input_ids) :]
|
|
|
|
| 121 |
]
|
| 122 |
|
| 123 |
# Decode the generated tokens into text
|
| 124 |
predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 125 |
|
| 126 |
# Extract semantic token IDs from the generated text
|
| 127 |
-
pred_semantic_ids =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
# Convert semantic tokens back to waveform
|
| 130 |
wav = self.audio_tokenizer.detokenize(
|
|
@@ -132,4 +233,4 @@ class SparkTTS:
|
|
| 132 |
pred_semantic_ids.to(self.device),
|
| 133 |
)
|
| 134 |
|
| 135 |
-
return wav
|
|
|
|
| 15 |
|
| 16 |
import re
|
| 17 |
import torch
|
| 18 |
+
from typing import Tuple
|
| 19 |
from pathlib import Path
|
| 20 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 21 |
|
| 22 |
from sparktts.utils.file import load_config
|
| 23 |
from sparktts.models.audio_tokenizer import BiCodecTokenizer
|
| 24 |
+
from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
|
| 25 |
|
| 26 |
|
| 27 |
class SparkTTS:
|
|
|
|
| 50 |
self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
|
| 51 |
self.model.to(self.device)
|
| 52 |
|
| 53 |
+
def process_prompt(
|
|
|
|
| 54 |
self,
|
| 55 |
text: str,
|
| 56 |
prompt_speech_path: Path,
|
| 57 |
prompt_text: str = None,
|
| 58 |
+
) -> Tuple[str, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
| 59 |
"""
|
| 60 |
+
Process input for voice cloning.
|
| 61 |
|
| 62 |
Args:
|
| 63 |
text (str): The text input to be converted to speech.
|
| 64 |
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
| 65 |
prompt_text (str, optional): Transcript of the prompt audio.
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
Return:
|
| 68 |
+
Tuple[str, torch.Tensor]: Input prompt; global tokens
|
| 69 |
"""
|
| 70 |
+
|
| 71 |
+
global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
|
| 72 |
+
prompt_speech_path
|
| 73 |
+
)
|
| 74 |
+
global_tokens = "".join(
|
| 75 |
+
[f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
|
| 76 |
+
)
|
| 77 |
|
| 78 |
# Prepare the input tokens for the model
|
| 79 |
if prompt_text is not None:
|
| 80 |
+
semantic_tokens = "".join(
|
| 81 |
+
[f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
|
| 82 |
+
)
|
| 83 |
inputs = [
|
| 84 |
TASK_TOKEN_MAP["tts"],
|
| 85 |
"<|start_content|>",
|
|
|
|
| 104 |
]
|
| 105 |
|
| 106 |
inputs = "".join(inputs)
|
| 107 |
+
|
| 108 |
+
return inputs, global_token_ids
|
| 109 |
+
|
| 110 |
+
def process_prompt_control(
|
| 111 |
+
self,
|
| 112 |
+
gender: str,
|
| 113 |
+
pitch: str,
|
| 114 |
+
speed: str,
|
| 115 |
+
text: str,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Process input for voice creation.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
gender (str): female | male.
|
| 122 |
+
pitch (str): very_low | low | moderate | high | very_high
|
| 123 |
+
speed (str): very_low | low | moderate | high | very_high
|
| 124 |
+
text (str): The text input to be converted to speech.
|
| 125 |
+
|
| 126 |
+
Return:
|
| 127 |
+
str: Input prompt
|
| 128 |
+
"""
|
| 129 |
+
assert gender in GENDER_MAP.keys()
|
| 130 |
+
assert pitch in LEVELS_MAP.keys()
|
| 131 |
+
assert speed in LEVELS_MAP.keys()
|
| 132 |
+
|
| 133 |
+
gender_id = GENDER_MAP[gender]
|
| 134 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
| 135 |
+
speed_level_id = LEVELS_MAP[speed]
|
| 136 |
+
|
| 137 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
| 138 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
| 139 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
| 140 |
+
|
| 141 |
+
attribte_tokens = "".join(
|
| 142 |
+
[gender_tokens, pitch_label_tokens, speed_label_tokens]
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
control_tts_inputs = [
|
| 146 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
| 147 |
+
"<|start_content|>",
|
| 148 |
+
text,
|
| 149 |
+
"<|end_content|>",
|
| 150 |
+
"<|start_style_label|>",
|
| 151 |
+
attribte_tokens,
|
| 152 |
+
"<|end_style_label|>",
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
return "".join(control_tts_inputs)
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def inference(
|
| 159 |
+
self,
|
| 160 |
+
text: str,
|
| 161 |
+
prompt_speech_path: Path = None,
|
| 162 |
+
prompt_text: str = None,
|
| 163 |
+
gender: str = None,
|
| 164 |
+
pitch: str = None,
|
| 165 |
+
speed: str = None,
|
| 166 |
+
temperature: float = 0.8,
|
| 167 |
+
top_k: float = 50,
|
| 168 |
+
top_p: float = 0.95,
|
| 169 |
+
) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
Performs inference to generate speech from text, incorporating prompt audio and/or text.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
text (str): The text input to be converted to speech.
|
| 175 |
+
prompt_speech_path (Path): Path to the audio file used as a prompt.
|
| 176 |
+
prompt_text (str, optional): Transcript of the prompt audio.
|
| 177 |
+
gender (str): female | male.
|
| 178 |
+
pitch (str): very_low | low | moderate | high | very_high
|
| 179 |
+
speed (str): very_low | low | moderate | high | very_high
|
| 180 |
+
temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
|
| 181 |
+
top_k (float, optional): Top-k sampling parameter. Default is 50.
|
| 182 |
+
top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
torch.Tensor: Generated waveform as a tensor.
|
| 186 |
+
"""
|
| 187 |
+
if gender is not None:
|
| 188 |
+
prompt = self.process_prompt_control(gender, pitch, speed, text)
|
| 189 |
+
|
| 190 |
+
else:
|
| 191 |
+
prompt, global_token_ids = self.process_prompt(
|
| 192 |
+
text, prompt_speech_path, prompt_text
|
| 193 |
+
)
|
| 194 |
+
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 195 |
|
| 196 |
# Generate speech using the model
|
| 197 |
generated_ids = self.model.generate(
|
|
|
|
| 205 |
|
| 206 |
# Trim the output tokens to remove the input tokens
|
| 207 |
generated_ids = [
|
| 208 |
+
output_ids[len(input_ids) :]
|
| 209 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 210 |
]
|
| 211 |
|
| 212 |
# Decode the generated tokens into text
|
| 213 |
predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 214 |
|
| 215 |
# Extract semantic token IDs from the generated text
|
| 216 |
+
pred_semantic_ids = (
|
| 217 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
|
| 218 |
+
.long()
|
| 219 |
+
.unsqueeze(0)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if gender is not None:
|
| 223 |
+
global_token_ids = (
|
| 224 |
+
torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
|
| 225 |
+
.long()
|
| 226 |
+
.unsqueeze(0)
|
| 227 |
+
.unsqueeze(0)
|
| 228 |
+
)
|
| 229 |
|
| 230 |
# Convert semantic tokens back to waveform
|
| 231 |
wav = self.audio_tokenizer.detokenize(
|
|
|
|
| 233 |
pred_semantic_ids.to(self.device),
|
| 234 |
)
|
| 235 |
|
| 236 |
+
return wav
|
cli/inference.py
CHANGED
|
@@ -12,16 +12,35 @@ def parse_args():
|
|
| 12 |
"""Parse command-line arguments."""
|
| 13 |
parser = argparse.ArgumentParser(description="Run TTS inference.")
|
| 14 |
|
| 15 |
-
parser.add_argument(
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
|
| 20 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 21 |
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
|
| 22 |
-
parser.add_argument(
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
return parser.parse_args()
|
| 26 |
|
| 27 |
|
|
@@ -47,14 +66,23 @@ def run_tts(args):
|
|
| 47 |
|
| 48 |
# Perform inference and save the output audio
|
| 49 |
with torch.no_grad():
|
| 50 |
-
wav = model.inference(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
sf.write(save_path, wav, samplerate=16000)
|
| 52 |
|
| 53 |
logging.info(f"Audio saved at: {save_path}")
|
| 54 |
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
-
logging.basicConfig(
|
|
|
|
|
|
|
| 58 |
|
| 59 |
args = parse_args()
|
| 60 |
run_tts(args)
|
|
|
|
| 12 |
"""Parse command-line arguments."""
|
| 13 |
parser = argparse.ArgumentParser(description="Run TTS inference.")
|
| 14 |
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--model_dir",
|
| 17 |
+
type=str,
|
| 18 |
+
default="pretrained_models/Spark-TTS-0.5B",
|
| 19 |
+
help="Path to the model directory",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--save_dir",
|
| 23 |
+
type=str,
|
| 24 |
+
default="example/results",
|
| 25 |
+
help="Directory to save generated audio files",
|
| 26 |
+
)
|
| 27 |
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--text", type=str, required=True, help="Text for TTS generation"
|
| 30 |
+
)
|
| 31 |
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--prompt_speech_path",
|
| 34 |
+
type=str,
|
| 35 |
+
help="Path to the prompt audio file",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--gender", choices=["male", "pitch"])
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
|
| 43 |
+
)
|
| 44 |
return parser.parse_args()
|
| 45 |
|
| 46 |
|
|
|
|
| 66 |
|
| 67 |
# Perform inference and save the output audio
|
| 68 |
with torch.no_grad():
|
| 69 |
+
wav = model.inference(
|
| 70 |
+
args.text,
|
| 71 |
+
args.prompt_speech_path,
|
| 72 |
+
prompt_text=args.prompt_text,
|
| 73 |
+
gender=args.gender,
|
| 74 |
+
pitch=args.pitch,
|
| 75 |
+
speed=args.speed,
|
| 76 |
+
)
|
| 77 |
sf.write(save_path, wav, samplerate=16000)
|
| 78 |
|
| 79 |
logging.info(f"Audio saved at: {save_path}")
|
| 80 |
|
| 81 |
|
| 82 |
if __name__ == "__main__":
|
| 83 |
+
logging.basicConfig(
|
| 84 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 85 |
+
)
|
| 86 |
|
| 87 |
args = parse_args()
|
| 88 |
run_tts(args)
|