File size: 15,921 Bytes
65e9daa
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
7d35d1e
 
 
65e9daa
 
 
 
7d35d1e
 
65e9daa
 
 
 
 
 
 
 
7d35d1e
 
df42459
 
 
 
65e9daa
 
df42459
 
 
 
 
 
65e9daa
 
7d35d1e
65e9daa
7d35d1e
 
65e9daa
 
 
df42459
 
 
 
 
 
 
 
65e9daa
 
 
7d35d1e
65e9daa
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
7d35d1e
df42459
65e9daa
df42459
 
 
 
 
 
 
 
65e9daa
df42459
 
 
 
 
7d35d1e
65e9daa
7d35d1e
65e9daa
 
 
 
 
7d35d1e
65e9daa
 
 
 
 
 
 
7d35d1e
65e9daa
 
7d35d1e
65e9daa
 
 
 
df42459
65e9daa
 
 
 
 
 
 
 
 
df42459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d35d1e
 
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
4c10907
 
65e9daa
 
 
 
 
 
df42459
65e9daa
 
 
 
 
 
 
df42459
65e9daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
 
4c10907
65e9daa
df42459
 
 
 
65e9daa
 
 
 
 
4c10907
65e9daa
 
 
 
 
 
 
 
 
 
 
4c10907
65e9daa
4c10907
 
65e9daa
 
 
4c10907
df42459
 
 
 
 
 
65e9daa
 
 
 
 
df42459
 
 
 
 
 
 
65e9daa
 
df42459
65e9daa
df42459
8e872fa
65e9daa
df42459
 
 
 
 
 
 
 
65e9daa
 
 
 
 
 
 
 
 
df42459
8e872fa
 
 
 
df42459
 
 
 
65e9daa
 
df42459
 
65e9daa
 
 
 
 
 
 
8e872fa
df42459
 
 
4c10907
df42459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65e9daa
df42459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65e9daa
df42459
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
#!/usr/bin/env python3
"""
Generate audio using JAM model
Reads from filtered test set and generates audio using CFM+DiT model.
"""

import os
import glob
import time
import json
import random
import sys
from huggingface_hub import snapshot_download
import torch
import torchaudio
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import accelerate
import pyloudnorm as pyln
from safetensors.torch import load_file
from muq import MuQMuLan
import numpy as np
from accelerate import Accelerator

from jam.dataset import enhance_webdataset_config, DiffusionWebDataset
from jam.model.vae import StableAudioOpenVAE, DiffRhythmVAE

# DiffRhythm imports for CFM+DiT model
from jam.model import CFM, DiT

def get_negative_style_prompt(device, file_path):
    if not os.path.exists(file_path):
        # Fallback if resource not found
        return torch.zeros(1, 512).to(device).float()
        
    vocal_stlye = np.load(file_path)
    vocal_stlye = torch.from_numpy(vocal_stlye).to(device)  # [1, 512]
    
    # Only use half precision on CUDA
    if str(device) == 'cuda':
        vocal_stlye = vocal_stlye.half()
    else:
        vocal_stlye = vocal_stlye.float()

    return vocal_stlye

def normalize_audio(audio, normalize_lufs=True):
    audio = audio - audio.mean(-1, keepdim=True)
    audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
    if normalize_lufs:
        meter = pyln.Meter(rate=44100)
        target_lufs = -14.0
        # pyln expects (samples, channels) numpy array
        try:
            loudness = meter.integrated_loudness(audio.transpose(0, 1).cpu().numpy())
            normalised = pyln.normalize.loudness(audio.transpose(0, 1).cpu().numpy(), loudness, target_lufs)
            normalised = torch.from_numpy(normalised).transpose(0, 1)
        except Exception as e:
            # Silently fail back to unnormalized if silence/error
            normalised = audio
    else:
        normalised = audio
    return normalised

class FilteredTestSetDataset(Dataset):
    """Custom dataset for loading from filtered test set JSON"""
    def __init__(self, test_set_path, diffusion_dataset, muq_model, num_samples=None, random_crop_style=False, num_style_secs=30, use_prompt_style=False):
        with open(test_set_path, 'r') as f:
            self.test_samples = json.load(f)
        
        if num_samples is not None:
            self.test_samples = self.test_samples[:num_samples]
            
        self.diffusion_dataset = diffusion_dataset
        self.muq_model = muq_model
        self.random_crop_style = random_crop_style
        self.num_style_secs = num_style_secs
        self.use_prompt_style = use_prompt_style
        if self.use_prompt_style:
            print("Using prompt style instead of audio style.")

    def __len__(self):
        return len(self.test_samples)
    
    def __getitem__(self, idx):
        test_sample = self.test_samples[idx]
        sample_id = test_sample["id"]
        
        # Load LRC data
        lrc_path = test_sample["lrc_path"]
        with open(lrc_path, 'r') as f:
            lrc_data = json.load(f)
        if 'word' not in lrc_data:
            data = {'word': lrc_data}
            lrc_data = data
        
        # Generate style embedding
        if self.use_prompt_style:
            prompt_content = test_sample.get("prompt_path", "")
            if prompt_content:
                # Truncate prompt if too long
                if len(prompt_content) > 300:
                    prompt_content = prompt_content[:300]
                style_embedding = self.muq_model(texts=[prompt_content]).squeeze(0)
            else:
                style_embedding = torch.zeros(512).to(self.muq_model.device)
        else:
            audio_path = test_sample.get("audio_path")
            if audio_path and os.path.exists(audio_path):
                style_embedding = self.generate_style_embedding(audio_path)
            else:
                style_embedding = torch.zeros(512).to(self.muq_model.device)
        
        duration = test_sample["duration"]
        
        # Create fake latent with correct length
        # Assuming frame_rate from config (typically 21.5 fps for 44.1kHz)
        frame_rate = 21.5
        num_frames = int(duration * frame_rate)
        fake_latent = torch.randn(128, num_frames)  # 128 is latent dim
        
        # Create sample tuple matching DiffusionWebDataset format
        fake_sample = (
            sample_id,
            fake_latent,     # latent with correct duration
            style_embedding, # style from actual audio
            lrc_data        # actual LRC data
        )
        
        # Process through DiffusionWebDataset's process_sample_safely
        processed_sample = self.diffusion_dataset.process_sample_safely(fake_sample)
        
        # Add metadata
        if processed_sample is not None:
            processed_sample['test_metadata'] = {
                'sample_id': sample_id,
                'audio_path': test_sample.get("audio_path"),
                'lrc_path': lrc_path,
                'duration': duration,
                'num_frames': num_frames
            }
        
        return processed_sample
    
    def generate_style_embedding(self, audio_path):
        """Generate style embedding using MuQ model on the whole music"""
        try:
            # Load audio
            waveform, sample_rate = torchaudio.load(audio_path)
            
            # Resample to 24kHz if needed (MuQ expects 24kHz)
            if sample_rate != 24000:
                resampler = torchaudio.transforms.Resample(sample_rate, 24000)
                waveform = resampler(waveform)
            
            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = waveform.mean(dim=0, keepdim=True)
            
            # Ensure waveform is 2D (channels, time) - squeeze out channel dim for mono
            waveform = waveform.squeeze(0)  # Now shape is (time,)
            
            # Move to same device as model
            waveform = waveform.to(self.muq_model.device)
            
            # Generate embedding using MuQ model
            with torch.inference_mode():
                # MuQ expects batch dimension and 1D audio, returns (batch, embedding_dim)
                if self.random_crop_style:
                    total_samples = waveform.shape[0]
                    target_samples = 24000 * self.num_style_secs
                    
                    if total_samples > target_samples:
                        start_idx = random.randint(0, total_samples - target_samples)
                        wav_input = waveform.unsqueeze(0)[..., start_idx:start_idx + target_samples]
                    else:
                        wav_input = waveform.unsqueeze(0)
                        
                    style_embedding = self.muq_model(wavs=wav_input)
                else:
                    style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * self.num_style_secs])
            
            return style_embedding[0]
            
        except Exception as e:
            print(f"Error generating style embedding from {audio_path}: {e}")
            return torch.zeros(512).to(self.muq_model.device)


def custom_collate_fn_with_metadata(batch, base_collate_fn):
    """Custom collate function that preserves test_metadata"""
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    
    test_metadata = [item.pop('test_metadata') for item in batch]
    collated = base_collate_fn(batch)
    
    if collated is not None:
        collated['test_metadata'] = test_metadata
    
    return collated


def load_model(model_config, checkpoint_path, device):
    """
    Load JAM CFM model from checkpoint (follows infer.py pattern)
    """
    dit_config = model_config["dit"].copy()
    if "text_num_embeds" not in dit_config:
        dit_config["text_num_embeds"] = 256
    
    cfm = CFM(
        transformer=DiT(**dit_config),
        **model_config["cfm"]
    )
    cfm = cfm.to(device)
    
    # Load checkpoint
    checkpoint = load_file(checkpoint_path)
    cfm.load_state_dict(checkpoint, strict=False)
    
    return cfm.eval()


def generate_latent(model, batch, sample_kwargs, negative_style_prompt_path=None, ignore_style=False, device='cuda'):
    """
    Generate latent from batch data (follows infer.py pattern)
    """
    with torch.inference_mode():
        batch_size = len(batch["lrc"])
        text = batch["lrc"].to(device)
        style_prompt = batch["prompt"].to(device)
        start_time = batch["start_time"].to(device)
        duration_abs = batch["duration_abs"].to(device)
        duration_rel = batch["duration_rel"].to(device)
        
        max_frames = model.max_frames
        cond = torch.zeros(batch_size, max_frames, 64).to(text.device)
        pred_frames = [(0, max_frames)]

        default_sample_kwargs = {
            "cfg_strength": 4,
            "steps": 50,
            "batch_infer_num": 1
        }
        sample_kwargs = {**default_sample_kwargs, **sample_kwargs}
        
        if negative_style_prompt_path is None:
            # Fallback path, or ensure file exists
            negative_style_prompt_path = 'public/vocal.npy' 
            
        if negative_style_prompt_path == 'zeros':
            negative_style_prompt = torch.zeros(1, 512).to(text.device)
        else:
            negative_style_prompt = get_negative_style_prompt(text.device, negative_style_prompt_path)

        negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)

        latents, _ = model.sample(
            cond=cond,
            text=text,
            style_prompt=negative_style_prompt if ignore_style else style_prompt,
            duration_abs=duration_abs,
            duration_rel=duration_rel,
            negative_style_prompt=negative_style_prompt,
            start_time=start_time,
            latent_pred_segments=pred_frames,
            **sample_kwargs
        )
        
        return latents


class Jamify:
    def __init__(self):
        os.makedirs('outputs', exist_ok=True)
        
        # ---------------------------------------------------------
        # FIX: Automatically detect CPU vs CUDA
        # ---------------------------------------------------------
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Initializing Jamify model on: {self.device}")

        config_path = 'jam_infer.yaml'
        self.config = OmegaConf.load(config_path)
        OmegaConf.resolve(self.config)

        print("Downloading main model checkpoint...")
        try:
            model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
            self.config.evaluation.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
        except Exception as e:
            print(f"Failed to download model (might be offline): {e}")

        # Load VAE
        vae_type = self.config.evaluation.get('vae_type', 'stable_audio')
        if vae_type == 'diffrhythm':
            self.vae = DiffRhythmVAE(device=self.device).to(self.device)
        else:
            self.vae = StableAudioOpenVAE().to(self.device)
        
        self.vae_type = vae_type
        
        # Load CFM
        self.cfm_model = load_model(self.config.model, self.config.evaluation.checkpoint_path, self.device)
        
        # Load MuQ (ensure float32 on CPU)
        self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
        if self.device == 'cpu':
            self.muq_model = self.muq_model.float()

        dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
        enhance_webdataset_config(dataset_cfg)
        dataset_cfg.multiple_styles = False
        self.base_dataset = DiffusionWebDataset(**dataset_cfg)

    def cleanup_old_files(self, sample_id):
        old_mp3_files = sorted(glob.glob("outputs/*.mp3"))
        if len(old_mp3_files) >= 10:
            for old_file in old_mp3_files[:-9]:
                try:
                    os.remove(old_file)
                except OSError:
                    pass
        try:
            os.unlink(f"outputs/{sample_id}.json")
        except OSError:
            pass
    
    def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration):
        sample_id = str(int(time.time() * 1000000))
        
        test_set = [{
            "id": sample_id,
            "audio_path": reference_audio_path,
            "lrc_path": lyrics_json_path,
            "duration": duration,
            "prompt_path": style_prompt
        }]
        
        json_path = f"outputs/{sample_id}.json"
        with open(json_path, "w") as f:
            json.dump(test_set, f)
        
        try:
            test_dataset = FilteredTestSetDataset(
                test_set_path=json_path,
                diffusion_dataset=self.base_dataset,
                muq_model=self.muq_model,
                num_samples=1,
                random_crop_style=self.config.evaluation.random_crop_style,
                num_style_secs=self.config.evaluation.num_style_secs,
                use_prompt_style=self.config.evaluation.use_prompt_style
            )
            
            dataloader = DataLoader(
                test_dataset,
                batch_size=1,
                shuffle=False,
                collate_fn=lambda batch: custom_collate_fn_with_metadata(batch, self.base_dataset.custom_collate_fn)
            )
            
            try:
                batch = next(iter(dataloader))
            except StopIteration:
                raise ValueError("Data loader returned empty batch. Check inputs.")
                
            sample_kwargs = self.config.evaluation.sample_kwargs
            
            # Generate latents
            generated_latents = generate_latent(
                self.cfm_model, 
                batch, 
                sample_kwargs, 
                self.config.evaluation.negative_style_prompt, 
                self.config.evaluation.ignore_style,
                device=self.device
            )
            
            # FIX: Correct indexing [0] for single batch item
            latent = generated_latents[0]
            
            test_metadata = batch['test_metadata'][0]
            original_duration = test_metadata['duration']

            # Decode
            latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
            
            use_chunked = self.config.evaluation.get('use_chunked_decoding', True)
            if self.vae_type == 'diffrhythm' and use_chunked:
                # DiffRhythm chunked decode
                pred_audio = self.vae.decode(
                    latent_for_vae, 
                    chunked=True, 
                    overlap=self.config.evaluation.get('chunked_overlap', 32),
                    chunk_size=self.config.evaluation.get('chunked_size', 128)
                ).sample.squeeze(0).detach().cpu()
            else:
                pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
            
            pred_audio = normalize_audio(pred_audio)
            sample_rate = 44100
            trim_samples = int(original_duration * sample_rate)
            if pred_audio.shape[1] > trim_samples:
                pred_audio_trimmed = pred_audio[:, :trim_samples]
            else:
                pred_audio_trimmed = pred_audio
                
            output_path = f'outputs/{sample_id}.mp3'
            torchaudio.save(output_path, pred_audio_trimmed, sample_rate, format="mp3")
            self.cleanup_old_files(sample_id)
            return output_path
            
        finally:
            if os.path.exists(json_path):
                os.unlink(json_path)