|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple |
|
|
import uuid |
|
|
import json |
|
|
import argparse |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from safetensors.torch import load_file |
|
|
from tqdm import tqdm |
|
|
|
|
|
from model import LocalSongModel |
|
|
from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
|
|
|
class TagEmbedder: |
|
|
def __init__(self, mapping_file: str = "checkpoints/tag_mapping.json"): |
|
|
|
|
|
with open(mapping_file, 'r', encoding='utf-8') as f: |
|
|
self.tag_mapping = json.load(f) |
|
|
|
|
|
self.num_classes = 2304 |
|
|
|
|
|
class AudioVAE: |
|
|
def __init__(self, device: torch.device): |
|
|
self.model = MusicDCAE().to(device) |
|
|
self.model.eval() |
|
|
self.device = device |
|
|
self.latent_mean = torch.tensor( |
|
|
[0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], |
|
|
device=device, |
|
|
).view(1, -1, 1, 1) |
|
|
self.latent_std = torch.tensor( |
|
|
[0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], |
|
|
device=device, |
|
|
).view(1, -1, 1, 1) |
|
|
|
|
|
def decode(self, latents: torch.Tensor) -> torch.Tensor: |
|
|
with torch.no_grad(): |
|
|
latents = latents * self.latent_std + self.latent_mean |
|
|
sr, audio_list = self.model.decode(latents, sr=48000) |
|
|
audio_batch = torch.stack(audio_list).to(self.device) |
|
|
return audio_batch |
|
|
|
|
|
class RF: |
|
|
def __init__(self, model: torch.nn.Module): |
|
|
self.model = model |
|
|
|
|
|
def sample( |
|
|
self, |
|
|
z: torch.Tensor, |
|
|
cond: List[List[int]], |
|
|
null_cond: List[List[int]] | None = None, |
|
|
sample_steps: int = 100, |
|
|
cfg: float = 3.0, |
|
|
) -> List[torch.Tensor]: |
|
|
batch = z.size(0) |
|
|
dt = 1.0 / sample_steps |
|
|
dt = torch.tensor([dt] * batch, device=z.device).view([batch, *([1] * len(z.shape[1:]))]) |
|
|
images = [z] |
|
|
for i in tqdm(range(sample_steps, 0, -1), desc="Generating", unit="step"): |
|
|
t = torch.tensor([i / sample_steps] * batch, device=z.device) |
|
|
|
|
|
if null_cond is not None: |
|
|
|
|
|
z_batched = torch.cat([z, z], dim=0) |
|
|
t_batched = torch.cat([t, t], dim=0) |
|
|
cond_batched = cond + null_cond |
|
|
v_batched = self.model(z_batched, t_batched, cond_batched) |
|
|
vc, vu = v_batched.chunk(2, dim=0) |
|
|
vc = vu + cfg * (vc - vu) |
|
|
|
|
|
else: |
|
|
vc = self.model(z, t, cond) |
|
|
|
|
|
z = z - dt * vc |
|
|
images.append(z) |
|
|
return images |
|
|
|
|
|
model: torch.nn.Module | None = None |
|
|
vae: AudioVAE | None = None |
|
|
tag_embedder: TagEmbedder | None = None |
|
|
rf_sampler: RF | None = None |
|
|
device: torch.device | None = None |
|
|
_available_tags: List[str] | None = None |
|
|
|
|
|
def load_resources(checkpoint_path) -> List[str]: |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
global model, vae, tag_embedder, rf_sampler, device, _available_tags |
|
|
|
|
|
if _available_tags is not None: |
|
|
return _available_tags |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
tag_embedder = TagEmbedder() |
|
|
|
|
|
model = LocalSongModel( |
|
|
in_channels=8, |
|
|
num_groups=16, |
|
|
hidden_size=1024, |
|
|
decoder_hidden_size=2048, |
|
|
num_blocks=36, |
|
|
patch_size=(16, 1), |
|
|
num_classes=tag_embedder.num_classes, |
|
|
max_tags=8, |
|
|
).to(device) |
|
|
|
|
|
print(f"Loading checkpoint: {checkpoint_path}") |
|
|
|
|
|
state_dict = load_file(checkpoint_path, device=str(device)) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
model.eval() |
|
|
|
|
|
vae = AudioVAE(device) |
|
|
rf_sampler = RF(model) |
|
|
|
|
|
_available_tags = sorted(tag_embedder.tag_mapping.keys()) |
|
|
return _available_tags |
|
|
|
|
|
|
|
|
def _tags_to_indices(tags: List[str]) -> List[int]: |
|
|
assert tag_embedder is not None |
|
|
indices = [] |
|
|
|
|
|
for tag in tags: |
|
|
tag_lower = tag.lower().strip() |
|
|
if tag_lower in tag_embedder.tag_mapping: |
|
|
indices.append(tag_embedder.tag_mapping[tag_lower]) |
|
|
|
|
|
return indices |
|
|
|
|
|
|
|
|
def generate_audio( |
|
|
tags: List[str], |
|
|
cfg: float, |
|
|
sample_steps: int, |
|
|
) -> Tuple[Tuple[int, object], str]: |
|
|
|
|
|
assert model is not None and vae is not None and rf_sampler is not None and device is not None |
|
|
|
|
|
if not tags: |
|
|
tags = [] |
|
|
if len(tags) > 8: |
|
|
raise gr.Error("A maximum of 8 tags is supported.") |
|
|
|
|
|
tag_indices = _tags_to_indices(tags) |
|
|
|
|
|
batch = 1 |
|
|
channels = 8 |
|
|
height = 16 |
|
|
width = 512 |
|
|
|
|
|
z = torch.randn(batch, channels, height, width, device=device) |
|
|
cond = [tag_indices] |
|
|
null_cond = [[]] |
|
|
|
|
|
with torch.no_grad(): |
|
|
sampled_latents = rf_sampler.sample( |
|
|
z=z, |
|
|
cond=cond, |
|
|
null_cond=null_cond, |
|
|
sample_steps=sample_steps, |
|
|
cfg=cfg, |
|
|
)[-1] |
|
|
audio = vae.decode(sampled_latents) |
|
|
|
|
|
audio_tensor = audio[0].cpu() |
|
|
sr = 48000 |
|
|
audio_numpy = audio_tensor.transpose(0, 1).numpy() |
|
|
|
|
|
os.makedirs("generated", exist_ok=True) |
|
|
output_path = f"generated/generated_{uuid.uuid4().hex}.wav" |
|
|
torchaudio.save(str(output_path), audio_tensor, sr) |
|
|
|
|
|
return (sr, audio_numpy), str(output_path) |
|
|
|
|
|
def build_interface(checkpoint_path) -> gr.Blocks: |
|
|
available_tags = load_resources(checkpoint_path) |
|
|
|
|
|
|
|
|
presets = [ |
|
|
["soundtrack1", "female vocalist","rock","melodic"], |
|
|
["soundtrack", "chrono trigger", "emotional", "piano", "strings"], |
|
|
["soundtrack", "touhou 10", "trumpet"], |
|
|
["soundtrack", "christmas music","winter","melodic"], |
|
|
["soundtrack2", "male vocalist","pop","melodic","acoustic guitar","ballad"], |
|
|
] |
|
|
|
|
|
with gr.Blocks(title="LocalSong") as demo: |
|
|
gr.Markdown("# LocalSong") |
|
|
|
|
|
with gr.Row(): |
|
|
tag_input = gr.Dropdown( |
|
|
label="Tags (select up to 8)", |
|
|
choices=available_tags, |
|
|
multiselect=True, |
|
|
max_choices=8, |
|
|
value=presets[0], |
|
|
) |
|
|
|
|
|
gr.Markdown("**Presets:**") |
|
|
with gr.Row(): |
|
|
for preset in presets: |
|
|
btn = gr.Button(f"{' + '.join(preset)}", size="sm") |
|
|
def make_preset_fn(p): |
|
|
return lambda: p |
|
|
btn.click(fn=make_preset_fn(preset), inputs=None, outputs=tag_input) |
|
|
|
|
|
with gr.Row(): |
|
|
cfg_slider = gr.Slider( |
|
|
label="CFG Scale", |
|
|
minimum=1.0, |
|
|
maximum=7.0, |
|
|
step=0.5, |
|
|
value=3.5, |
|
|
) |
|
|
sample_steps_slider = gr.Slider( |
|
|
label="Sample Steps", |
|
|
minimum=50, |
|
|
maximum=200, |
|
|
step=10, |
|
|
value=200, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
seed_input = gr.Number( |
|
|
label="Seed", |
|
|
value=45, |
|
|
precision=0, |
|
|
) |
|
|
|
|
|
generate_button = gr.Button("Generate Audio", variant="primary") |
|
|
audio_output = gr.Audio(label="Generated Audio", type="numpy") |
|
|
download_output = gr.File(label="Download WAV") |
|
|
|
|
|
def generate_wrapper(tags, cfg, steps, seed): |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
return generate_audio(tags, cfg, steps) |
|
|
|
|
|
generate_button.click( |
|
|
fn=generate_wrapper, |
|
|
inputs=[ |
|
|
tag_input, |
|
|
cfg_slider, |
|
|
sample_steps_slider, |
|
|
seed_input, |
|
|
], |
|
|
outputs=[ |
|
|
audio_output, |
|
|
download_output, |
|
|
], |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="LocalSong Gradio Interface") |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
type=str, |
|
|
default="checkpoints/checkpoint_461260.safetensors", |
|
|
help="Path to the model checkpoint" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
demo = build_interface(args.checkpoint) |
|
|
|
|
|
demo.launch() |
|
|
|