Setup and Inference Code
Inference speed: Stupid fast.
CODE (SETUP INSTRUCTIONS ARE UNDERNEATH):
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import numpy as np
import soundfile as sf
model_id = "/path/to/your/weights/FlashLabs/Chroma-4B"
# ----- MODEL LOADING
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map={"": 0},
dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# ----- PROMPT CONFIG
system_prompt = (
"You are Chroma, an advanced virtual human created by the FlashLabs. "
"You possess the ability to understand auditory inputs and generate both text and speech."
)
conversation = [[
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [{"type": "audio", "audio": "example/make_taco.wav"}],
},
]]
# ------ SETUP
def load_prompt(speaker_name):
text_path = f"example/prompt_text/{speaker_name}.txt"
audio_path = f"example/prompt_audio/{speaker_name}.wav"
with open(text_path, "r", encoding="utf-8") as f:
prompt_text = f.read()
return [prompt_text], [audio_path]
prompt_text, prompt_audio = load_prompt("donald_trump") # find other speaker in example/
inputs = processor(
conversation,
add_generation_prompt=True,
tokenize=False,
prompt_audio=prompt_audio,
prompt_text=prompt_text,
)
device = model.device
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(device)
for k in ("input_values", "prompt_input_values"):
if k in inputs and torch.is_tensor(inputs[k]):
inputs[k] = inputs[k].to(dtype=model.dtype)
# CONTROLLING SEED
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# ------ INFERENCE
output = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
use_cache=True,
)
# DECODING
audio_values = model.codec_model.decode(output.permute(0, 2, 1)).audio_values
# ------ SAVING THE OUTPUT
av = audio_values
a = audio_values[0, 0].float().detach().cpu().numpy()
a = np.asarray(a, dtype=np.float32)
sf.write("chroma_out.wav", a, 24_000, subtype="PCM_16")
ENVIRONMENT
git clone https://github.com/FlashLabs-AI-Corp/FlashLabs-Chroma.git
cd FlashLabs-Chroma
python3 -m venv .venv
# activate your environment
. ./venv/bin/activate
(I'm using python 3.12)
Some general python stuff you should always install before getting into you venv:python -m pip install -U pip setuptools wheel
Installing Torch
torch (must be minimum 2.7.1, I believe. I used 2.10.0+cu130)pip3 install --index-url https://download.pytorch.org/whl/cu130 torch torchvision torchaudio
Find other versions here: https://pytorch.org/get-started/previous-versions/
transformers: make sure rc1 or newer.
python -m pip install transformers==5.0.0rc1
Ensure to update to the full release once transformers 5.0 drops.
Audio processing
python -m pip install "av>=14.0.0" "librosa>=0.11.0" "audioread>=3.0.0" "soundfile>=0.13.0"
Other dependencies
python -m pip install "pillow>=11.0.0" "accelerate>=1.7.0" "numpy>=2.2.0" "safetensors>=0.5.0" "huggingface-hub>=1.3.0"
POSSIBLES (transformers may want these but run first and see if it demands it, I was running through different versions so not entirely sure):
pip install protobuf
python -m pip install -U sentencepiece tiktoken tokenizers
python -m pip install -U torchcodec
WEIGHTS CAN BE FOUND HERE:
https://huggingface.co/FlashLabs/Chroma-4B
Note, you don't technically need to clone the github repo, the code is running via torch + transformers.