Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Generate audio using JAM model | |
| Reads from filtered test set and generates audio using CFM+DiT model. | |
| """ | |
| import os | |
| import glob | |
| import time | |
| import json | |
| import random | |
| import sys | |
| from huggingface_hub import snapshot_download | |
| import torch | |
| import torchaudio | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm.auto import tqdm | |
| import accelerate | |
| import pyloudnorm as pyln | |
| from safetensors.torch import load_file | |
| from muq import MuQMuLan | |
| import numpy as np | |
| from accelerate import Accelerator | |
| from jam.dataset import enhance_webdataset_config, DiffusionWebDataset | |
| from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE | |
| # DiffRhythm imports for CFM+DiT model | |
| from jam.model import CFM, DiT | |
| def get_negative_style_prompt(device, file_path): | |
| if not os.path.exists(file_path): | |
| # Fallback if resource not found | |
| return torch.zeros(1, 512).to(device).float() | |
| vocal_stlye = np.load(file_path) | |
| vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512] | |
| # Only use half precision on CUDA | |
| if str(device) == 'cuda': | |
| vocal_stlye = vocal_stlye.half() | |
| else: | |
| vocal_stlye = vocal_stlye.float() | |
| return vocal_stlye | |
| def normalize_audio(audio, normalize_lufs=True): | |
| audio = audio - audio.mean(-1, keepdim=True) | |
| audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8) | |
| if normalize_lufs: | |
| meter = pyln.Meter(rate=44100) | |
| target_lufs = -14.0 | |
| # pyln expects (samples, channels) numpy array | |
| try: | |
| loudness = meter.integrated_loudness(audio.transpose(0, 1).cpu().numpy()) | |
| normalised = pyln.normalize.loudness(audio.transpose(0, 1).cpu().numpy(), loudness, target_lufs) | |
| normalised = torch.from_numpy(normalised).transpose(0, 1) | |
| except Exception as e: | |
| # Silently fail back to unnormalized if silence/error | |
| normalised = audio | |
| else: | |
| normalised = audio | |
| return normalised | |
| class FilteredTestSetDataset(Dataset): | |
| """Custom dataset for loading from filtered test set JSON""" | |
| def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False): | |
| with open(test_set_path, 'r') as f: | |
| self.test_samples = json.load(f) | |
| if num_samples is not None: | |
| self.test_samples = self.test_samples[:num_samples] | |
| self.diffusion_dataset = diffusion_dataset | |
| self.muq_model = muq_model | |
| self.random_crop_style = random_crop_style | |
| self.num_style_secs = num_style_secs | |
| self.use_prompt_style = use_prompt_style | |
| if self.use_prompt_style: | |
| print("Using prompt style instead of audio style.") | |
| def __len__(self): | |
| return len(self.test_samples) | |
| def __getitem__(self, idx): | |
| test_sample = self.test_samples[idx] | |
| sample_id = test_sample["id"] | |
| # Load LRC data | |
| lrc_path = test_sample["lrc_path"] | |
| with open(lrc_path, 'r') as f: | |
| lrc_data = json.load(f) | |
| if 'word' not in lrc_data: | |
| data = {'word': lrc_data} | |
| lrc_data = data | |
| # Generate style embedding | |
| if self.use_prompt_style: | |
| prompt_content = test_sample.get("prompt_path", "") | |
| if prompt_content: | |
| # Truncate prompt if too long | |
| if len(prompt_content) > 300: | |
| prompt_content = prompt_content[:300] | |
| style_embedding = self.muq_model(texts=[prompt_content]).squeeze(0) | |
| else: | |
| style_embedding = torch.zeros(512).to(self.muq_model.device) | |
| else: | |
| audio_path = test_sample.get("audio_path") | |
| if audio_path and os.path.exists(audio_path): | |
| style_embedding = self.generate_style_embedding(audio_path) | |
| else: | |
| style_embedding = torch.zeros(512).to(self.muq_model.device) | |
| duration = test_sample["duration"] | |
| # Create fake latent with correct length | |
| # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz) | |
| frame_rate = 21.5 | |
| num_frames = int(duration * frame_rate) | |
| fake_latent = torch.randn(128, num_frames) # 128 is latent dim | |
| # Create sample tuple matching DiffusionWebDataset format | |
| fake_sample = ( | |
| sample_id, | |
| fake_latent, # latent with correct duration | |
| style_embedding, # style from actual audio | |
| lrc_data # actual LRC data | |
| ) | |
| # Process through DiffusionWebDataset's process_sample_safely | |
| processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample) | |
| # Add metadata | |
| if processed_sample is not None: | |
| processed_sample['test_metadata'] = { | |
| 'sample_id': sample_id, | |
| 'audio_path': test_sample.get("audio_path"), | |
| 'lrc_path': lrc_path, | |
| 'duration': duration, | |
| 'num_frames': num_frames | |
| } | |
| return processed_sample | |
| def generate_style_embedding(self, audio_path): | |
| """Generate style embedding using MuQ model on the whole music""" | |
| try: | |
| # Load audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Resample to 24kHz if needed (MuQ expects 24kHz) | |
| if sample_rate != 24000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 24000) | |
| waveform = resampler(waveform) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono | |
| waveform = waveform.squeeze(0) # Now shape is (time,) | |
| # Move to same device as model | |
| waveform = waveform.to(self.muq_model.device) | |
| # Generate embedding using MuQ model | |
| with torch.inference_mode(): | |
| # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim) | |
| if self.random_crop_style: | |
| total_samples = waveform.shape[0] | |
| target_samples = 24000 * self.num_style_secs | |
| if total_samples > target_samples: | |
| start_idx = random.randint(0, total_samples - target_samples) | |
| wav_input = waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples] | |
| else: | |
| wav_input = waveform.unsqueeze(0) | |
| style_embedding = self.muq_model(wavs=wav_input) | |
| else: | |
| style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs]) | |
| return style_embedding[0] | |
| except Exception as e: | |
| print(f"Error generating style embedding from {audio_path}: {e}") | |
| return torch.zeros(512).to(self.muq_model.device) | |
| def custom_collate_fn_with_metadata(batch, base_collate_fn): | |
| """Custom collate function that preserves test_metadata""" | |
| batch = [item for item in batch if item is not None] | |
| if not batch: | |
| return None | |
| test_metadata = [item.pop('test_metadata') for item in batch] | |
| collated = base_collate_fn(batch) | |
| if collated is not None: | |
| collated['test_metadata'] = test_metadata | |
| return collated | |
| def load_model(model_config, checkpoint_path, device): | |
| """ | |
| Load JAM CFM model from checkpoint (follows infer.py pattern) | |
| """ | |
| dit_config = model_config["dit"].copy() | |
| if "text_num_embeds" not in dit_config: | |
| dit_config["text_num_embeds"] = 256 | |
| cfm = CFM( | |
| transformer=DiT(**dit_config), | |
| **model_config["cfm"] | |
| ) | |
| cfm = cfm.to(device) | |
| # Load checkpoint | |
| checkpoint = load_file(checkpoint_path) | |
| cfm.load_state_dict(checkpoint, strict=False) | |
| return cfm.eval() | |
| def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'): | |
| """ | |
| Generate latent from batch data (follows infer.py pattern) | |
| """ | |
| with torch.inference_mode(): | |
| batch_size = len(batch["lrc"]) | |
| text = batch["lrc"].to(device) | |
| style_prompt = batch["prompt"].to(device) | |
| start_time = batch["start_time"].to(device) | |
| duration_abs = batch["duration_abs"].to(device) | |
| duration_rel = batch["duration_rel"].to(device) | |
| max_frames = model.max_frames | |
| cond = torch.zeros(batch_size, max_frames, 64).to(text.device) | |
| pred_frames = [(0, max_frames)] | |
| default_sample_kwargs = { | |
| "cfg_strength": 4, | |
| "steps": 50, | |
| "batch_infer_num": 1 | |
| } | |
| sample_kwargs = {**default_sample_kwargs, **sample_kwargs} | |
| if negative_style_prompt_path is None: | |
| # Fallback path, or ensure file exists | |
| negative_style_prompt_path = 'public/vocal.npy' | |
| if negative_style_prompt_path == 'zeros': | |
| negative_style_prompt = torch.zeros(1, 512).to(text.device) | |
| else: | |
| negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path) | |
| negative_style_prompt = negative_style_prompt.repeat(batch_size, 1) | |
| latents, _ = model.sample( | |
| cond=cond, | |
| text=text, | |
| style_prompt=negative_style_prompt if ignore_style else style_prompt, | |
| duration_abs=duration_abs, | |
| duration_rel=duration_rel, | |
| negative_style_prompt=negative_style_prompt, | |
| start_time=start_time, | |
| latent_pred_segments=pred_frames, | |
| **sample_kwargs | |
| ) | |
| return latents | |
| class Jamify: | |
| def __init__(self): | |
| os.makedirs('outputs', exist_ok=True) | |
| # --------------------------------------------------------- | |
| # FIX: Automatically detect CPU vs CUDA | |
| # --------------------------------------------------------- | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Initializing Jamify model on: {self.device}") | |
| config_path = 'jam_infer.yaml' | |
| self.config = OmegaConf.load(config_path) | |
| OmegaConf.resolve(self.config) | |
| print("Downloading main model checkpoint...") | |
| try: | |
| model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5") | |
| self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors") | |
| except Exception as e: | |
| print(f"Failed to download model (might be offline): {e}") | |
| # Load VAE | |
| vae_type = self.config.evaluation.get('vae_type', 'stable_audio') | |
| if vae_type == 'diffrhythm': | |
| self.vae = DiffRhythmVAE(device=self.device).to(self.device) | |
| else: | |
| self.vae = StableAudioOpenVAE().to(self.device) | |
| self.vae_type = vae_type | |
| # Load CFM | |
| self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, self.device) | |
| # Load MuQ (ensure float32 on CPU) | |
| self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval() | |
| if self.device == 'cpu': | |
| self.muq_model = self.muq_model.float() | |
| dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset) | |
| enhance_webdataset_config(dataset_cfg) | |
| dataset_cfg.multiple_styles = False | |
| self.base_dataset = DiffusionWebDataset(**dataset_cfg) | |
| def cleanup_old_files(self, sample_id): | |
| old_mp3_files = sorted(glob.glob("outputs/*.mp3")) | |
| if len(old_mp3_files) >= 10: | |
| for old_file in old_mp3_files[:-9]: | |
| try: | |
| os.remove(old_file) | |
| except OSError: | |
| pass | |
| try: | |
| os.unlink(f"outputs/{sample_id}.json") | |
| except OSError: | |
| pass | |
| def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration): | |
| sample_id = str(int(time.time() * 1000000)) | |
| test_set = [{ | |
| "id": sample_id, | |
| "audio_path": reference_audio_path, | |
| "lrc_path": lyrics_json_path, | |
| "duration": duration, | |
| "prompt_path": style_prompt | |
| }] | |
| json_path = f"outputs/{sample_id}.json" | |
| with open(json_path, "w") as f: | |
| json.dump(test_set, f) | |
| try: | |
| test_dataset = FilteredTestSetDataset( | |
| test_set_path=json_path, | |
| diffusion_dataset=self.base_dataset, | |
| muq_model=self.muq_model, | |
| num_samples=1, | |
| random_crop_style=self.config.evaluation.random_crop_style, | |
| num_style_secs=self.config.evaluation.num_style_secs, | |
| use_prompt_style=self.config.evaluation.use_prompt_style | |
| ) | |
| dataloader = DataLoader( | |
| test_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn) | |
| ) | |
| try: | |
| batch = next(iter(dataloader)) | |
| except StopIteration: | |
| raise ValueError("Data loader returned empty batch. Check inputs.") | |
| sample_kwargs = self.config.evaluation.sample_kwargs | |
| # Generate latents | |
| generated_latents = generate_latent( | |
| self.cfm_model, | |
| batch, | |
| sample_kwargs, | |
| self.config.evaluation.negative_style_prompt, | |
| self.config.evaluation.ignore_style, | |
| device=self.device | |
| ) | |
| # FIX: Correct indexing [0] for single batch item | |
| latent = generated_latents[0] | |
| test_metadata = batch['test_metadata'][0] | |
| original_duration = test_metadata['duration'] | |
| # Decode | |
| latent_for_vae = latent.transpose(0, 1).unsqueeze(0) | |
| use_chunked = self.config.evaluation.get('use_chunked_decoding', True) | |
| if self.vae_type == 'diffrhythm' and use_chunked: | |
| # DiffRhythm chunked decode | |
| pred_audio = self.vae.decode( | |
| latent_for_vae, | |
| chunked=True, | |
| overlap=self.config.evaluation.get('chunked_overlap', 32), | |
| chunk_size=self.config.evaluation.get('chunked_size', 128) | |
| ).sample.squeeze(0).detach().cpu() | |
| else: | |
| pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu() | |
| pred_audio = normalize_audio(pred_audio) | |
| sample_rate = 44100 | |
| trim_samples = int(original_duration * sample_rate) | |
| if pred_audio.shape[1] > trim_samples: | |
| pred_audio_trimmed = pred_audio[:, :trim_samples] | |
| else: | |
| pred_audio_trimmed = pred_audio | |
| output_path = f'outputs/{sample_id}.mp3' | |
| torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3") | |
| self.cleanup_old_files(sample_id) | |
| return output_path | |
| finally: | |
| if os.path.exists(json_path): | |
| os.unlink(json_path) |