File size: 4,936 Bytes
6f91e60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Prediction interface for Cog ⚙️
# https://cog.run/python
import os
import subprocess
import time
import torch
import torchaudio
from cog import BasePredictor, Input, Path
from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
from boson_multimodal.serve.serve_engine import HiggsAudioResponse, HiggsAudioServeEngine
MODEL_PATH = "higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "higgs-audio-v2-tokenizer"
MODEL_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-generation-3B-base/model.tar"
TOKENIZER_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-tokenizer/model.tar"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# Download weights
if not os.path.exists(MODEL_PATH):
download_weights(MODEL_URL, MODEL_PATH)
if not os.path.exists(AUDIO_TOKENIZER_PATH):
download_weights(TOKENIZER_URL, AUDIO_TOKENIZER_PATH)
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Initialize the serve engine
self.serve_engine = HiggsAudioServeEngine(
MODEL_PATH,
AUDIO_TOKENIZER_PATH,
device=self.device)
print("Higgs Audio V2 model loaded successfully")
def predict(
self,
text: str = Input(
description="Text to convert to speech",
default="The sun rises in the east and sets in the west",
),
temperature: float = Input(
description="Controls randomness in generation. Lower values are more deterministic.",
ge=0.1,
le=1.0,
default=0.3,
),
top_p: float = Input(
description="Nucleus sampling parameter. Controls diversity of generated audio.",
ge=0.1,
le=1.0,
default=0.95,
),
top_k: int = Input(
description="Top-k sampling parameter. Limits vocabulary to top k tokens.", ge=1, le=100, default=50
),
max_new_tokens: int = Input(
description="Maximum number of audio tokens to generate", ge=256, le=2048, default=1024
),
scene_description: str = Input(
description="Scene description for audio context", default="Audio is recorded from a quiet room."
),
system_message: str = Input(description="Custom system message (optional)", default=""),
ref_audio: Path = Input(
description="Reference audio file for voice cloning (optional). Supports WAV, MP3, etc.",
default=None,
),
) -> Path:
"""Run a single prediction on the model"""
try:
# Construct system prompt
if system_message:
system_prompt = system_message
else:
system_prompt = f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_description}\n<|scene_desc_end|>"
# Prepare messages
messages = [
Message(
role="system",
content=system_prompt,
),
]
# Add reference audio message if provided (voice cloning)
if ref_audio is not None:
messages.append(
Message(
role="assistant",
content=AudioContent(audio_url=str(ref_audio)),
)
)
# Add user text message
messages.append(
Message(
role="user",
content=text,
)
)
# Generate audio
output: HiggsAudioResponse = self.serve_engine.generate(
chat_ml_sample=ChatMLSample(messages=messages),
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
)
# Save output audio to a temporary file with a clear filename
output_path = "/tmp/audio_output.wav"
# Convert output audio to tensor and save
audio_tensor = torch.from_numpy(output.audio)[None, :]
torchaudio.save(output_path, audio_tensor, output.sampling_rate, format="wav")
return Path(output_path)
except Exception as e:
raise RuntimeError(f"Audio generation failed: {str(e)}")
|