File size: 29,636 Bytes
d155c6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 |
#!/usr/bin/env python3
"""Interactive script for generating audio using HiggsAudio with single model load."""
import click
import soundfile as sf
import langid
import jieba
import os
import re
import copy
import torchaudio
import tqdm
import yaml
from loguru import logger
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
from boson_multimodal.data_types import Message, ChatMLSample, AudioContent, TextContent
from boson_multimodal.model.higgs_audio import HiggsAudioConfig, HiggsAudioModel
from boson_multimodal.data_collator.higgs_audio_collator import HiggsAudioSampleCollator
from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
from boson_multimodal.dataset.chatml_dataset import (
ChatMLDatasetSample,
prepare_chatml_sample,
)
from boson_multimodal.model.higgs_audio.utils import revert_delay_pattern
from typing import List
from transformers import AutoConfig, AutoTokenizer
from transformers.cache_utils import StaticCache
from typing import Optional
from dataclasses import asdict
import torch
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
If no speaker tag is present, select a suitable voice on your own."""
def normalize_chinese_punctuation(text):
"""
Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
"""
chinese_to_english_punct = {
",": ", ", # comma
"。": ".", # period
":": ":", # colon
";": ";", # semicolon
"?": "?", # question mark
"!": "!", # exclamation mark
"(": "(", # left parenthesis
")": ")", # right parenthesis
"【": "[", # left square bracket
"】": "]", # right square bracket
"《": "<", # left angle quote
"》": ">", # right angle quote
"“": '"', # left double quotation
"”": '"', # right double quotation
"‘": "'", # left single quotation
"’": "'", # right single quotation
"、": ",", # enumeration comma
"—": "-", # em dash
"…": "...", # ellipsis
"·": ".", # middle dot
"「": '"', # left corner bracket
"」": '"', # right corner bracket
"『": '"', # left double corner bracket
"』": '"', # right double corner bracket
}
for zh_punct, en_punct in chinese_to_english_punct.items():
text = text.replace(zh_punct, en_punct)
return text
def prepare_chunk_text(
text, chunk_method: Optional[str] = None, chunk_max_word_num: int = 100, chunk_max_num_turns: int = 1
):
"""Chunk the text into smaller pieces. We will later feed the chunks one by one to the model."""
if chunk_method is None:
return [text]
elif chunk_method == "speaker":
lines = text.split("\n")
speaker_chunks = []
speaker_utterance = ""
for line in lines:
line = line.strip()
if line.startswith("[SPEAKER") or line.startswith("<|speaker_id_start|>"):
if speaker_utterance:
speaker_chunks.append(speaker_utterance.strip())
speaker_utterance = line
else:
if speaker_utterance:
speaker_utterance += "\n" + line
else:
speaker_utterance = line
if speaker_utterance:
speaker_chunks.append(speaker_utterance.strip())
if chunk_max_num_turns > 1:
merged_chunks = []
for i in range(0, len(speaker_chunks), chunk_max_num_turns):
merged_chunk = "\n".join(speaker_chunks[i : i + chunk_max_num_turns])
merged_chunks.append(merged_chunk)
return merged_chunks
return speaker_chunks
elif chunk_method == "word":
language = langid.classify(text)[0]
paragraphs = text.split("\n\n")
chunks = []
for idx, paragraph in enumerate(paragraphs):
if language == "zh":
words = list(jieba.cut(paragraph, cut_all=False))
for i in range(0, len(words), chunk_max_word_num):
chunk = "".join(words[i : i + chunk_max_word_num])
chunks.append(chunk)
else:
words = paragraph.split(" ")
for i in range(0, len(words), chunk_max_word_num):
chunk = " ".join(words[i : i + chunk_max_word_num])
chunks.append(chunk)
chunks[-1] += "\n\n"
return chunks
else:
raise ValueError(f"Unknown chunk method: {chunk_method}")
def _build_system_message_with_audio_prompt(system_message):
contents = []
while AUDIO_PLACEHOLDER_TOKEN in system_message:
loc = system_message.find(AUDIO_PLACEHOLDER_TOKEN)
contents.append(TextContent(system_message[:loc]))
contents.append(AudioContent(audio_url=""))
system_message = system_message[loc + len(AUDIO_PLACEHOLDER_TOKEN) :]
if len(system_message) > 0:
contents.append(TextContent(system_message))
ret = Message(
role="system",
content=contents,
)
return ret
class HiggsAudioModelClient:
def __init__(
self,
model_path,
audio_tokenizer,
device=None,
device_id=None,
max_new_tokens=2048,
kv_cache_lengths: List[int] = [1024, 4096, 8192],
use_static_kv_cache=False,
):
if device_id is not None:
device = f"cuda:{device_id}"
self._device = device
else:
if device is not None:
self._device = device
else:
if torch.cuda.is_available():
self._device = "cuda:0"
elif torch.backends.mps.is_available():
self._device = "mps"
else:
self._device = "cpu"
logger.info(f"Using device: {self._device}")
if isinstance(audio_tokenizer, str):
audio_tokenizer_device = "cpu" if self._device == "mps" else self._device
self._audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer, device=audio_tokenizer_device)
else:
self._audio_tokenizer = audio_tokenizer
self._model = HiggsAudioModel.from_pretrained(
model_path,
device_map=self._device,
torch_dtype=torch.bfloat16,
)
self._model.eval()
self._kv_cache_lengths = kv_cache_lengths
self._use_static_kv_cache = use_static_kv_cache
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
self._config = AutoConfig.from_pretrained(model_path)
self._max_new_tokens = max_new_tokens
self._collator = HiggsAudioSampleCollator(
whisper_processor=None,
audio_in_token_id=self._config.audio_in_token_idx,
audio_out_token_id=self._config.audio_out_token_idx,
audio_stream_bos_id=self._config.audio_stream_bos_id,
audio_stream_eos_id=self._config.audio_stream_eos_id,
encode_whisper_embed=self._config.encode_whisper_embed,
pad_token_id=self._config.pad_token_id,
return_audio_in_tokens=self._config.encode_audio_in_tokens,
use_delay_pattern=self._config.use_delay_pattern,
round_to=1,
audio_num_codebooks=self._config.audio_num_codebooks,
)
self.kv_caches = None
if use_static_kv_cache:
self._init_static_kv_cache()
def _init_static_kv_cache(self):
cache_config = copy.deepcopy(self._model.config.text_config)
cache_config.num_hidden_layers = self._model.config.text_config.num_hidden_layers
if self._model.config.audio_dual_ffn_layers:
cache_config.num_hidden_layers += len(self._model.config.audio_dual_ffn_layers)
self.kv_caches = {
length: StaticCache(
config=cache_config,
max_batch_size=1,
max_cache_len=length,
device=self._model.device,
dtype=self._model.dtype,
)
for length in sorted(self._kv_cache_lengths)
}
if "cuda" in self._device:
logger.info(f"Capturing CUDA graphs for each KV cache length")
self._model.capture_model(self.kv_caches.values())
def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values():
kv_cache.reset()
@torch.inference_mode()
def generate(
self,
messages,
audio_ids,
chunked_text,
generation_chunk_buffer_size,
temperature=1.0,
top_k=50,
top_p=0.95,
ras_win_len=7,
ras_win_max_num_repeat=2,
seed=123,
*args,
**kwargs,
):
if ras_win_len is not None and ras_win_len <= 0:
ras_win_len = None
sr = 24000
audio_out_ids_l = []
generated_audio_ids = []
generation_messages = []
for idx, chunk_text in tqdm.tqdm(
enumerate(chunked_text), desc="Generating audio chunks", total=len(chunked_text)
):
generation_messages.append(
Message(
role="user",
content=chunk_text,
)
)
chatml_sample = ChatMLSample(messages=messages + generation_messages)
input_tokens, _, _, _ = prepare_chatml_sample(chatml_sample, self._tokenizer)
postfix = self._tokenizer.encode(
"<|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False
)
input_tokens.extend(postfix)
logger.info(f"========= Chunk {idx} Input =========")
logger.info(self._tokenizer.decode(input_tokens))
context_audio_ids = audio_ids + generated_audio_ids
curr_sample = ChatMLDatasetSample(
input_ids=torch.LongTensor(input_tokens),
label_ids=None,
audio_ids_concat=torch.concat([ele.cpu() for ele in context_audio_ids], dim=1)
if context_audio_ids
else None,
audio_ids_start=torch.cumsum(
torch.tensor([0] + [ele.shape[1] for ele in context_audio_ids], dtype=torch.long), dim=0
)
if context_audio_ids
else None,
audio_waveforms_concat=None,
audio_waveforms_start=None,
audio_sample_rate=None,
audio_speaker_indices=None,
)
batch_data = self._collator([curr_sample])
batch = asdict(batch_data)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.contiguous().to(self._device)
if self._use_static_kv_cache:
self._prepare_kv_caches()
outputs = self._model.generate(
**batch,
max_new_tokens=self._max_new_tokens,
use_cache=True,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
past_key_values_buckets=self.kv_caches,
ras_win_len=ras_win_len,
ras_win_max_num_repeat=ras_win_max_num_repeat,
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
tokenizer=self._tokenizer,
seed=seed,
)
step_audio_out_ids_l = []
for ele in outputs[1]:
audio_out_ids = ele
if self._config.use_delay_pattern:
audio_out_ids = revert_delay_pattern(audio_out_ids)
step_audio_out_ids_l.append(audio_out_ids.clip(0, self._audio_tokenizer.codebook_size - 1)[:, 1:-1])
audio_out_ids = torch.concat(step_audio_out_ids_l, dim=1)
audio_out_ids_l.append(audio_out_ids)
generated_audio_ids.append(audio_out_ids)
generation_messages.append(
Message(
role="assistant",
content=AudioContent(audio_url=""),
)
)
if generation_chunk_buffer_size is not None and len(generated_audio_ids) > generation_chunk_buffer_size:
generated_audio_ids = generated_audio_ids[-generation_chunk_buffer_size:]
generation_messages = generation_messages[(-2 * generation_chunk_buffer_size) :]
logger.info(f"========= Final Text output =========")
logger.info(self._tokenizer.decode(outputs[0][0]))
concat_audio_out_ids = torch.concat(audio_out_ids_l, dim=1)
if concat_audio_out_ids.device.type in ["mps", "cuda"]:
concat_audio_out_ids_cpu = concat_audio_out_ids.detach().cpu()
else:
concat_audio_out_ids_cpu = concat_audio_out_ids
concat_wv = self._audio_tokenizer.decode(concat_audio_out_ids_cpu.unsqueeze(0))[0, 0]
text_result = self._tokenizer.decode(outputs[0][0])
return concat_wv, sr, text_result
def prepare_generation_context(scene_prompt, ref_audio, ref_audio_in_system_message, audio_tokenizer, speaker_tags):
"""Prepare the context for generation."""
system_message = None
messages = []
audio_ids = []
if ref_audio is not None:
num_speakers = len(ref_audio.split(","))
speaker_info_l = ref_audio.split(",")
voice_profile = None
if any([speaker_info.startswith("profile:") for speaker_info in ref_audio.split(",")]):
ref_audio_in_system_message = True
if ref_audio_in_system_message:
speaker_desc = []
for spk_id, character_name in enumerate(speaker_info_l):
if character_name.startswith("profile:"):
if voice_profile is None:
with open(f"{CURR_DIR}/voice_prompts/profile.yaml", "r", encoding="utf-8") as f:
voice_profile = yaml.safe_load(f)
character_desc = voice_profile["profiles"][character_name[len("profile:") :].strip()]
speaker_desc.append(f"SPEAKER{spk_id}: {character_desc}")
else:
speaker_desc.append(f"SPEAKER{spk_id}: {AUDIO_PLACEHOLDER_TOKEN}")
if scene_prompt:
system_message = (
"Generate audio following instruction."
"\n\n"
f"<|scene_desc_start|>\n{scene_prompt}\n\n" + "\n".join(speaker_desc) + "\n<|scene_desc_end|>"
)
else:
system_message = (
"Generate audio following instruction.\n\n"
+ f"<|scene_desc_start|>\n"
+ "\n".join(speaker_desc)
+ "\n<|scene_desc_end|>"
)
system_message = _build_system_message_with_audio_prompt(system_message)
else:
if scene_prompt:
system_message = Message(
role="system",
content=f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>",
)
voice_profile = None
for spk_id, character_name in enumerate(ref_audio.split(",")):
if not character_name.startswith("profile:"):
prompt_audio_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.wav")
prompt_text_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.txt")
assert os.path.exists(prompt_audio_path), (
f"Voice prompt audio file {prompt_audio_path} does not exist."
)
assert os.path.exists(prompt_text_path), f"Voice prompt text file {prompt_text_path} does not exist."
with open(prompt_text_path, "r", encoding="utf-8") as f:
prompt_text = f.read().strip()
audio_tokens = audio_tokenizer.encode(prompt_audio_path)
audio_ids.append(audio_tokens)
if not ref_audio_in_system_message:
messages.append(
Message(
role="user",
content=f"[SPEAKER{spk_id}] {prompt_text}" if num_speakers > 1 else prompt_text,
)
)
messages.append(
Message(
role="assistant",
content=AudioContent(
audio_url=prompt_audio_path,
),
)
)
else:
if len(speaker_tags) > 1:
speaker_desc_l = []
for idx, tag in enumerate(speaker_tags):
if idx % 2 == 0:
speaker_desc = f"feminine"
else:
speaker_desc = f"masculine"
speaker_desc_l.append(f"{tag}: {speaker_desc}")
speaker_desc = "\n".join(speaker_desc_l)
scene_desc_l = []
if scene_prompt:
scene_desc_l.append(scene_prompt)
scene_desc_l.append(speaker_desc)
scene_desc = "\n\n".join(scene_desc_l)
system_message = Message(
role="system",
content=f"{MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE}\n\n<|scene_desc_start|>\n{scene_desc}\n<|scene_desc_end|>",
)
else:
system_message_l = ["Generate audio following instruction."]
if scene_prompt:
system_message_l.append(f"<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>")
system_message = Message(
role="system",
content="\n\n".join(system_message_l),
)
if system_message:
messages.insert(0, system_message)
return messages, audio_ids
def interactive_generation_loop(
model_client,
audio_tokenizer,
scene_prompt,
ref_audio,
ref_audio_in_system_message,
chunk_method,
chunk_max_word_num,
chunk_max_num_turns,
generation_chunk_buffer_size,
temperature,
top_k,
top_p,
ras_win_len,
ras_win_max_num_repeat,
seed,
output_dir,
):
"""Main interactive loop for audio generation."""
logger.info("Starting interactive generation mode. Enter 'quit' or 'exit' to stop.")
logger.info("Enter your transcript and press Enter to generate audio.")
generation_count = 0
while True:
try:
# Get user input
print("\n" + "="*50)
print("Enter transcript (or 'quit'/'exit' to stop):")
user_input = input("> ").strip()
if not user_input:
continue
if user_input.lower() in ['quit', 'exit']:
logger.info("Exiting interactive generation mode.")
break
transcript = user_input
# Process transcript
pattern = re.compile(r"\[(SPEAKER\d+)\]")
speaker_tags = sorted(set(pattern.findall(transcript)))
# Normalize transcript
transcript = normalize_chinese_punctuation(transcript)
transcript = transcript.replace("(", " ")
transcript = transcript.replace(")", " ")
transcript = transcript.replace("°F", " degrees Fahrenheit")
transcript = transcript.replace("°C", " degrees Celsius")
for tag, replacement in [
("[laugh]", "<SE>[Laughter]</SE>"),
("[humming start]", "<SE>[Humming]</SE>"),
("[humming end]", "<SE_e>[Humming]</SE_e>"),
("[music start]", "<SE_s>[Music]</SE_s>"),
("[music end]", "<SE_e>[Music]</SE_e>"),
("[music]", "<SE>[Music]</SE>"),
("[sing start]", "<SE_s>[Singing]</SE_s>"),
("[sing end]", "<SE_e>[Singing]</SE_e>"),
("[applause]", "<SE>[Applause]</SE>"),
("[cheering]", "<SE>[Cheering]</SE>"),
("[cough]", "<SE>[Cough]</SE>"),
]:
transcript = transcript.replace(tag, replacement)
lines = transcript.split("\n")
transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
transcript = transcript.strip()
if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
transcript += "."
# Prepare generation context
messages, audio_ids = prepare_generation_context(
scene_prompt=scene_prompt,
ref_audio=ref_audio,
ref_audio_in_system_message=ref_audio_in_system_message,
audio_tokenizer=audio_tokenizer,
speaker_tags=speaker_tags,
)
# Chunk text
chunked_text = prepare_chunk_text(
transcript,
chunk_method=chunk_method,
chunk_max_word_num=chunk_max_word_num,
chunk_max_num_turns=chunk_max_num_turns,
)
logger.info("Chunks used for generation:")
for idx, chunk_text in enumerate(chunked_text):
logger.info(f"Chunk {idx}:")
logger.info(chunk_text)
logger.info("-----")
# Generate audio
logger.info(f"Generating audio for input: {transcript[:50]}...")
concat_wv, sr, text_output = model_client.generate(
messages=messages,
audio_ids=audio_ids,
chunked_text=chunked_text,
generation_chunk_buffer_size=generation_chunk_buffer_size,
temperature=temperature,
top_k=top_k,
top_p=top_p,
ras_win_len=ras_win_len,
ras_win_max_num_repeat=ras_win_max_num_repeat,
seed=seed,
)
# Save audio file
generation_count += 1
output_filename = f"generation_{generation_count:03d}.wav"
output_path = os.path.join(output_dir, output_filename)
sf.write(output_path, concat_wv, sr)
logger.info(f"Audio saved to: {output_path}")
print(f"✓ Audio generated and saved to: {output_filename}")
except KeyboardInterrupt:
logger.info("\nInterrupted by user. Exiting...")
break
except Exception as e:
logger.error(f"Error during generation: {e}")
print(f"✗ Error: {e}")
continue
@click.command()
@click.option(
"--model_path",
type=str,
default="./higgs-audio-v2-generation-3B-base",
help="Path to the model directory.",
)
@click.option(
"--audio_tokenizer",
type=str,
default="./higgs-audio-v2-tokenizer",
help="Path to the audio tokenizer directory.",
)
@click.option(
"--max_new_tokens",
type=int,
default=2048,
help="The maximum number of new tokens to generate.",
)
@click.option(
"--scene_prompt",
type=str,
default=f"{CURR_DIR}/scene_prompts/quiet_indoor.txt",
help="The scene description prompt to use for generation. If not set, or set to `empty`, we will leave it to empty.",
)
@click.option(
"--temperature",
type=float,
default=1.0,
help="The value used to module the next token probabilities.",
)
@click.option(
"--top_k",
type=int,
default=50,
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
)
@click.option(
"--top_p",
type=float,
default=0.95,
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
)
@click.option(
"--ras_win_len",
type=int,
default=7,
help="The window length for RAS sampling. If set to 0 or a negative value, we won't use RAS sampling.",
)
@click.option(
"--ras_win_max_num_repeat",
type=int,
default=2,
help="The maximum number of times to repeat the RAS window. Only used when --ras_win_len is set.",
)
@click.option(
"--ref_audio",
type=str,
default=None,
help="The voice prompt to use for generation. If not set, we will let the model randomly pick a voice.",
)
@click.option(
"--ref_audio_in_system_message",
is_flag=True,
default=False,
help="Whether to include the voice prompt description in the system message.",
show_default=True,
)
@click.option(
"--chunk_method",
default=None,
type=click.Choice([None, "speaker", "word"]),
help="The method to use for chunking the prompt text.",
)
@click.option(
"--chunk_max_word_num",
default=200,
type=int,
help="The maximum number of words for each chunk when 'word' chunking method is used.",
)
@click.option(
"--chunk_max_num_turns",
default=1,
type=int,
help="The maximum number of turns for each chunk when 'speaker' chunking method is used.",
)
@click.option(
"--generation_chunk_buffer_size",
default=None,
type=int,
help="The maximal number of chunks to keep in the buffer.",
)
@click.option(
"--seed",
default=None,
type=int,
help="Random seed for generation.",
)
@click.option(
"--device_id",
type=int,
default=None,
help="The device to run the model on.",
)
@click.option(
"--output_dir",
type=str,
default="./interactive_outputs",
help="Directory to save generated audio files.",
)
@click.option(
"--use_static_kv_cache",
type=int,
default=1,
help="Whether to use static KV cache for faster generation. Only works when using GPU.",
)
@click.option(
"--device",
type=click.Choice(["auto", "cuda", "mps", "none"]),
default="auto",
help="Device to use: 'auto' (pick best available), 'cuda', 'mps', or 'none' (CPU only).",
)
def main(
model_path,
audio_tokenizer,
max_new_tokens,
scene_prompt,
temperature,
top_k,
top_p,
ras_win_len,
ras_win_max_num_repeat,
ref_audio,
ref_audio_in_system_message,
chunk_method,
chunk_max_word_num,
chunk_max_num_turns,
generation_chunk_buffer_size,
seed,
device_id,
output_dir,
use_static_kv_cache,
device,
):
"""Interactive audio generation - model loads once, generates multiple times."""
# Setup device
if device_id is None:
if device == "auto":
if torch.cuda.is_available():
device_id = 0
device = "cuda:0"
elif torch.backends.mps.is_available():
device_id = None
device = "mps"
else:
device_id = None
device = "cpu"
elif device == "cuda":
device_id = 0
device = "cuda:0"
elif device == "mps":
device_id = None
device = "mps"
else:
device_id = None
device = "cpu"
else:
device = f"cuda:{device_id}"
# For MPS, use CPU for audio tokenizer
audio_tokenizer_device = "cpu" if device == "mps" else device
audio_tokenizer_obj = load_higgs_audio_tokenizer(audio_tokenizer, device=audio_tokenizer_device)
# Disable static KV cache on MPS
if device == "mps" and use_static_kv_cache:
use_static_kv_cache = False
# Create output directory
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Output directory: {output_dir}")
# Load scene prompt if file exists
if scene_prompt is not None and scene_prompt != "empty" and os.path.exists(scene_prompt):
with open(scene_prompt, "r", encoding="utf-8") as f:
scene_prompt = f.read().strip()
else:
scene_prompt = None
# Initialize model client (loads model once)
logger.info("Loading model... This may take a while.")
model_client = HiggsAudioModelClient(
model_path=model_path,
audio_tokenizer=audio_tokenizer_obj,
device=device,
device_id=device_id,
max_new_tokens=max_new_tokens,
use_static_kv_cache=use_static_kv_cache,
)
logger.info("Model loaded successfully!")
# Start interactive generation loop
interactive_generation_loop(
model_client=model_client,
audio_tokenizer=audio_tokenizer_obj,
scene_prompt=scene_prompt,
ref_audio=ref_audio,
ref_audio_in_system_message=ref_audio_in_system_message,
chunk_method=chunk_method,
chunk_max_word_num=chunk_max_word_num,
chunk_max_num_turns=chunk_max_num_turns,
generation_chunk_buffer_size=generation_chunk_buffer_size,
temperature=temperature,
top_k=top_k,
top_p=top_p,
ras_win_len=ras_win_len,
ras_win_max_num_repeat=ras_win_max_num_repeat,
seed=seed,
output_dir=output_dir,
)
if __name__ == "__main__":
main() |