File size: 1,515 Bytes
6f91e60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
"""Example for using HiggsAudio for generating both the transcript and audio in an interleaved manner."""
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
import torch
import torchaudio
import time
from loguru import logger
import click
from input_samples import INPUT_SAMPLES
MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
@click.command()
@click.argument("example", type=click.Choice(list(INPUT_SAMPLES.keys())))
def main(example: str):
input_sample = INPUT_SAMPLES[example]()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
serve_engine = HiggsAudioServeEngine(
MODEL_PATH,
AUDIO_TOKENIZER_PATH,
device=device,
)
logger.info("Starting generation...")
start_time = time.time()
output: HiggsAudioResponse = serve_engine.generate(
chat_ml_sample=input_sample,
max_new_tokens=1024,
temperature=1.0,
top_p=0.95,
top_k=50,
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
)
elapsed_time = time.time() - start_time
logger.info(f"Generation time: {elapsed_time:.2f} seconds")
torchaudio.save(f"output_{example}.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
logger.info(f"Generated text:\n{output.generated_text}")
logger.info(f"Saved audio to output_{example}.wav")
if __name__ == "__main__":
main()
|