Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,990 Bytes
e00b5e2 746b11d e00b5e2 746b11d fa5a4cf 746b11d e00b5e2 746b11d e00b5e2 746b11d e00b5e2 746b11d e00b5e2 8340c5c 746b11d e00b5e2 746b11d e00b5e2 746b11d e00b5e2 746b11d e00b5e2 746b11d e00b5e2 746b11d e00b5e2 746b11d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import numpy as np
import torch
import soundfile as sf
import librosa
import gradio as gr
import spaces # For ZeroGPU
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from xcodec2.configuration_bigcodec import BigCodecConfig
from xcodec2.modeling_xcodec2 import XCodec2Model
# ====== Settings ======
# Use only the FT (44.1 kHz) version
FT_REPO = os.getenv("FT_REPO", "NandemoGHS/Anime-XCodec2-44.1kHz-v2")
TARGET_SR = 16000 # XCodec2 expects 16 kHz input
MAX_SECONDS_DEFAULT = 30 # Default max duration (seconds)
def _ensure_models():
"""Load the FT model to CPU once, and reuse across requests."""
global _model_ft
if _model_ft is None:
ckpt_path = hf_hub_download(repo_id=FT_REPO, filename="model.safetensors")
ckpt = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for k in f.keys():
ckpt[k.replace(".beta", ".bias")] = f.get_tensor(k)
codec_config = BigCodecConfig.from_pretrained(FT_REPO)
_model_ft = XCodec2Model.from_pretrained(
None, config=codec_config, state_dict=ckpt
)
_model_ft.eval().to("cpu")
# ====== Globals (lazy CPU load; move to GPU only during inference) ======
_model_ft = None
_ensure_models()
def _load_audio(filepath: str, max_seconds: int):
"""
Load audio (wav/flac/ogg/mp3), convert to mono, resample to 16 kHz,
trim to the given max length (from the beginning), and return torch.Tensor (1, T).
"""
# Try soundfile first, then fall back to librosa
try:
wav, sr = sf.read(filepath, dtype="float32", always_2d=False)
except Exception:
wav, sr = librosa.load(filepath, sr=None, mono=False)
wav = np.asarray(wav, dtype=np.float32)
# Mono
if wav.ndim == 2:
if wav.shape[1] in (1, 2): # (frames, ch)
wav = wav.mean(axis=1)
else: # (ch, frames)
wav = wav.mean(axis=0)
elif wav.ndim > 2:
wav = np.mean(wav, axis=tuple(range(1, wav.ndim)))
# Resample to 16 kHz
if sr != TARGET_SR:
wav = librosa.resample(wav, orig_sr=sr, target_sr=TARGET_SR)
sr = TARGET_SR
# Length cap
if max_seconds is None or max_seconds <= 0:
max_seconds = MAX_SECONDS_DEFAULT
max_len = int(sr * max_seconds)
if wav.shape[0] > max_len:
wav = wav[:max_len]
# Light safety normalization
peak = np.max(np.abs(wav))
if peak > 1.0:
wav = wav / (peak + 1e-8)
wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # (1, T)
return wav_tensor, sr
def _codes_to_tensor(codes, device):
"""
Normalize the output of xcodec2.encode_code to a tensor with shape (1, 1, N).
Handles version differences where the return type/shape may vary.
"""
if isinstance(codes, torch.Tensor):
return codes.to(device)
try:
t = torch.as_tensor(codes[0][0], device=device)
return t.unsqueeze(0).unsqueeze(0) if t.ndim == 1 else t
except Exception:
return torch.as_tensor(codes, device=device)
def _reconstruct(model: XCodec2Model, waveform: torch.Tensor, device: str) -> np.ndarray:
"""Encode→decode with XCodec2 to get a reconstructed waveform (np.float32, clipped to [-1, 1])."""
with torch.inference_mode():
wave = waveform.to(device)
codes = model.encode_code(input_waveform=wave)
codes_t = _codes_to_tensor(codes, device=device)
recon = model.decode_code(codes_t) # (1, 1, T')
recon_np = recon.squeeze().detach().cpu().numpy().astype(np.float32)
recon_np = np.clip(recon_np, -1.0, 1.0)
return recon_np
@spaces.GPU(duration=60) # ZeroGPU: reserve GPU only during this function call
def run(audio_path, max_seconds):
if audio_path is None:
raise gr.Error("Please upload an audio file.")
_ensure_models()
waveform, sr = _load_audio(audio_path, max_seconds)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Fine-tuned
ft = _model_ft.to(device)
recon_ft = _reconstruct(ft, waveform, device)
# Gradio Audio expects (sample_rate, np.ndarray)
# 44.1 kHz version returns 44.1kHz sr
return (44100, recon_ft)
# ====== UI ======
# Modified DESCRIPTION for the single-model demo
DESCRIPTION = """
# Anime-XCodec2-44.1kHz-v2 Reconstruction Demo
This demo reconstructs audio using the **44.1 kHz fine-tuned (NandemoGHS/Anime-XCodec2-44.1kHz-v2)** model.
- Supported inputs: wav / flac / ogg / mp3
- Input is automatically converted to **16 kHz** (as required by XCodec2).
- ZeroGPU ready. If no GPU is available, it falls back to CPU (slower).
"""
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
inp = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload an audio file",
waveform_options={"show_controls": True}
)
max_sec = gr.Slider(
3, 60, value=MAX_SECONDS_DEFAULT, step=1,
label="Max length (seconds)",
info="If the input is longer, only the first N seconds will be processed."
)
run_btn = gr.Button("Run", variant="primary")
gr.Markdown(
f"**44.1 kHz model**: `{FT_REPO}`\n"
f"**Inference device**: auto (GPU on ZeroGPU)"
)
with gr.Column(scale=1):
# Single audio output
out_ft = gr.Audio(
label="44.1kHz reconstruction (NandemoGHS/Anime-XCodec2-44.1kHz-v2)",
show_download_button=True, format="wav"
)
# Click action points to the single output
run_btn.click(run, inputs=[inp, max_sec], outputs=[out_ft])
# In Spaces, explicit launch is optional
if __name__ == "__main__":
demo.queue(max_size=8).launch() |