File size: 1,515 Bytes
6f91e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Example for using HiggsAudio for generating both the transcript and audio in an interleaved manner."""

from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
import torch
import torchaudio
import time
from loguru import logger
import click

from input_samples import INPUT_SAMPLES

MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"


@click.command()
@click.argument("example", type=click.Choice(list(INPUT_SAMPLES.keys())))
def main(example: str):
    input_sample = INPUT_SAMPLES[example]()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    serve_engine = HiggsAudioServeEngine(
        MODEL_PATH,
        AUDIO_TOKENIZER_PATH,
        device=device,
    )

    logger.info("Starting generation...")
    start_time = time.time()
    output: HiggsAudioResponse = serve_engine.generate(
        chat_ml_sample=input_sample,
        max_new_tokens=1024,
        temperature=1.0,
        top_p=0.95,
        top_k=50,
        stop_strings=["<|end_of_text|>", "<|eot_id|>"],
    )
    elapsed_time = time.time() - start_time
    logger.info(f"Generation time: {elapsed_time:.2f} seconds")

    torchaudio.save(f"output_{example}.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
    logger.info(f"Generated text:\n{output.generated_text}")
    logger.info(f"Saved audio to output_{example}.wav")


if __name__ == "__main__":
    main()