#!/usr/bin/env python3 """ Generate audio using JAM model Reads from filtered test set and generates audio using CFM+DiT model. """ import os import glob import time import json import random import sys from huggingface_hub import snapshot_download import torch import torchaudio from omegaconf import OmegaConf from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm import accelerate import pyloudnorm as pyln from safetensors.torch import load_file from muq import MuQMuLan import numpy as np from accelerate import Accelerator from jam.dataset import enhance_webdataset_config, DiffusionWebDataset from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE # DiffRhythm imports for CFM+DiT model from jam.model import CFM, DiT def get_negative_style_prompt(device, file_path): if not os.path.exists(file_path): # Fallback if resource not found return torch.zeros(1, 512).to(device).float() vocal_stlye = np.load(file_path) vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] # Only use half precision on CUDA if str(device) == 'cuda': vocal_stlye = vocal_stlye.half() else: vocal_stlye = vocal_stlye.float() return vocal_stlye def normalize_audio(audio, normalize_lufs=True): audio = audio - audio.mean(-1, keepdim=True) audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) if normalize_lufs: meter = pyln.Meter(rate=44100) target_lufs = -14.0 # pyln expects (samples, channels) numpy array try: loudness = meter.integrated_loudness(audio.transpose(0, 1).cpu().numpy()) normalised = pyln.normalize.loudness(audio.transpose(0, 1).cpu().numpy(), loudness, target_lufs) normalised = torch.from_numpy(normalised).transpose(0, 1) except Exception as e: # Silently fail back to unnormalized if silence/error normalised = audio else: normalised = audio return normalised class FilteredTestSetDataset(Dataset): """Custom dataset for loading from filtered test set JSON""" def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False): with open(test_set_path, 'r') as f: self.test_samples = json.load(f) if num_samples is not None: self.test_samples = self.test_samples[:num_samples] self.diffusion_dataset = diffusion_dataset self.muq_model = muq_model self.random_crop_style = random_crop_style self.num_style_secs = num_style_secs self.use_prompt_style = use_prompt_style if self.use_prompt_style: print("Using prompt style instead of audio style.") def __len__(self): return len(self.test_samples) def __getitem__(self, idx): test_sample = self.test_samples[idx] sample_id = test_sample["id"] # Load LRC data lrc_path = test_sample["lrc_path"] with open(lrc_path, 'r') as f: lrc_data = json.load(f) if 'word' not in lrc_data: data = {'word': lrc_data} lrc_data = data # Generate style embedding if self.use_prompt_style: prompt_content = test_sample.get("prompt_path", "") if prompt_content: # Truncate prompt if too long if len(prompt_content) > 300: prompt_content = prompt_content[:300] style_embedding = self.muq_model(texts=[prompt_content]).squeeze(0) else: style_embedding = torch.zeros(512).to(self.muq_model.device) else: audio_path = test_sample.get("audio_path") if audio_path and os.path.exists(audio_path): style_embedding = self.generate_style_embedding(audio_path) else: style_embedding = torch.zeros(512).to(self.muq_model.device) duration = test_sample["duration"] # Create fake latent with correct length # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz) frame_rate = 21.5 num_frames = int(duration * frame_rate) fake_latent = torch.randn(128, num_frames) # 128 is latent dim # Create sample tuple matching DiffusionWebDataset format fake_sample = ( sample_id, fake_latent, # latent with correct duration style_embedding, # style from actual audio lrc_data # actual LRC data ) # Process through DiffusionWebDataset's process_sample_safely processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample) # Add metadata if processed_sample is not None: processed_sample['test_metadata'] = { 'sample_id': sample_id, 'audio_path': test_sample.get("audio_path"), 'lrc_path': lrc_path, 'duration': duration, 'num_frames': num_frames } return processed_sample def generate_style_embedding(self, audio_path): """Generate style embedding using MuQ model on the whole music""" try: # Load audio waveform, sample_rate = torchaudio.load(audio_path) # Resample to 24kHz if needed (MuQ expects 24kHz) if sample_rate != 24000: resampler = torchaudio.transforms.Resample(sample_rate, 24000) waveform = resampler(waveform) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono waveform = waveform.squeeze(0) # Now shape is (time,) # Move to same device as model waveform = waveform.to(self.muq_model.device) # Generate embedding using MuQ model with torch.inference_mode(): # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim) if self.random_crop_style: total_samples = waveform.shape[0] target_samples = 24000 * self.num_style_secs if total_samples > target_samples: start_idx = random.randint(0, total_samples - target_samples) wav_input = waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples] else: wav_input = waveform.unsqueeze(0) style_embedding = self.muq_model(wavs=wav_input) else: style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs]) return style_embedding[0] except Exception as e: print(f"Error generating style embedding from {audio_path}: {e}") return torch.zeros(512).to(self.muq_model.device) def custom_collate_fn_with_metadata(batch, base_collate_fn): """Custom collate function that preserves test_metadata""" batch = [item for item in batch if item is not None] if not batch: return None test_metadata = [item.pop('test_metadata') for item in batch] collated = base_collate_fn(batch) if collated is not None: collated['test_metadata'] = test_metadata return collated def load_model(model_config, checkpoint_path, device): """ Load JAM CFM model from checkpoint (follows infer.py pattern) """ dit_config = model_config["dit"].copy() if "text_num_embeds" not in dit_config: dit_config["text_num_embeds"] = 256 cfm = CFM( transformer=DiT(**dit_config), **model_config["cfm"] ) cfm = cfm.to(device) # Load checkpoint checkpoint = load_file(checkpoint_path) cfm.load_state_dict(checkpoint, strict=False) return cfm.eval() def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'): """ Generate latent from batch data (follows infer.py pattern) """ with torch.inference_mode(): batch_size = len(batch["lrc"]) text = batch["lrc"].to(device) style_prompt = batch["prompt"].to(device) start_time = batch["start_time"].to(device) duration_abs = batch["duration_abs"].to(device) duration_rel = batch["duration_rel"].to(device) max_frames = model.max_frames cond = torch.zeros(batch_size, max_frames, 64).to(text.device) pred_frames = [(0, max_frames)] default_sample_kwargs = { "cfg_strength": 4, "steps": 50, "batch_infer_num": 1 } sample_kwargs = {**default_sample_kwargs, **sample_kwargs} if negative_style_prompt_path is None: # Fallback path, or ensure file exists negative_style_prompt_path = 'public/vocal.npy' if negative_style_prompt_path == 'zeros': negative_style_prompt = torch.zeros(1, 512).to(text.device) else: negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) latents, _ = model.sample( cond=cond, text=text, style_prompt=negative_style_prompt if ignore_style else style_prompt, duration_abs=duration_abs, duration_rel=duration_rel, negative_style_prompt=negative_style_prompt, start_time=start_time, latent_pred_segments=pred_frames, **sample_kwargs ) return latents class Jamify: def __init__(self): os.makedirs('outputs', exist_ok=True) # --------------------------------------------------------- # FIX: Automatically detect CPU vs CUDA # --------------------------------------------------------- self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Initializing Jamify model on: {self.device}") config_path = 'jam_infer.yaml' self.config = OmegaConf.load(config_path) OmegaConf.resolve(self.config) print("Downloading main model checkpoint...") try: model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") except Exception as e: print(f"Failed to download model (might be offline): {e}") # Load VAE vae_type = self.config.evaluation.get('vae_type', 'stable_audio') if vae_type == 'diffrhythm': self.vae = DiffRhythmVAE(device=self.device).to(self.device) else: self.vae = StableAudioOpenVAE().to(self.device) self.vae_type = vae_type # Load CFM self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, self.device) # Load MuQ (ensure float32 on CPU) self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval() if self.device == 'cpu': self.muq_model = self.muq_model.float() dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) enhance_webdataset_config(dataset_cfg) dataset_cfg.multiple_styles = False self.base_dataset = DiffusionWebDataset(**dataset_cfg) def cleanup_old_files(self, sample_id): old_mp3_files = sorted(glob.glob("outputs/*.mp3")) if len(old_mp3_files) >= 10: for old_file in old_mp3_files[:-9]: try: os.remove(old_file) except OSError: pass try: os.unlink(f"outputs/{sample_id}.json") except OSError: pass def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration): sample_id = str(int(time.time() * 1000000)) test_set = [{ "id": sample_id, "audio_path": reference_audio_path, "lrc_path": lyrics_json_path, "duration": duration, "prompt_path": style_prompt }] json_path = f"outputs/{sample_id}.json" with open(json_path, "w") as f: json.dump(test_set, f) try: test_dataset = FilteredTestSetDataset( test_set_path=json_path, diffusion_dataset=self.base_dataset, muq_model=self.muq_model, num_samples=1, random_crop_style=self.config.evaluation.random_crop_style, num_style_secs=self.config.evaluation.num_style_secs, use_prompt_style=self.config.evaluation.use_prompt_style ) dataloader = DataLoader( test_dataset, batch_size=1, shuffle=False, collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn) ) try: batch = next(iter(dataloader)) except StopIteration: raise ValueError("Data loader returned empty batch. Check inputs.") sample_kwargs = self.config.evaluation.sample_kwargs # Generate latents generated_latents = generate_latent( self.cfm_model, batch, sample_kwargs, self.config.evaluation.negative_style_prompt, self.config.evaluation.ignore_style, device=self.device ) # FIX: Correct indexing [0] for single batch item latent = generated_latents[0] test_metadata = batch['test_metadata'][0] original_duration = test_metadata['duration'] # Decode latent_for_vae = latent.transpose(0, 1).unsqueeze(0) use_chunked = self.config.evaluation.get('use_chunked_decoding', True) if self.vae_type == 'diffrhythm' and use_chunked: # DiffRhythm chunked decode pred_audio = self.vae.decode( latent_for_vae, chunked=True, overlap=self.config.evaluation.get('chunked_overlap', 32), chunk_size=self.config.evaluation.get('chunked_size', 128) ).sample.squeeze(0).detach().cpu() else: pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() pred_audio = normalize_audio(pred_audio) sample_rate = 44100 trim_samples = int(original_duration * sample_rate) if pred_audio.shape[1] > trim_samples: pred_audio_trimmed = pred_audio[:, :trim_samples] else: pred_audio_trimmed = pred_audio output_path = f'outputs/{sample_id}.mp3' torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3") self.cleanup_old_files(sample_id) return output_path finally: if os.path.exists(json_path): os.unlink(json_path)