rvc-cpu-trainer / scripts /prepare_training_data_v2.py
ayf3's picture
Upload scripts/prepare_training_data_v2.py with huggingface_hub
bcaa58c verified
#!/usr/bin/env python3
"""
准备RVC v2训练数据 - 简化版
使用snapshot_download一次性下载整个Dataset
"""
import os
from pathlib import Path
from huggingface_hub import snapshot_download
import subprocess
import json
from tqdm import tqdm
# 配置
DATASET_ID = "ayf3/numberblocks-audio"
OUTPUT_DIR = Path("data/training_data")
AUDIO_DIR = OUTPUT_DIR / "audio"
METADATA_FILE = OUTPUT_DIR / "metadata.json"
# HuggingFace Token - 从环境变量或缓存读取
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.environ.get("HF_TOKEN", None))
def create_directories():
"""创建必要的目录"""
AUDIO_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"✅ 目录创建完成: {AUDIO_DIR}")
def download_audio_files():
"""从HuggingFace Dataset下载所有音频文件"""
print(f"📥 开始下载音频文件...")
print(f"📦 Dataset: {DATASET_ID}")
try:
# 使用snapshot_download一次性下载整个repo
snapshot_download(
repo_id=DATASET_ID,
repo_type="dataset",
token=HF_TOKEN,
local_dir=str(AUDIO_DIR),
local_dir_use_symlinks=False
)
print(f"✅ 下载完成")
except Exception as e:
print(f"❌ 下载失败: {e}")
return False
return True
def analyze_audio_files():
"""分析音频文件(时长、采样率、质量)"""
print(f"\n🔍 分析音频文件...")
audio_files = list(AUDIO_DIR.glob("*.wav")) + list(AUDIO_DIR.glob("*.mp3")) + list(AUDIO_DIR.glob("*.m4a"))
print(f"📊 找到 {len(audio_files)} 个本地音频文件")
if len(audio_files) == 0:
print("❌ 没有找到音频文件")
return None
metadata = []
total_duration = 0
print(f"\n处理中...")
for i, audio_file in enumerate(audio_files, 1):
try:
# 使用ffprobe获取音频信息
cmd = [
"ffprobe",
"-v", "error",
"-show_entries", "format=duration",
"-show_entries", "stream=sample_rate,channels",
"-of", "json",
str(audio_file)
]
result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=10)
info = json.loads(result.stdout)
duration = float(info["format"]["duration"])
sample_rate = int(info["streams"][0]["sample_rate"])
channels = int(info["streams"][0]["channels"])
total_duration += duration
file_metadata = {
"filename": audio_file.name,
"duration": duration,
"sample_rate": sample_rate,
"channels": channels,
"size": audio_file.stat().st_size
}
metadata.append(file_metadata)
if i <= 10 or i == len(audio_files):
print(f" [{i:3d}/{len(audio_files)}] {audio_file.name}: {duration:6.2f}s, {sample_rate}Hz, {channels}ch")
except Exception as e:
print(f" ❌ [{i}/{len(audio_files)}] 分析失败: {audio_file.name}, 错误: {e}")
# 保存元数据
with open(METADATA_FILE, 'w', encoding='utf-8') as f:
json.dump({
"total_files": len(metadata),
"total_duration": total_duration,
"total_duration_hours": round(total_duration / 3600, 2),
"files": metadata
}, f, indent=2, ensure_ascii=False)
print(f"\n✅ 分析完成:")
print(f" - 文件数: {len(metadata)}")
print(f" - 总时长: {total_duration / 3600:.2f} 小时 ({total_duration / 60:.1f} 分钟)")
print(f" - 元数据保存: {METADATA_FILE}")
return metadata
def main():
"""主函数"""
print("=" * 60)
print("🎤 准备RVC v2训练数据(简化版)")
print("=" * 60)
# 步骤1: 创建目录
create_directories()
# 步骤2: 下载音频文件(如果本地已有文件,跳过)
audio_files = list(AUDIO_DIR.glob("*.wav")) + list(AUDIO_DIR.glob("*.mp3")) + list(AUDIO_DIR.glob("*.m4a"))
if len(audio_files) > 0:
print(f"📂 本地已有 {len(audio_files)} 个音频文件,跳过下载")
else:
success = download_audio_files()
if not success:
print("❌ 下载失败,退出")
return
# 步骤3: 分析音频文件
metadata = analyze_audio_files()
if metadata:
print("\n" + "=" * 60)
print("✅ 数据准备完成!")
print("=" * 60)
else:
print("\n" + "=" * 60)
print("❌ 数据准备失败")
print("=" * 60)
if __name__ == "__main__":
main()