Spaces:
Sleeping
Sleeping
File size: 9,505 Bytes
4428f5f 44ebeaa 4428f5f 71ee5ef 8b5510c 71ee5ef 8b5510c 71ee5ef 44ebeaa d39efb0 8b5510c 58fd207 44ebeaa 8b5510c c383166 8b5510c 966d861 8b5510c 966d861 d365d5c 966d861 d365d5c 966d861 d365d5c 966d861 d365d5c 966d861 8b5510c 966d861 8b5510c 71ee5ef 8b5510c 71ee5ef 8b5510c 52c31e0 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 966d861 8b5510c 44ebeaa 8b5510c 6620877 8b5510c 52c31e0 8b5510c 966d861 8b5510c 6620877 8b5510c 6620877 8b5510c 6620877 8b5510c 966d861 8b5510c 44ebeaa 8b5510c c92f551 6699d9a 8b5510c 7dcdfba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | #!/usr/bin/env python3
"""
NumberBlocks One Voice Cloning Space - VoxCPM V5
Fix: float32 on CPU + monkey-patch SDPA mask shape for CPU compatibility
Root cause of "Dimension out of range":
MiniCPM4's Attention.forward_step creates a 1D attn_mask but SDPA on CPU
expects at least 2D for proper broadcasting with GQA (Grouped Query Attention).
On GPU, the flash-attention backend handles this; on CPU the math backend does not.
"""
import os
import gradio as gr
import tempfile
import soundfile as sf
import traceback
from pathlib import Path
import torch
import torch.nn.functional as F
HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGINGFACE_TOKEN"))
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Monkey-patch: fix SDPA mask shape for CPU
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
_original_sdpa = F.scaled_dot_product_attention
def _cpu_safe_sdpa(query, key, value, attn_mask=None, **kwargs):
"""Wrapper that fixes 1D attn_mask for CPU SDPA."""
if attn_mask is not None and attn_mask.dim() == 1 and not torch.cuda.is_available():
# attn_mask is (seq_len,) but SDPA needs (B, H, L, S) or broadcastable
# query shape: (B, H, L, D), key shape: (B, H_kv, S, D)
B, H, L, D = query.shape
S = key.shape[2]
# Reshape 1D mask to (1, 1, 1, S) for proper broadcasting
attn_mask = attn_mask.view(1, 1, 1, S).expand(B, H, L, S)
return _original_sdpa(query, key, value, attn_mask=attn_mask, **kwargs)
# Apply the patch globally
F.scaled_dot_product_attention = _cpu_safe_sdpa
print("โ
Patched scaled_dot_product_attention for CPU mask shape fix")
def load_model():
"""ๅ ่ฝฝ VoxCPM ๆจกๅ"""
try:
from voxcpm import VoxCPM
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading VoxCPM model on {device}...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Load model (optimize=False to avoid torch.compile issues)
model = VoxCPM.from_pretrained("openbmb/VoxCPM2", load_denoiser=False, optimize=False)
# CRITICAL FIX: Force float32 on CPU
if device == "cpu":
print("Converting model to float32 for CPU compatibility...")
# Step 1: Change config dtype so _inference creates float32 tensors
if hasattr(model.tts_model, 'config'):
old_dtype = model.tts_model.config.dtype
model.tts_model.config.dtype = "float32"
print(f" config.dtype: {old_dtype} -> float32")
# Step 2: Convert all model parameters and buffers to float32
model.tts_model = model.tts_model.to(torch.float32)
# Step 3: Fix KV caches (created in __init__ with old dtype)
if hasattr(model.tts_model, 'base_lm') and hasattr(model.tts_model.base_lm, 'kv_cache'):
if model.tts_model.base_lm.kv_cache is not None:
model.tts_model.base_lm.kv_cache.kv_cache = model.tts_model.base_lm.kv_cache.kv_cache.to(torch.float32)
print(" base_lm KV cache -> float32")
if hasattr(model.tts_model, 'residual_lm') and hasattr(model.tts_model.residual_lm, 'kv_cache'):
if model.tts_model.residual_lm.kv_cache is not None:
model.tts_model.residual_lm.kv_cache.kv_cache = model.tts_model.residual_lm.kv_cache.kv_cache.to(torch.float32)
print(" residual_lm KV cache -> float32")
print("Model conversion to float32 complete!")
print("Model loaded successfully!")
return model, device, None
except Exception as e:
print(f"Error loading model: {e}")
traceback.print_exc()
return None, "cpu", str(e)
# ๅ
จๅฑๆจกๅ็ถๆ
MODEL_STATE = {
"model": None,
"device": "cpu",
"error": None,
"loading": False
}
def ensure_model():
"""็กฎไฟๆจกๅๅทฒๅ ่ฝฝ"""
if MODEL_STATE["model"] is None and not MODEL_STATE["loading"]:
MODEL_STATE["loading"] = True
try:
model, device, error = load_model()
MODEL_STATE["model"] = model
MODEL_STATE["device"] = device
MODEL_STATE["error"] = error
except Exception as e:
MODEL_STATE["error"] = str(e)
traceback.print_exc()
finally:
MODEL_STATE["loading"] = False
return MODEL_STATE["model"], MODEL_STATE["device"], MODEL_STATE["error"]
def generate_audio(text, reference_audio, cfg_value=2.0, steps=10):
"""็ๆ้ณ้ข"""
if not text or not reference_audio:
return None, "โ ่ฏท่พๅ
ฅๆๆฌๅๅ่้ณ้ข"
if not text.strip():
return None, "โ ๆๆฌไธ่ฝไธบ็ฉบ"
try:
model, device, error = ensure_model()
if error:
return None, f"โ ๆจกๅๅ ่ฝฝๅคฑ่ดฅ: {error}"
if model is None:
return None, "โ ๆจกๅๆญฃๅจๅ ่ฝฝไธญ๏ผ่ฏท็จๅ..."
# ่ฏปๅๅ่้ณ้ข
ref_audio, sr = sf.read(reference_audio)
# ๅฆๆๆฏ็ซไฝๅฃฐ๏ผ่ฝฌๆขไธบๅๅฃฐ้
if len(ref_audio.shape) > 1:
ref_audio = ref_audio[:, 0]
# ไฟๅญๅฐไธดๆถๆไปถ
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, ref_audio, sr)
ref_path = tmp.name
print(f"Generating with text: {text[:50]}...")
print(f"Reference audio: {len(ref_audio)/sr:.2f}s at {sr}Hz")
# ็ๆ้ณ้ข
import time
t0 = time.time()
wav = model.generate(
text=text,
reference_wav_path=ref_path,
cfg_value=float(cfg_value),
inference_timesteps=int(steps),
)
elapsed = time.time() - t0
# ไฟๅญ่พๅบ
sample_rate = model.tts_model.sample_rate
output_path = "/tmp/voxcpm_output.wav"
sf.write(output_path, wav, sample_rate)
duration = len(wav) / sample_rate
msg = f"โ
็ๆๆๅ! ๆถ้ฟ: {duration:.2f}s, ่ๆถ: {elapsed:.1f}s, ่ฎพๅค: {device}"
print(msg)
# ๆธ
็ไธดๆถๆไปถ
os.unlink(ref_path)
return output_path, msg
except Exception as e:
error_msg = f"โ ็ๆๅคฑ่ดฅ: {str(e)}"
print(f"Error: {e}")
traceback.print_exc()
return None, error_msg
# ้ข่ฎพๆๆฌ
PRESET_TEXTS = {
"้ฎๅ": "Hello! I am One! I am the first Numberblock, and I love being number one!",
"่ฎกๆฐ": "One, two, three, four, five! Counting is so much fun! I can count all the way to ten!",
"ๆ
ๆ": "Sometimes I feel a little lonely being just one, but then I remember that one is the start of everything!",
}
# ๅๅปบ Gradio ็้ข
with gr.Blocks(title="NumberBlocks One Voice Cloning") as demo:
gr.Markdown("# ๐ญ NumberBlocks One Voice Cloning (VoxCPM V5)")
gr.Markdown("### ไฝฟ็จ VoxCPM 2 ๆจกๅๅ
้ One ็ๅฃฐ้ณ")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="่พๅ
ฅๆๆฌ",
placeholder="่พๅ
ฅ่ฆๅๆ็ๆๆฌ...",
lines=3,
value=PRESET_TEXTS["้ฎๅ"]
)
with gr.Row():
for name, txt in PRESET_TEXTS.items():
gr.Button(name).click(lambda t=txt: t, inputs=None, outputs=text_input)
with gr.Column():
ref_audio_input = gr.Audio(
label="ๅ่้ณ้ข (One ็ๅฃฐ้ณ)",
type="filepath"
)
with gr.Row():
cfg_slider = gr.Slider(
minimum=0.5,
maximum=5.0,
value=2.0,
step=0.1,
label="CFG Value (่ถ้ซ่ถๅๅ่้ณ่ฒ)"
)
steps_slider = gr.Slider(
minimum=5,
maximum=50,
value=10,
step=1,
label="ๆจ็ๆญฅๆฐ (่ถ้ซ่ดจ้่ถๅฅฝไฝ่ถๆ
ข)"
)
generate_btn = gr.Button("๐๏ธ ็ๆ้ณ้ข", variant="primary")
with gr.Row():
output_audio = gr.Audio(label="็ๆ็ปๆ")
status_msg = gr.Markdown(value="โธ๏ธ ็ญๅพ
็ๆ...")
generate_btn.click(
fn=generate_audio,
inputs=[text_input, ref_audio_input, cfg_slider, steps_slider],
outputs=[output_audio, status_msg]
)
gr.Markdown("---")
gr.Markdown("### ่ฏดๆ")
gr.Markdown("""
- **ๅ่้ณ้ข**: ไธไผ One ็ๅฃฐ้ณ็ๆฎต๏ผๅปบ่ฎฎ 5-15 ็งๆธ
ๆฐ่ฏญ้ณ๏ผ
- **CFG Value**: ๆงๅถ้ณ่ฒ็ธไผผๅบฆ๏ผ้ป่ฎค 2.0๏ผ่ถ้ซ่ถๅๅ่้ณ่ฒ
- **ๆจ็ๆญฅๆฐ**: ้ป่ฎค 10๏ผ่ถ้ซ่ดจ้่ถๅฅฝไฝ็ๆ่ถๆ
ข
- **ๆจกๅ**: VoxCPM 2 (openbmb/VoxCPM2)
- **V5**: CPU float32 + SDPA mask shape fix
""")
if __name__ == "__main__":
import threading
def preload():
print("Preloading VoxCPM model...")
ensure_model()
threading.Thread(target=preload, daemon=True).start()
demo.launch(server_name="0.0.0.0", server_port=7860)
|