JAM / model.py
thejagstudio's picture
Update model.py
df42459 verified
#!/usr/bin/env python3
"""
Generate audio using JAM model
Reads from filtered test set and generates audio using CFM+DiT model.
"""
import os
import glob
import time
import json
import random
import sys
from huggingface_hub import snapshot_download
import torch
import torchaudio
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import accelerate
import pyloudnorm as pyln
from safetensors.torch import load_file
from muq import MuQMuLan
import numpy as np
from accelerate import Accelerator
from jam.dataset import enhance_webdataset_config, DiffusionWebDataset
from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE
# DiffRhythm imports for CFM+DiT model
from jam.model import CFM, DiT
def get_negative_style_prompt(device, file_path):
if not os.path.exists(file_path):
# Fallback if resource not found
return torch.zeros(1, 512).to(device).float()
vocal_stlye = np.load(file_path)
vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
# Only use half precision on CUDA
if str(device) == 'cuda':
vocal_stlye = vocal_stlye.half()
else:
vocal_stlye = vocal_stlye.float()
return vocal_stlye
def normalize_audio(audio, normalize_lufs=True):
audio = audio - audio.mean(-1, keepdim=True)
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
if normalize_lufs:
meter = pyln.Meter(rate=44100)
target_lufs = -14.0
# pyln expects (samples, channels) numpy array
try:
loudness = meter.integrated_loudness(audio.transpose(0, 1).cpu().numpy())
normalised = pyln.normalize.loudness(audio.transpose(0, 1).cpu().numpy(), loudness, target_lufs)
normalised = torch.from_numpy(normalised).transpose(0, 1)
except Exception as e:
# Silently fail back to unnormalized if silence/error
normalised = audio
else:
normalised = audio
return normalised
class FilteredTestSetDataset(Dataset):
"""Custom dataset for loading from filtered test set JSON"""
def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False):
with open(test_set_path, 'r') as f:
self.test_samples = json.load(f)
if num_samples is not None:
self.test_samples = self.test_samples[:num_samples]
self.diffusion_dataset = diffusion_dataset
self.muq_model = muq_model
self.random_crop_style = random_crop_style
self.num_style_secs = num_style_secs
self.use_prompt_style = use_prompt_style
if self.use_prompt_style:
print("Using prompt style instead of audio style.")
def __len__(self):
return len(self.test_samples)
def __getitem__(self, idx):
test_sample = self.test_samples[idx]
sample_id = test_sample["id"]
# Load LRC data
lrc_path = test_sample["lrc_path"]
with open(lrc_path, 'r') as f:
lrc_data = json.load(f)
if 'word' not in lrc_data:
data = {'word': lrc_data}
lrc_data = data
# Generate style embedding
if self.use_prompt_style:
prompt_content = test_sample.get("prompt_path", "")
if prompt_content:
# Truncate prompt if too long
if len(prompt_content) > 300:
prompt_content = prompt_content[:300]
style_embedding = self.muq_model(texts=[prompt_content]).squeeze(0)
else:
style_embedding = torch.zeros(512).to(self.muq_model.device)
else:
audio_path = test_sample.get("audio_path")
if audio_path and os.path.exists(audio_path):
style_embedding = self.generate_style_embedding(audio_path)
else:
style_embedding = torch.zeros(512).to(self.muq_model.device)
duration = test_sample["duration"]
# Create fake latent with correct length
# Assuming frame_rate from config (typically 21.5 fps for 44.1kHz)
frame_rate = 21.5
num_frames = int(duration * frame_rate)
fake_latent = torch.randn(128, num_frames) # 128 is latent dim
# Create sample tuple matching DiffusionWebDataset format
fake_sample = (
sample_id,
fake_latent, # latent with correct duration
style_embedding, # style from actual audio
lrc_data # actual LRC data
)
# Process through DiffusionWebDataset's process_sample_safely
processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample)
# Add metadata
if processed_sample is not None:
processed_sample['test_metadata'] = {
'sample_id': sample_id,
'audio_path': test_sample.get("audio_path"),
'lrc_path': lrc_path,
'duration': duration,
'num_frames': num_frames
}
return processed_sample
def generate_style_embedding(self, audio_path):
"""Generate style embedding using MuQ model on the whole music"""
try:
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 24kHz if needed (MuQ expects 24kHz)
if sample_rate != 24000:
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono
waveform = waveform.squeeze(0) # Now shape is (time,)
# Move to same device as model
waveform = waveform.to(self.muq_model.device)
# Generate embedding using MuQ model
with torch.inference_mode():
# MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim)
if self.random_crop_style:
total_samples = waveform.shape[0]
target_samples = 24000 * self.num_style_secs
if total_samples > target_samples:
start_idx = random.randint(0, total_samples - target_samples)
wav_input = waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples]
else:
wav_input = waveform.unsqueeze(0)
style_embedding = self.muq_model(wavs=wav_input)
else:
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs])
return style_embedding[0]
except Exception as e:
print(f"Error generating style embedding from {audio_path}: {e}")
return torch.zeros(512).to(self.muq_model.device)
def custom_collate_fn_with_metadata(batch, base_collate_fn):
"""Custom collate function that preserves test_metadata"""
batch = [item for item in batch if item is not None]
if not batch:
return None
test_metadata = [item.pop('test_metadata') for item in batch]
collated = base_collate_fn(batch)
if collated is not None:
collated['test_metadata'] = test_metadata
return collated
def load_model(model_config, checkpoint_path, device):
"""
Load JAM CFM model from checkpoint (follows infer.py pattern)
"""
dit_config = model_config["dit"].copy()
if "text_num_embeds" not in dit_config:
dit_config["text_num_embeds"] = 256
cfm = CFM(
transformer=DiT(**dit_config),
**model_config["cfm"]
)
cfm = cfm.to(device)
# Load checkpoint
checkpoint = load_file(checkpoint_path)
cfm.load_state_dict(checkpoint, strict=False)
return cfm.eval()
def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'):
"""
Generate latent from batch data (follows infer.py pattern)
"""
with torch.inference_mode():
batch_size = len(batch["lrc"])
text = batch["lrc"].to(device)
style_prompt = batch["prompt"].to(device)
start_time = batch["start_time"].to(device)
duration_abs = batch["duration_abs"].to(device)
duration_rel = batch["duration_rel"].to(device)
max_frames = model.max_frames
cond = torch.zeros(batch_size, max_frames, 64).to(text.device)
pred_frames = [(0, max_frames)]
default_sample_kwargs = {
"cfg_strength": 4,
"steps": 50,
"batch_infer_num": 1
}
sample_kwargs = {**default_sample_kwargs, **sample_kwargs}
if negative_style_prompt_path is None:
# Fallback path, or ensure file exists
negative_style_prompt_path = 'public/vocal.npy'
if negative_style_prompt_path == 'zeros':
negative_style_prompt = torch.zeros(1, 512).to(text.device)
else:
negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
latents, _ = model.sample(
cond=cond,
text=text,
style_prompt=negative_style_prompt if ignore_style else style_prompt,
duration_abs=duration_abs,
duration_rel=duration_rel,
negative_style_prompt=negative_style_prompt,
start_time=start_time,
latent_pred_segments=pred_frames,
**sample_kwargs
)
return latents
class Jamify:
def __init__(self):
os.makedirs('outputs', exist_ok=True)
# ---------------------------------------------------------
# FIX: Automatically detect CPU vs CUDA
# ---------------------------------------------------------
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Initializing Jamify model on: {self.device}")
config_path = 'jam_infer.yaml'
self.config = OmegaConf.load(config_path)
OmegaConf.resolve(self.config)
print("Downloading main model checkpoint...")
try:
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
except Exception as e:
print(f"Failed to download model (might be offline): {e}")
# Load VAE
vae_type = self.config.evaluation.get('vae_type', 'stable_audio')
if vae_type == 'diffrhythm':
self.vae = DiffRhythmVAE(device=self.device).to(self.device)
else:
self.vae = StableAudioOpenVAE().to(self.device)
self.vae_type = vae_type
# Load CFM
self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, self.device)
# Load MuQ (ensure float32 on CPU)
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
if self.device == 'cpu':
self.muq_model = self.muq_model.float()
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
enhance_webdataset_config(dataset_cfg)
dataset_cfg.multiple_styles = False
self.base_dataset = DiffusionWebDataset(**dataset_cfg)
def cleanup_old_files(self, sample_id):
old_mp3_files = sorted(glob.glob("outputs/*.mp3"))
if len(old_mp3_files) >= 10:
for old_file in old_mp3_files[:-9]:
try:
os.remove(old_file)
except OSError:
pass
try:
os.unlink(f"outputs/{sample_id}.json")
except OSError:
pass
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration):
sample_id = str(int(time.time() * 1000000))
test_set = [{
"id": sample_id,
"audio_path": reference_audio_path,
"lrc_path": lyrics_json_path,
"duration": duration,
"prompt_path": style_prompt
}]
json_path = f"outputs/{sample_id}.json"
with open(json_path, "w") as f:
json.dump(test_set, f)
try:
test_dataset = FilteredTestSetDataset(
test_set_path=json_path,
diffusion_dataset=self.base_dataset,
muq_model=self.muq_model,
num_samples=1,
random_crop_style=self.config.evaluation.random_crop_style,
num_style_secs=self.config.evaluation.num_style_secs,
use_prompt_style=self.config.evaluation.use_prompt_style
)
dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn)
)
try:
batch = next(iter(dataloader))
except StopIteration:
raise ValueError("Data loader returned empty batch. Check inputs.")
sample_kwargs = self.config.evaluation.sample_kwargs
# Generate latents
generated_latents = generate_latent(
self.cfm_model,
batch,
sample_kwargs,
self.config.evaluation.negative_style_prompt,
self.config.evaluation.ignore_style,
device=self.device
)
# FIX: Correct indexing [0] for single batch item
latent = generated_latents[0]
test_metadata = batch['test_metadata'][0]
original_duration = test_metadata['duration']
# Decode
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
use_chunked = self.config.evaluation.get('use_chunked_decoding', True)
if self.vae_type == 'diffrhythm' and use_chunked:
# DiffRhythm chunked decode
pred_audio = self.vae.decode(
latent_for_vae,
chunked=True,
overlap=self.config.evaluation.get('chunked_overlap', 32),
chunk_size=self.config.evaluation.get('chunked_size', 128)
).sample.squeeze(0).detach().cpu()
else:
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
pred_audio = normalize_audio(pred_audio)
sample_rate = 44100
trim_samples = int(original_duration * sample_rate)
if pred_audio.shape[1] > trim_samples:
pred_audio_trimmed = pred_audio[:, :trim_samples]
else:
pred_audio_trimmed = pred_audio
output_path = f'outputs/{sample_id}.mp3'
torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3")
self.cleanup_old_files(sample_id)
return output_path
finally:
if os.path.exists(json_path):
os.unlink(json_path)