File size: 4,936 Bytes
6f91e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Prediction interface for Cog ⚙️
# https://cog.run/python
import os
import subprocess
import time

import torch
import torchaudio
from cog import BasePredictor, Input, Path

from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
from boson_multimodal.serve.serve_engine import HiggsAudioResponse, HiggsAudioServeEngine


MODEL_PATH = "higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "higgs-audio-v2-tokenizer"
MODEL_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-generation-3B-base/model.tar"
TOKENIZER_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-tokenizer/model.tar"


def download_weights(url, dest):
    start = time.time()
    print("downloading url: ", url)
    print("downloading to: ", dest)
    subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
    print("downloading took: ", time.time() - start)


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""
        # Download weights
        if not os.path.exists(MODEL_PATH):
            download_weights(MODEL_URL, MODEL_PATH)
        if not os.path.exists(AUDIO_TOKENIZER_PATH):
            download_weights(TOKENIZER_URL, AUDIO_TOKENIZER_PATH)

        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")

        # Initialize the serve engine
        self.serve_engine = HiggsAudioServeEngine(
            MODEL_PATH,
            AUDIO_TOKENIZER_PATH,
            device=self.device)
        print("Higgs Audio V2 model loaded successfully")

    def predict(
        self,
        text: str = Input(
            description="Text to convert to speech",
            default="The sun rises in the east and sets in the west",
        ),
        temperature: float = Input(
            description="Controls randomness in generation. Lower values are more deterministic.",
            ge=0.1,
            le=1.0,
            default=0.3,
        ),
        top_p: float = Input(
            description="Nucleus sampling parameter. Controls diversity of generated audio.",
            ge=0.1,
            le=1.0,
            default=0.95,
        ),
        top_k: int = Input(
            description="Top-k sampling parameter. Limits vocabulary to top k tokens.", ge=1, le=100, default=50
        ),
        max_new_tokens: int = Input(
            description="Maximum number of audio tokens to generate", ge=256, le=2048, default=1024
        ),
        scene_description: str = Input(
            description="Scene description for audio context", default="Audio is recorded from a quiet room."
        ),
        system_message: str = Input(description="Custom system message (optional)", default=""),
        ref_audio: Path = Input(
            description="Reference audio file for voice cloning (optional). Supports WAV, MP3, etc.",
            default=None,
        ),
    ) -> Path:
        """Run a single prediction on the model"""
        try:
            # Construct system prompt
            if system_message:
                system_prompt = system_message
            else:
                system_prompt = f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_description}\n<|scene_desc_end|>"

            # Prepare messages
            messages = [
                Message(
                    role="system",
                    content=system_prompt,
                ),
            ]

            # Add reference audio message if provided (voice cloning)
            if ref_audio is not None:
                messages.append(
                    Message(
                        role="assistant",
                        content=AudioContent(audio_url=str(ref_audio)),
                    )
                )

            # Add user text message
            messages.append(
                Message(
                    role="user",
                    content=text,
                )
            )

            # Generate audio
            output: HiggsAudioResponse = self.serve_engine.generate(
                chat_ml_sample=ChatMLSample(messages=messages),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                stop_strings=["<|end_of_text|>", "<|eot_id|>"],
            )
            # Save output audio to a temporary file with a clear filename
            output_path = "/tmp/audio_output.wav"
            # Convert output audio to tensor and save
            audio_tensor = torch.from_numpy(output.audio)[None, :]
            torchaudio.save(output_path, audio_tensor, output.sampling_rate, format="wav")
            return Path(output_path)

        except Exception as e:
            raise RuntimeError(f"Audio generation failed: {str(e)}")