| | |
| | """ |
| | RVC AI 翻唱 - 主入口 |
| | """ |
| | import os |
| | import sys |
| | import argparse |
| | from pathlib import Path |
| |
|
| | |
| | ROOT_DIR = Path(__file__).parent |
| | sys.path.insert(0, str(ROOT_DIR)) |
| |
|
| | from lib.logger import log |
| |
|
| |
|
| | def check_environment(): |
| | """检查运行环境""" |
| | log.header("RVC AI 翻唱系统") |
| |
|
| | |
| | py_version = sys.version_info |
| | log.info(f"Python 版本: {py_version.major}.{py_version.minor}.{py_version.micro}") |
| |
|
| | if py_version.major < 3 or (py_version.major == 3 and py_version.minor < 8): |
| | log.warning("建议使用 Python 3.8 或更高版本") |
| |
|
| | |
| | try: |
| | import torch |
| | log.info(f"PyTorch 版本: {torch.__version__}") |
| |
|
| | from lib.device import get_device_info, _is_rocm, _has_xpu, _has_directml, _has_mps |
| | info = get_device_info() |
| | log.info(f"可用加速后端: {', '.join(info['backends'])}") |
| |
|
| | if torch.cuda.is_available(): |
| | backend = "ROCm" if _is_rocm() else "CUDA" |
| | log.info(f"{backend} 版本: {torch.version.hip if _is_rocm() else torch.version.cuda}") |
| | log.info(f"GPU: {torch.cuda.get_device_name(0)}") |
| | elif _has_xpu(): |
| | log.info(f"Intel GPU: {torch.xpu.get_device_name(0)}") |
| | elif _has_directml(): |
| | import torch_directml |
| | log.info(f"DirectML 设备: {torch_directml.device_name(0)}") |
| | elif _has_mps(): |
| | log.info("Apple MPS 加速可用") |
| | else: |
| | log.warning("未检测到 GPU 加速,将使用 CPU") |
| | except ImportError: |
| | log.error("未安装 PyTorch") |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def check_models(): |
| | """检查必需模型""" |
| | from tools.download_models import check_model, REQUIRED_MODELS |
| |
|
| | missing = [] |
| | for name in REQUIRED_MODELS: |
| | if not check_model(name): |
| | missing.append(name) |
| |
|
| | if missing: |
| | log.warning(f"缺少必需模型: {', '.join(missing)}") |
| | log.info("正在下载...") |
| | from tools.download_models import download_required_models |
| | if not download_required_models(): |
| | log.error("模型下载失败,请检查网络连接") |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def main(): |
| | """主函数""" |
| | parser = argparse.ArgumentParser(description="RVC AI 翻唱系统") |
| | parser.add_argument( |
| | "--host", |
| | type=str, |
| | default="127.0.0.1", |
| | help="服务器地址 (默认: 127.0.0.1)" |
| | ) |
| | parser.add_argument( |
| | "--port", |
| | type=int, |
| | default=7860, |
| | help="服务器端口 (默认: 7860)" |
| | ) |
| | parser.add_argument( |
| | "--share", |
| | action="store_true", |
| | help="创建公共链接" |
| | ) |
| | parser.add_argument( |
| | "--skip-check", |
| | action="store_true", |
| | help="跳过环境检查" |
| | ) |
| | parser.add_argument( |
| | "--download-models", |
| | action="store_true", |
| | help="仅下载模型" |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if args.download_models: |
| | from tools.download_models import download_all_models |
| | download_all_models() |
| | return |
| |
|
| | |
| | if not args.skip_check: |
| | if not check_environment(): |
| | sys.exit(1) |
| |
|
| | |
| | if not check_models(): |
| | log.info("提示: 可以使用 --skip-check 跳过检查") |
| | sys.exit(1) |
| |
|
| | |
| | log.info(f"启动 Gradio 界面: http://{args.host}:{args.port}") |
| | from ui.app import launch |
| | launch(host=args.host, port=args.port, share=args.share) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|