Setup and Inference Code

#11
by E10H1M - opened


Inference speed: Stupid fast.

CODE (SETUP INSTRUCTIONS ARE UNDERNEATH):

import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import numpy as np
import soundfile as sf

model_id = "/path/to/your/weights/FlashLabs/Chroma-4B"

# ----- MODEL LOADING 
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map={"": 0},
    dtype=torch.bfloat16,
)

processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)




# ----- PROMPT CONFIG
system_prompt = (
    "You are Chroma, an advanced virtual human created by the FlashLabs. "
    "You possess the ability to understand auditory inputs and generate both text and speech."
)

conversation = [[
    {
        "role": "system",
        "content": [{"type": "text", "text": system_prompt}],
    },
    {
        "role": "user",
        "content": [{"type": "audio", "audio": "example/make_taco.wav"}],
    },
]]



# ------ SETUP
def load_prompt(speaker_name):
    text_path = f"example/prompt_text/{speaker_name}.txt"
    audio_path = f"example/prompt_audio/{speaker_name}.wav"
    with open(text_path, "r", encoding="utf-8") as f:
        prompt_text = f.read()
    return [prompt_text], [audio_path]

prompt_text, prompt_audio = load_prompt("donald_trump") # find other speaker in example/

inputs = processor(
    conversation,
    add_generation_prompt=True,
    tokenize=False,
    prompt_audio=prompt_audio,
    prompt_text=prompt_text,
)

device = model.device
for k, v in inputs.items():
    if torch.is_tensor(v):
        inputs[k] = v.to(device)

for k in ("input_values", "prompt_input_values"):
    if k in inputs and torch.is_tensor(inputs[k]):
        inputs[k] = inputs[k].to(dtype=model.dtype)


# CONTROLLING SEED
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)




# ------ INFERENCE
output = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    use_cache=True,
)



# DECODING
audio_values = model.codec_model.decode(output.permute(0, 2, 1)).audio_values

# ------ SAVING THE OUTPUT
av = audio_values
a = audio_values[0, 0].float().detach().cpu().numpy()
a = np.asarray(a, dtype=np.float32)

sf.write("chroma_out.wav", a, 24_000, subtype="PCM_16")

ENVIRONMENT

git clone https://github.com/FlashLabs-AI-Corp/FlashLabs-Chroma.git
cd FlashLabs-Chroma
python3 -m venv .venv

# activate your environment
. ./venv/bin/activate

(I'm using python 3.12)

Some general python stuff you should always install before getting into you venv:
python -m pip install -U pip setuptools wheel

Installing Torch

torch (must be minimum 2.7.1, I believe. I used 2.10.0+cu130)
pip3 install --index-url https://download.pytorch.org/whl/cu130 torch torchvision torchaudio

Find other versions here: https://pytorch.org/get-started/previous-versions/

transformers: make sure rc1 or newer.

python -m pip install transformers==5.0.0rc1

Ensure to update to the full release once transformers 5.0 drops.

Audio processing

python -m pip install "av>=14.0.0" "librosa>=0.11.0" "audioread>=3.0.0" "soundfile>=0.13.0"

Other dependencies

python -m pip install "pillow>=11.0.0" "accelerate>=1.7.0" "numpy>=2.2.0" "safetensors>=0.5.0" "huggingface-hub>=1.3.0"

POSSIBLES (transformers may want these but run first and see if it demands it, I was running through different versions so not entirely sure):

pip install protobuf
python -m pip install -U sentencepiece tiktoken tokenizers
python -m pip install -U torchcodec

WEIGHTS CAN BE FOUND HERE:

https://huggingface.co/FlashLabs/Chroma-4B

Note, you don't technically need to clone the github repo, the code is running via torch + transformers.

This comment has been hidden (marked as Resolved)

Sign up or log in to comment