ayf3's picture
Upload app.py with huggingface_hub
44ebeaa verified
#!/usr/bin/env python3
"""
NumberBlocks One Voice Cloning Space - VoxCPM V5
Fix: float32 on CPU + monkey-patch SDPA mask shape for CPU compatibility
Root cause of "Dimension out of range":
MiniCPM4's Attention.forward_step creates a 1D attn_mask but SDPA on CPU
expects at least 2D for proper broadcasting with GQA (Grouped Query Attention).
On GPU, the flash-attention backend handles this; on CPU the math backend does not.
"""
import os
import gradio as gr
import tempfile
import soundfile as sf
import traceback
from pathlib import Path
import torch
import torch.nn.functional as F
HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGINGFACE_TOKEN"))
# ──────────────────────────────────────────────────────────────────
# Monkey-patch: fix SDPA mask shape for CPU
# ──────────────────────────────────────────────────────────────────
_original_sdpa = F.scaled_dot_product_attention
def _cpu_safe_sdpa(query, key, value, attn_mask=None, **kwargs):
"""Wrapper that fixes 1D attn_mask for CPU SDPA."""
if attn_mask is not None and attn_mask.dim() == 1 and not torch.cuda.is_available():
# attn_mask is (seq_len,) but SDPA needs (B, H, L, S) or broadcastable
# query shape: (B, H, L, D), key shape: (B, H_kv, S, D)
B, H, L, D = query.shape
S = key.shape[2]
# Reshape 1D mask to (1, 1, 1, S) for proper broadcasting
attn_mask = attn_mask.view(1, 1, 1, S).expand(B, H, L, S)
return _original_sdpa(query, key, value, attn_mask=attn_mask, **kwargs)
# Apply the patch globally
F.scaled_dot_product_attention = _cpu_safe_sdpa
print("✅ Patched scaled_dot_product_attention for CPU mask shape fix")
def load_model():
"""加载 VoxCPM 模型"""
try:
from voxcpm import VoxCPM
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading VoxCPM model on {device}...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Load model (optimize=False to avoid torch.compile issues)
model = VoxCPM.from_pretrained("openbmb/VoxCPM2", load_denoiser=False, optimize=False)
# CRITICAL FIX: Force float32 on CPU
if device == "cpu":
print("Converting model to float32 for CPU compatibility...")
# Step 1: Change config dtype so _inference creates float32 tensors
if hasattr(model.tts_model, 'config'):
old_dtype = model.tts_model.config.dtype
model.tts_model.config.dtype = "float32"
print(f" config.dtype: {old_dtype} -> float32")
# Step 2: Convert all model parameters and buffers to float32
model.tts_model = model.tts_model.to(torch.float32)
# Step 3: Fix KV caches (created in __init__ with old dtype)
if hasattr(model.tts_model, 'base_lm') and hasattr(model.tts_model.base_lm, 'kv_cache'):
if model.tts_model.base_lm.kv_cache is not None:
model.tts_model.base_lm.kv_cache.kv_cache = model.tts_model.base_lm.kv_cache.kv_cache.to(torch.float32)
print(" base_lm KV cache -> float32")
if hasattr(model.tts_model, 'residual_lm') and hasattr(model.tts_model.residual_lm, 'kv_cache'):
if model.tts_model.residual_lm.kv_cache is not None:
model.tts_model.residual_lm.kv_cache.kv_cache = model.tts_model.residual_lm.kv_cache.kv_cache.to(torch.float32)
print(" residual_lm KV cache -> float32")
print("Model conversion to float32 complete!")
print("Model loaded successfully!")
return model, device, None
except Exception as e:
print(f"Error loading model: {e}")
traceback.print_exc()
return None, "cpu", str(e)
# 全局模型状态
MODEL_STATE = {
"model": None,
"device": "cpu",
"error": None,
"loading": False
}
def ensure_model():
"""确保模型已加载"""
if MODEL_STATE["model"] is None and not MODEL_STATE["loading"]:
MODEL_STATE["loading"] = True
try:
model, device, error = load_model()
MODEL_STATE["model"] = model
MODEL_STATE["device"] = device
MODEL_STATE["error"] = error
except Exception as e:
MODEL_STATE["error"] = str(e)
traceback.print_exc()
finally:
MODEL_STATE["loading"] = False
return MODEL_STATE["model"], MODEL_STATE["device"], MODEL_STATE["error"]
def generate_audio(text, reference_audio, cfg_value=2.0, steps=10):
"""生成音频"""
if not text or not reference_audio:
return None, "❌ 请输入文本和参考音频"
if not text.strip():
return None, "❌ 文本不能为空"
try:
model, device, error = ensure_model()
if error:
return None, f"❌ 模型加载失败: {error}"
if model is None:
return None, "❌ 模型正在加载中,请稍候..."
# 读取参考音频
ref_audio, sr = sf.read(reference_audio)
# 如果是立体声,转换为单声道
if len(ref_audio.shape) > 1:
ref_audio = ref_audio[:, 0]
# 保存到临时文件
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, ref_audio, sr)
ref_path = tmp.name
print(f"Generating with text: {text[:50]}...")
print(f"Reference audio: {len(ref_audio)/sr:.2f}s at {sr}Hz")
# 生成音频
import time
t0 = time.time()
wav = model.generate(
text=text,
reference_wav_path=ref_path,
cfg_value=float(cfg_value),
inference_timesteps=int(steps),
)
elapsed = time.time() - t0
# 保存输出
sample_rate = model.tts_model.sample_rate
output_path = "/tmp/voxcpm_output.wav"
sf.write(output_path, wav, sample_rate)
duration = len(wav) / sample_rate
msg = f"✅ 生成成功! 时长: {duration:.2f}s, 耗时: {elapsed:.1f}s, 设备: {device}"
print(msg)
# 清理临时文件
os.unlink(ref_path)
return output_path, msg
except Exception as e:
error_msg = f"❌ 生成失败: {str(e)}"
print(f"Error: {e}")
traceback.print_exc()
return None, error_msg
# 预设文本
PRESET_TEXTS = {
"问候": "Hello! I am One! I am the first Numberblock, and I love being number one!",
"计数": "One, two, three, four, five! Counting is so much fun! I can count all the way to ten!",
"情感": "Sometimes I feel a little lonely being just one, but then I remember that one is the start of everything!",
}
# 创建 Gradio 界面
with gr.Blocks(title="NumberBlocks One Voice Cloning") as demo:
gr.Markdown("# 🎭 NumberBlocks One Voice Cloning (VoxCPM V5)")
gr.Markdown("### 使用 VoxCPM 2 模型克隆 One 的声音")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="输入文本",
placeholder="输入要合成的文本...",
lines=3,
value=PRESET_TEXTS["问候"]
)
with gr.Row():
for name, txt in PRESET_TEXTS.items():
gr.Button(name).click(lambda t=txt: t, inputs=None, outputs=text_input)
with gr.Column():
ref_audio_input = gr.Audio(
label="参考音频 (One 的声音)",
type="filepath"
)
with gr.Row():
cfg_slider = gr.Slider(
minimum=0.5,
maximum=5.0,
value=2.0,
step=0.1,
label="CFG Value (越高越像参考音色)"
)
steps_slider = gr.Slider(
minimum=5,
maximum=50,
value=10,
step=1,
label="推理步数 (越高质量越好但越慢)"
)
generate_btn = gr.Button("🎙️ 生成音频", variant="primary")
with gr.Row():
output_audio = gr.Audio(label="生成结果")
status_msg = gr.Markdown(value="⏸️ 等待生成...")
generate_btn.click(
fn=generate_audio,
inputs=[text_input, ref_audio_input, cfg_slider, steps_slider],
outputs=[output_audio, status_msg]
)
gr.Markdown("---")
gr.Markdown("### 说明")
gr.Markdown("""
- **参考音频**: 上传 One 的声音片段(建议 5-15 秒清晰语音)
- **CFG Value**: 控制音色相似度,默认 2.0,越高越像参考音色
- **推理步数**: 默认 10,越高质量越好但生成越慢
- **模型**: VoxCPM 2 (openbmb/VoxCPM2)
- **V5**: CPU float32 + SDPA mask shape fix
""")
if __name__ == "__main__":
import threading
def preload():
print("Preloading VoxCPM model...")
ensure_model()
threading.Thread(target=preload, daemon=True).start()
demo.launch(server_name="0.0.0.0", server_port=7860)