Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| HuggingFace Space Gradio应用 | |
| RVC v2 训练界面 | |
| """ | |
| import os | |
| import gradio as gr | |
| import subprocess | |
| import threading | |
| from pathlib import Path | |
| # 创建必要的目录 | |
| Path("checkpoints").mkdir(exist_ok=True) | |
| Path("data/training_data").mkdir(parents=True, exist_ok=True) | |
| Path("logs").mkdir(exist_ok=True) | |
| # 训练状态 | |
| training_status = "未开始" | |
| training_log = [] | |
| def get_training_status(): | |
| """获取训练状态""" | |
| checkpoints_dir = Path("checkpoints") | |
| logs_dir = Path("logs") | |
| status_lines = [] | |
| status_lines.append(f"## 当前状态: {training_status}") | |
| # 检查checkpoints | |
| if checkpoints_dir.exists(): | |
| checkpoints = list(checkpoints_dir.glob("*.pth")) | |
| if checkpoints: | |
| status_lines.append(f"### 模型文件: {len(checkpoints)} 个") | |
| for ckpt in checkpoints: | |
| size_mb = ckpt.stat().st_size / 1024 / 1024 | |
| status_lines.append(f"- {ckpt.name}: {size_mb:.2f} MB") | |
| else: | |
| status_lines.append("### 模型文件: 尚未生成") | |
| else: | |
| status_lines.append("### 模型文件: 目录不存在") | |
| # 检查日志 | |
| if logs_dir.exists(): | |
| log_files = list(logs_dir.glob("*.txt")) | |
| if log_files: | |
| status_lines.append(f"### 日志文件: {len(log_files)} 个") | |
| latest_log = max(log_files, key=lambda p: p.stat().st_mtime) | |
| status_lines.append(f"最新日志: {latest_log.name}") | |
| else: | |
| status_lines.append("### 日志文件: 尚未生成") | |
| else: | |
| status_lines.append("### 日志文件: 目录不存在") | |
| # 检查数据 | |
| data_dir = Path("data/training_data/audio") | |
| if data_dir.exists(): | |
| audio_files = list(data_dir.glob("*.wav")) + list(data_dir.glob("*.mp3")) | |
| status_lines.append(f"### 训练数据: {len(audio_files)} 个音频文件") | |
| else: | |
| status_lines.append("### 训练数据: 尚未下载") | |
| return "\n".join(status_lines) | |
| def start_training(): | |
| """启动训练""" | |
| global training_status | |
| if training_status == "正在训练": | |
| return "⚠️ 训练已在进行中" | |
| training_status = "正在训练" | |
| # 在后台线程中运行训练 | |
| def train_thread(): | |
| global training_status | |
| try: | |
| # 运行训练脚本 | |
| result = subprocess.run( | |
| ["python", "scripts/train_rvc_v2_fixed.py"], | |
| capture_output=True, | |
| text=True, | |
| timeout=14400 # 4小时超时 | |
| ) | |
| # 保存日志 | |
| with open("logs/training_log.txt", "w", encoding="utf-8") as f: | |
| f.write("=== STDOUT ===\n") | |
| f.write(result.stdout) | |
| f.write("\n=== STDERR ===\n") | |
| f.write(result.stderr) | |
| if result.returncode == 0: | |
| training_status = "训练完成" | |
| else: | |
| training_status = f"训练失败: {result.returncode}" | |
| except Exception as e: | |
| training_status = f"训练异常: {str(e)}" | |
| thread = threading.Thread(target=train_thread) | |
| thread.daemon = True | |
| thread.start() | |
| return "✅ 训练已启动" | |
| def download_data(): | |
| """下载训练数据""" | |
| try: | |
| result = subprocess.run( | |
| ["python", "scripts/prepare_training_data_v2.py"], | |
| capture_output=True, | |
| text=True, | |
| timeout=600 # 10分钟超时 | |
| ) | |
| if result.returncode == 0: | |
| return "✅ 数据下载完成\n\n" + result.stdout | |
| else: | |
| return "❌ 数据下载失败\n\n" + result.stderr | |
| except Exception as e: | |
| return f"❌ 数据下载异常: {str(e)}" | |
| # 创建Gradio界面 | |
| with gr.Blocks(title="NumberBlocks One RVC v2 训练") as demo: | |
| gr.Markdown("# 🎤 NumberBlocks One RVC v2 训练") | |
| gr.Markdown("## 使用9小时One语音数据训练RVC v2模型") | |
| with gr.Tab("训练状态"): | |
| status_output = gr.Markdown(value=get_training_status()) | |
| refresh_btn = gr.Button("刷新状态") | |
| refresh_btn.click(get_training_status, outputs=status_output) | |
| with gr.Tab("开始训练"): | |
| gr.Markdown("### 步骤1: 下载训练数据") | |
| download_btn = gr.Button("下载数据") | |
| download_output = gr.Textbox(label="下载结果", lines=10) | |
| download_btn.click(download_data, outputs=download_output) | |
| gr.Markdown("### 步骤2: 开始训练") | |
| train_btn = gr.Button("开始训练", variant="primary") | |
| train_output = gr.Textbox(label="训练状态") | |
| train_btn.click(start_training, outputs=train_output) | |
| with gr.Tab("训练配置"): | |
| gr.Markdown("### 训练参数") | |
| gr.Markdown(""" | |
| - **模型**: RVC v2 | |
| - **数据集**: ayf3/numberblocks-audio (9小时音频) | |
| - **硬件**: T4 GPU (免费) | |
| - **训练时长**: 2-3小时 | |
| - **Batch Size**: 8 | |
| - **Epochs**: 500 | |
| - **学习率**: 0.0001 | |
| - **优化器**: AdamW | |
| """) | |
| # 启动Gradio应用 | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |