File size: 3,653 Bytes
4204217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# -*- coding: utf-8 -*-
"""
RVC AI 翻唱 - 主入口
"""
import os
import sys
import argparse
from pathlib import Path

# 添加项目根目录到路径
ROOT_DIR = Path(__file__).parent
sys.path.insert(0, str(ROOT_DIR))

from lib.logger import log


def check_environment():
    """检查运行环境"""
    log.header("RVC AI 翻唱系统")

    # 检查 Python 版本
    py_version = sys.version_info
    log.info(f"Python 版本: {py_version.major}.{py_version.minor}.{py_version.micro}")

    if py_version.major < 3 or (py_version.major == 3 and py_version.minor < 8):
        log.warning("建议使用 Python 3.8 或更高版本")

    # 检查 PyTorch
    try:
        import torch
        log.info(f"PyTorch 版本: {torch.__version__}")

        from lib.device import get_device_info, _is_rocm, _has_xpu, _has_directml, _has_mps
        info = get_device_info()
        log.info(f"可用加速后端: {', '.join(info['backends'])}")

        if torch.cuda.is_available():
            backend = "ROCm" if _is_rocm() else "CUDA"
            log.info(f"{backend} 版本: {torch.version.hip if _is_rocm() else torch.version.cuda}")
            log.info(f"GPU: {torch.cuda.get_device_name(0)}")
        elif _has_xpu():
            log.info(f"Intel GPU: {torch.xpu.get_device_name(0)}")
        elif _has_directml():
            import torch_directml
            log.info(f"DirectML 设备: {torch_directml.device_name(0)}")
        elif _has_mps():
            log.info("Apple MPS 加速可用")
        else:
            log.warning("未检测到 GPU 加速,将使用 CPU")
    except ImportError:
        log.error("未安装 PyTorch")
        return False

    return True


def check_models():
    """检查必需模型"""
    from tools.download_models import check_model, REQUIRED_MODELS

    missing = []
    for name in REQUIRED_MODELS:
        if not check_model(name):
            missing.append(name)

    if missing:
        log.warning(f"缺少必需模型: {', '.join(missing)}")
        log.info("正在下载...")
        from tools.download_models import download_required_models
        if not download_required_models():
            log.error("模型下载失败,请检查网络连接")
            return False

    return True


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="RVC AI 翻唱系统")
    parser.add_argument(
        "--host",
        type=str,
        default="127.0.0.1",
        help="服务器地址 (默认: 127.0.0.1)"
    )
    parser.add_argument(
        "--port",
        type=int,
        default=7860,
        help="服务器端口 (默认: 7860)"
    )
    parser.add_argument(
        "--share",
        action="store_true",
        help="创建公共链接"
    )
    parser.add_argument(
        "--skip-check",
        action="store_true",
        help="跳过环境检查"
    )
    parser.add_argument(
        "--download-models",
        action="store_true",
        help="仅下载模型"
    )

    args = parser.parse_args()

    # 仅下载模型
    if args.download_models:
        from tools.download_models import download_all_models
        download_all_models()
        return

    # 环境检查
    if not args.skip_check:
        if not check_environment():
            sys.exit(1)

    # 模型检查
    if not check_models():
        log.info("提示: 可以使用 --skip-check 跳过检查")
        sys.exit(1)

    # 启动界面
    log.info(f"启动 Gradio 界面: http://{args.host}:{args.port}")
    from ui.app import launch
    launch(host=args.host, port=args.port, share=args.share)


if __name__ == "__main__":
    main()