import os from pathlib import Path from typing import List, Tuple import uuid import json import argparse import gradio as gr import torch import torchaudio from safetensors.torch import load_file from tqdm import tqdm from model import LocalSongModel from acestep.music_dcae.music_dcae_pipeline import MusicDCAE class TagEmbedder: def __init__(self, mapping_file: str = "checkpoints/tag_mapping.json"): with open(mapping_file, 'r', encoding='utf-8') as f: self.tag_mapping = json.load(f) self.num_classes = 2304 class AudioVAE: def __init__(self, device: torch.device): self.model = MusicDCAE().to(device) self.model.eval() self.device = device self.latent_mean = torch.tensor( [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device, ).view(1, -1, 1, 1) self.latent_std = torch.tensor( [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device, ).view(1, -1, 1, 1) def decode(self, latents: torch.Tensor) -> torch.Tensor: with torch.no_grad(): latents = latents * self.latent_std + self.latent_mean sr, audio_list = self.model.decode(latents, sr=48000) audio_batch = torch.stack(audio_list).to(self.device) return audio_batch class RF: def __init__(self, model: torch.nn.Module): self.model = model def sample( self, z: torch.Tensor, cond: List[List[int]], null_cond: List[List[int]] | None = None, sample_steps: int = 100, cfg: float = 3.0, ) -> List[torch.Tensor]: batch = z.size(0) dt = 1.0 / sample_steps dt = torch.tensor([dt] * batch, device=z.device).view([batch, *([1] * len(z.shape[1:]))]) images = [z] for i in tqdm(range(sample_steps, 0, -1), desc="Generating", unit="step"): t = torch.tensor([i / sample_steps] * batch, device=z.device) if null_cond is not None: z_batched = torch.cat([z, z], dim=0) t_batched = torch.cat([t, t], dim=0) cond_batched = cond + null_cond v_batched = self.model(z_batched, t_batched, cond_batched) vc, vu = v_batched.chunk(2, dim=0) vc = vu + cfg * (vc - vu) else: vc = self.model(z, t, cond) z = z - dt * vc images.append(z) return images model: torch.nn.Module | None = None vae: AudioVAE | None = None tag_embedder: TagEmbedder | None = None rf_sampler: RF | None = None device: torch.device | None = None _available_tags: List[str] | None = None def load_resources(checkpoint_path) -> List[str]: torch.set_float32_matmul_precision('high') global model, vae, tag_embedder, rf_sampler, device, _available_tags if _available_tags is not None: return _available_tags device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tag_embedder = TagEmbedder() model = LocalSongModel( in_channels=8, num_groups=16, hidden_size=1024, decoder_hidden_size=2048, num_blocks=36, patch_size=(16, 1), num_classes=tag_embedder.num_classes, max_tags=8, ).to(device) print(f"Loading checkpoint: {checkpoint_path}") state_dict = load_file(checkpoint_path, device=str(device)) model.load_state_dict(state_dict, strict=True) model.eval() vae = AudioVAE(device) rf_sampler = RF(model) _available_tags = sorted(tag_embedder.tag_mapping.keys()) return _available_tags def _tags_to_indices(tags: List[str]) -> List[int]: assert tag_embedder is not None indices = [] for tag in tags: tag_lower = tag.lower().strip() if tag_lower in tag_embedder.tag_mapping: indices.append(tag_embedder.tag_mapping[tag_lower]) return indices def generate_audio( tags: List[str], cfg: float, sample_steps: int, ) -> Tuple[Tuple[int, object], str]: assert model is not None and vae is not None and rf_sampler is not None and device is not None if not tags: tags = [] if len(tags) > 8: raise gr.Error("A maximum of 8 tags is supported.") tag_indices = _tags_to_indices(tags) batch = 1 channels = 8 height = 16 width = 512 z = torch.randn(batch, channels, height, width, device=device) cond = [tag_indices] null_cond = [[]] with torch.no_grad(): sampled_latents = rf_sampler.sample( z=z, cond=cond, null_cond=null_cond, sample_steps=sample_steps, cfg=cfg, )[-1] audio = vae.decode(sampled_latents) audio_tensor = audio[0].cpu() sr = 48000 audio_numpy = audio_tensor.transpose(0, 1).numpy() os.makedirs("generated", exist_ok=True) output_path = f"generated/generated_{uuid.uuid4().hex}.wav" torchaudio.save(str(output_path), audio_tensor, sr) return (sr, audio_numpy), str(output_path) def build_interface(checkpoint_path) -> gr.Blocks: available_tags = load_resources(checkpoint_path) # Define preset tag combinations presets = [ ["soundtrack1", "female vocalist","rock","melodic"], ["soundtrack", "chrono trigger", "emotional", "piano", "strings"], ["soundtrack", "touhou 10", "trumpet"], ["soundtrack", "christmas music","winter","melodic"], ["soundtrack2", "male vocalist","pop","melodic","acoustic guitar","ballad"], ] with gr.Blocks(title="LocalSong") as demo: gr.Markdown("# LocalSong") with gr.Row(): tag_input = gr.Dropdown( label="Tags (select up to 8)", choices=available_tags, multiselect=True, max_choices=8, value=presets[0], ) gr.Markdown("**Presets:**") with gr.Row(): for preset in presets: btn = gr.Button(f"{' + '.join(preset)}", size="sm") def make_preset_fn(p): return lambda: p btn.click(fn=make_preset_fn(preset), inputs=None, outputs=tag_input) with gr.Row(): cfg_slider = gr.Slider( label="CFG Scale", minimum=1.0, maximum=7.0, step=0.5, value=3.5, ) sample_steps_slider = gr.Slider( label="Sample Steps", minimum=50, maximum=200, step=10, value=200, ) with gr.Row(): seed_input = gr.Number( label="Seed", value=45, precision=0, ) generate_button = gr.Button("Generate Audio", variant="primary") audio_output = gr.Audio(label="Generated Audio", type="numpy") download_output = gr.File(label="Download WAV") def generate_wrapper(tags, cfg, steps, seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) return generate_audio(tags, cfg, steps) generate_button.click( fn=generate_wrapper, inputs=[ tag_input, cfg_slider, sample_steps_slider, seed_input, ], outputs=[ audio_output, download_output, ], ) return demo if __name__ == "__main__": parser = argparse.ArgumentParser(description="LocalSong Gradio Interface") parser.add_argument( "--checkpoint", type=str, default="checkpoints/checkpoint_461260.safetensors", help="Path to the model checkpoint" ) args = parser.parse_args() demo = build_interface(args.checkpoint) demo.launch()