File size: 3,653 Bytes
4204217 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | # -*- coding: utf-8 -*-
"""
RVC AI 翻唱 - 主入口
"""
import os
import sys
import argparse
from pathlib import Path
# 添加项目根目录到路径
ROOT_DIR = Path(__file__).parent
sys.path.insert(0, str(ROOT_DIR))
from lib.logger import log
def check_environment():
"""检查运行环境"""
log.header("RVC AI 翻唱系统")
# 检查 Python 版本
py_version = sys.version_info
log.info(f"Python 版本: {py_version.major}.{py_version.minor}.{py_version.micro}")
if py_version.major < 3 or (py_version.major == 3 and py_version.minor < 8):
log.warning("建议使用 Python 3.8 或更高版本")
# 检查 PyTorch
try:
import torch
log.info(f"PyTorch 版本: {torch.__version__}")
from lib.device import get_device_info, _is_rocm, _has_xpu, _has_directml, _has_mps
info = get_device_info()
log.info(f"可用加速后端: {', '.join(info['backends'])}")
if torch.cuda.is_available():
backend = "ROCm" if _is_rocm() else "CUDA"
log.info(f"{backend} 版本: {torch.version.hip if _is_rocm() else torch.version.cuda}")
log.info(f"GPU: {torch.cuda.get_device_name(0)}")
elif _has_xpu():
log.info(f"Intel GPU: {torch.xpu.get_device_name(0)}")
elif _has_directml():
import torch_directml
log.info(f"DirectML 设备: {torch_directml.device_name(0)}")
elif _has_mps():
log.info("Apple MPS 加速可用")
else:
log.warning("未检测到 GPU 加速,将使用 CPU")
except ImportError:
log.error("未安装 PyTorch")
return False
return True
def check_models():
"""检查必需模型"""
from tools.download_models import check_model, REQUIRED_MODELS
missing = []
for name in REQUIRED_MODELS:
if not check_model(name):
missing.append(name)
if missing:
log.warning(f"缺少必需模型: {', '.join(missing)}")
log.info("正在下载...")
from tools.download_models import download_required_models
if not download_required_models():
log.error("模型下载失败,请检查网络连接")
return False
return True
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="RVC AI 翻唱系统")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="服务器地址 (默认: 127.0.0.1)"
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="服务器端口 (默认: 7860)"
)
parser.add_argument(
"--share",
action="store_true",
help="创建公共链接"
)
parser.add_argument(
"--skip-check",
action="store_true",
help="跳过环境检查"
)
parser.add_argument(
"--download-models",
action="store_true",
help="仅下载模型"
)
args = parser.parse_args()
# 仅下载模型
if args.download_models:
from tools.download_models import download_all_models
download_all_models()
return
# 环境检查
if not args.skip_check:
if not check_environment():
sys.exit(1)
# 模型检查
if not check_models():
log.info("提示: 可以使用 --skip-check 跳过检查")
sys.exit(1)
# 启动界面
log.info(f"启动 Gradio 界面: http://{args.host}:{args.port}")
from ui.app import launch
launch(host=args.host, port=args.port, share=args.share)
if __name__ == "__main__":
main()
|