rvc-cpu-trainer / app.py
ayf3's picture
Upload app.py with huggingface_hub
3ae8666 verified
#!/usr/bin/env python3
"""
HuggingFace Space Gradio应用
RVC v2 训练界面
"""
import os
import gradio as gr
import subprocess
import threading
from pathlib import Path
# 创建必要的目录
Path("checkpoints").mkdir(exist_ok=True)
Path("data/training_data").mkdir(parents=True, exist_ok=True)
Path("logs").mkdir(exist_ok=True)
# 训练状态
training_status = "未开始"
training_log = []
def get_training_status():
"""获取训练状态"""
checkpoints_dir = Path("checkpoints")
logs_dir = Path("logs")
status_lines = []
status_lines.append(f"## 当前状态: {training_status}")
# 检查checkpoints
if checkpoints_dir.exists():
checkpoints = list(checkpoints_dir.glob("*.pth"))
if checkpoints:
status_lines.append(f"### 模型文件: {len(checkpoints)} 个")
for ckpt in checkpoints:
size_mb = ckpt.stat().st_size / 1024 / 1024
status_lines.append(f"- {ckpt.name}: {size_mb:.2f} MB")
else:
status_lines.append("### 模型文件: 尚未生成")
else:
status_lines.append("### 模型文件: 目录不存在")
# 检查日志
if logs_dir.exists():
log_files = list(logs_dir.glob("*.txt"))
if log_files:
status_lines.append(f"### 日志文件: {len(log_files)} 个")
latest_log = max(log_files, key=lambda p: p.stat().st_mtime)
status_lines.append(f"最新日志: {latest_log.name}")
else:
status_lines.append("### 日志文件: 尚未生成")
else:
status_lines.append("### 日志文件: 目录不存在")
# 检查数据
data_dir = Path("data/training_data/audio")
if data_dir.exists():
audio_files = list(data_dir.glob("*.wav")) + list(data_dir.glob("*.mp3"))
status_lines.append(f"### 训练数据: {len(audio_files)} 个音频文件")
else:
status_lines.append("### 训练数据: 尚未下载")
return "\n".join(status_lines)
def start_training():
"""启动训练"""
global training_status
if training_status == "正在训练":
return "⚠️ 训练已在进行中"
training_status = "正在训练"
# 在后台线程中运行训练
def train_thread():
global training_status
try:
# 运行训练脚本
result = subprocess.run(
["python", "scripts/train_rvc_v2_fixed.py"],
capture_output=True,
text=True,
timeout=14400 # 4小时超时
)
# 保存日志
with open("logs/training_log.txt", "w", encoding="utf-8") as f:
f.write("=== STDOUT ===\n")
f.write(result.stdout)
f.write("\n=== STDERR ===\n")
f.write(result.stderr)
if result.returncode == 0:
training_status = "训练完成"
else:
training_status = f"训练失败: {result.returncode}"
except Exception as e:
training_status = f"训练异常: {str(e)}"
thread = threading.Thread(target=train_thread)
thread.daemon = True
thread.start()
return "✅ 训练已启动"
def download_data():
"""下载训练数据"""
try:
result = subprocess.run(
["python", "scripts/prepare_training_data_v2.py"],
capture_output=True,
text=True,
timeout=600 # 10分钟超时
)
if result.returncode == 0:
return "✅ 数据下载完成\n\n" + result.stdout
else:
return "❌ 数据下载失败\n\n" + result.stderr
except Exception as e:
return f"❌ 数据下载异常: {str(e)}"
# 创建Gradio界面
with gr.Blocks(title="NumberBlocks One RVC v2 训练") as demo:
gr.Markdown("# 🎤 NumberBlocks One RVC v2 训练")
gr.Markdown("## 使用9小时One语音数据训练RVC v2模型")
with gr.Tab("训练状态"):
status_output = gr.Markdown(value=get_training_status())
refresh_btn = gr.Button("刷新状态")
refresh_btn.click(get_training_status, outputs=status_output)
with gr.Tab("开始训练"):
gr.Markdown("### 步骤1: 下载训练数据")
download_btn = gr.Button("下载数据")
download_output = gr.Textbox(label="下载结果", lines=10)
download_btn.click(download_data, outputs=download_output)
gr.Markdown("### 步骤2: 开始训练")
train_btn = gr.Button("开始训练", variant="primary")
train_output = gr.Textbox(label="训练状态")
train_btn.click(start_training, outputs=train_output)
with gr.Tab("训练配置"):
gr.Markdown("### 训练参数")
gr.Markdown("""
- **模型**: RVC v2
- **数据集**: ayf3/numberblocks-audio (9小时音频)
- **硬件**: T4 GPU (免费)
- **训练时长**: 2-3小时
- **Batch Size**: 8
- **Epochs**: 500
- **学习率**: 0.0001
- **优化器**: AdamW
""")
# 启动Gradio应用
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)