# -*- coding: utf-8 -*- """ 人力V助手 (JinrikiHelper) - 云端部署入口 适用于 Hugging Face Spaces / 魔塔社区 """ import os import sys import subprocess import platform import logging import time import zipfile from pathlib import Path logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 项目根目录 BASE_DIR = Path(__file__).parent.absolute() # 云端持久化模型目录(魔搭创空间 /home/studio_service/ 是持久化的) PERSISTENT_MODELS_DIR = Path("/home/studio_service/models") # 本地模型目录 LOCAL_MODELS_DIR = BASE_DIR / "models" # 根据环境选择模型目录 def get_models_dir(): """获取模型目录,云端使用持久化路径""" if PERSISTENT_MODELS_DIR.parent.exists() and not LOCAL_MODELS_DIR.is_symlink(): # 魔搭创空间环境,使用持久化目录 PERSISTENT_MODELS_DIR.mkdir(parents=True, exist_ok=True) # 如果本地 models 目录存在且不是符号链接,先迁移已有模型 if LOCAL_MODELS_DIR.exists() and LOCAL_MODELS_DIR.is_dir(): import shutil for item in LOCAL_MODELS_DIR.iterdir(): dest = PERSISTENT_MODELS_DIR / item.name if not dest.exists(): shutil.move(str(item), str(dest)) shutil.rmtree(LOCAL_MODELS_DIR, ignore_errors=True) # 创建符号链接 if not LOCAL_MODELS_DIR.exists(): LOCAL_MODELS_DIR.symlink_to(PERSISTENT_MODELS_DIR) return PERSISTENT_MODELS_DIR return LOCAL_MODELS_DIR MODELS_DIR = None # 延迟初始化 MFA_DIR = None def ensure_ffmpeg(): """确保 ffmpeg 已安装(用于音频格式转换,支持 m4a 等格式)""" import shutil if shutil.which("ffmpeg"): logger.info("ffmpeg 已安装") return True logger.info("ffmpeg 未安装,尝试安装...") try: subprocess.run(["apt-get", "update"], capture_output=True, text=True, timeout=60) subprocess.run(["apt-get", "install", "-y", "ffmpeg"], capture_output=True, text=True, timeout=120) if shutil.which("ffmpeg"): logger.info("ffmpeg 安装成功") return True logger.warning("ffmpeg 安装后仍未找到") return False except subprocess.TimeoutExpired: logger.warning("ffmpeg 安装超时") return False except Exception as e: logger.warning(f"ffmpeg 安装失败: {e}") return False def setup_mfa_linux() -> bool: """Linux 环境下安装 MFA(使用 MFA 官方推荐的 conda-forge 方案) 返回: bool: MFA 是否可用 """ import shutil import importlib.util import tarfile import tempfile import urllib.request import stat def _run_cmd_ok(cmd, timeout=30): try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) return result.returncode == 0, (result.stdout or ""), (result.stderr or "") except Exception: return False, "", "" def _get_callable_conda_tools() -> dict[str, str]: """查找可调用的 conda 类工具(conda/mamba/micromamba)。""" tools = {} for tool_name in ["conda", "mamba", "micromamba"]: tool_path = shutil.which(tool_name) if not tool_path: continue ok, _, _ = _run_cmd_ok([tool_path, "--version"], timeout=20) if ok: tools[tool_name] = tool_path else: logger.warning(f"检测到 {tool_name} 但不可调用,已跳过: {tool_path}") return tools def _ensure_micromamba_available() -> str | None: """自动下载并配置 micromamba(无控制台场景)。""" mm_path = shutil.which("micromamba") if mm_path: return mm_path if PERSISTENT_MODELS_DIR.parent.exists(): mamba_root = PERSISTENT_MODELS_DIR.parent / "micromamba" else: mamba_root = BASE_DIR / ".micromamba" mamba_bin_dir = mamba_root / "bin" mamba_bin_dir.mkdir(parents=True, exist_ok=True) target_bin = mamba_bin_dir / "micromamba" os.environ["MAMBA_ROOT_PREFIX"] = str(mamba_root) if str(mamba_bin_dir) not in os.environ.get("PATH", ""): os.environ["PATH"] = f"{mamba_bin_dir}:{os.environ.get('PATH', '')}" if target_bin.exists(): target_bin.chmod(target_bin.stat().st_mode | stat.S_IEXEC) logger.info(f"检测到已存在 micromamba: {target_bin}") return str(target_bin) urls = [ "https://micro.mamba.pm/api/micromamba/linux-64/latest", "https://gh-proxy.com/https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-linux-64", "https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-linux-64", ] for url in urls: for attempt in range(1, 3): try: logger.info(f"下载 micromamba: {url} (第{attempt}次)") with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as tmp: tmp_path = Path(tmp.name) try: with urllib.request.urlopen(url, timeout=180) as resp: data = resp.read() tmp_path.write_bytes(data) # 处理 tar.bz2 包格式 extracted = False try: with tarfile.open(tmp_path, mode="r:*") as tar: member = None for m in tar.getmembers(): if m.name.endswith("/bin/micromamba") or m.name == "bin/micromamba": member = m break if member is not None: src = tar.extractfile(member) if src is None: raise RuntimeError("无法从压缩包读取 micromamba") target_bin.write_bytes(src.read()) extracted = True except tarfile.TarError: extracted = False # 处理单文件二进制格式 if not extracted: target_bin.write_bytes(tmp_path.read_bytes()) target_bin.chmod(target_bin.stat().st_mode | stat.S_IEXEC) logger.info(f"micromamba 就绪: {target_bin}") return str(target_bin) finally: if tmp_path.exists(): tmp_path.unlink(missing_ok=True) except Exception as e: logger.warning(f"下载/安装 micromamba 失败: {e}") logger.error("自动配置 micromamba 失败") return None def verify_mfa_working() -> bool: mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa") commands = [] conda_tools = _get_callable_conda_tools() # 优先官方 conda/mamba 入口,避免 PATH 中残留的 pip 版 mfa(缺少 _kalpy) conda = conda_tools.get("conda") if conda: commands.append([conda, "run", "-n", mfa_env_name, "mfa", "--help"]) commands.append([conda, "run", "-n", "base", "mfa", "--help"]) micromamba = conda_tools.get("micromamba") if micromamba: commands.append([micromamba, "run", "-n", mfa_env_name, "mfa", "--help"]) commands.append([micromamba, "run", "-n", "base", "mfa", "--help"]) mamba = conda_tools.get("mamba") if mamba: commands.append([mamba, "run", "-n", mfa_env_name, "mfa", "--help"]) commands.append([mamba, "run", "-n", "base", "mfa", "--help"]) commands.extend([ [sys.executable, "-m", "montreal_forced_aligner.command_line.mfa", "--help"], [sys.executable, "-m", "montreal_forced_aligner", "--help"], ["mfa", "--help"], ]) py_bin_dir = Path(sys.executable).parent mfa_bin = py_bin_dir / "mfa" if mfa_bin.exists(): commands.insert(0, [str(mfa_bin), "--help"]) for cmd in commands: ok, stdout, stderr = _run_cmd_ok(cmd, timeout=120) if ok: logger.info(f"MFA 验证命令通过: {' '.join(cmd)}") return True output = f"{stdout}\n{stderr}" if "No module named '_kalpy'" in output: logger.warning(f"命令 {' '.join(cmd)} 缺少 _kalpy,跳过该入口") logger.warning("MFA 验证命令均未通过,可能缺少 kalpy/_kalpy 或入口脚本未加入 PATH") return False # 检查是否已可用 if verify_mfa_working(): logger.info("MFA 已安装且工作正常") return True logger.info("MFA 不可用,Linux 下将使用 conda/mamba 从 conda-forge 安装(官方推荐)...") try: mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa") conda_tools = _get_callable_conda_tools() # 无控制台场景:自动补齐 micromamba if not conda_tools: logger.info("未检测到 conda/mamba/micromamba,开始自动配置 micromamba...") _ensure_micromamba_available() conda_tools = _get_callable_conda_tools() install_attempts = [] mamba = conda_tools.get("mamba") if mamba: install_attempts.append(( "mamba", [mamba, "install", "-y", "-n", "base", "-c", "conda-forge", "montreal-forced-aligner"], )) micromamba = conda_tools.get("micromamba") if micromamba: install_attempts.append(( f"micromamba(create:{mfa_env_name})", [micromamba, "create", "-y", "-n", mfa_env_name, "-c", "conda-forge", "montreal-forced-aligner"], )) install_attempts.append(( f"micromamba(install:{mfa_env_name})", [micromamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge", "montreal-forced-aligner"], )) install_attempts.append(( "micromamba(base)", [micromamba, "install", "-y", "-n", "base", "-c", "conda-forge", "montreal-forced-aligner"], )) conda = conda_tools.get("conda") if conda: install_attempts.append(( "conda", [conda, "install", "-y", "-n", "base", "-c", "conda-forge", "montreal-forced-aligner"], )) if not install_attempts: logger.error("未找到 conda/mamba/micromamba,无法按官方推荐方法安装 MFA") return False installed = False for installer_name, install_cmd in install_attempts: logger.info(f"尝试使用 {installer_name} 安装 MFA...") result = subprocess.run( install_cmd, capture_output=True, text=True, timeout=1800, check=False, ) if result.returncode == 0: logger.info(f"{installer_name} 安装 MFA 完成") installed = True break stderr_tail = (result.stderr or result.stdout or "")[-800:] logger.warning(f"{installer_name} 安装 MFA 失败: {stderr_tail}") if not installed: logger.error("所有 conda/mamba 安装尝试均失败") return False # 某些云端环境不会自动刷新 PATH,补充常见 conda bin 目录 if not shutil.which("mfa"): candidate_dirs = [Path("/opt/conda/bin"), Path.home() / "micromamba" / "bin"] conda_prefix = os.environ.get("CONDA_PREFIX") if conda_prefix: candidate_dirs.insert(0, Path(conda_prefix) / "bin") for candidate_dir in candidate_dirs: candidate_mfa = candidate_dir / "mfa" if candidate_mfa.exists() and str(candidate_dir) not in os.environ.get("PATH", ""): os.environ["PATH"] = f"{candidate_dir}:{os.environ.get('PATH', '')}" logger.info(f"已将 {candidate_dir} 加入 PATH") break # 在 MFA 环境中安装日语/中文分词依赖(必须在 MFA 环境中,而不是系统 Python) # 否则 MFA 无法访问 spacy/sudachipy 包 pkuseg_home = PERSISTENT_MODELS_DIR / "pkuseg" if PERSISTENT_MODELS_DIR.parent.exists() else Path("/root/.pkuseg") pkuseg_home.mkdir(parents=True, exist_ok=True) os.environ["PKUSEG_HOME"] = str(pkuseg_home) logger.info("在 MFA 环境中安装日语/中文分词支持...") conda_tools = _get_callable_conda_tools() mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa") # 优先使用conda-forge安装到MFA环境(最稳定) deps_installed = False # 尝试候选安装方式 install_attempts_deps = [] if conda_tools.get("micromamba"): micromamba = conda_tools["micromamba"] install_attempts_deps.extend([ ("micromamba(conda-forge)", [micromamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge", "spacy", "sudachipy", "sudachidict-core"]), ("micromamba(pip in mfa env)", [micromamba, "run", "-n", mfa_env_name, "pip", "install", "--no-cache-dir", "spacy-pkuseg", "dragonmapper", "hanziconv"]), ]) if conda_tools.get("mamba"): mamba = conda_tools["mamba"] install_attempts_deps.append( ("mamba(conda-forge)", [mamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge", "spacy", "sudachipy", "sudachidict-core"]) ) if conda_tools.get("conda"): conda = conda_tools["conda"] install_attempts_deps.append( ("conda(conda-forge)", [conda, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge", "spacy", "sudachipy", "sudachidict-core"]) ) for installer_name, dep_cmd in install_attempts_deps: logger.info(f"尝试用 {installer_name} 安装分词依赖...") dep_result = subprocess.run(dep_cmd, capture_output=True, text=True, timeout=600, check=False) if dep_result.returncode == 0: logger.info(f"{installer_name} 分词依赖安装完成") deps_installed = True break else: stderr_tail = (dep_result.stderr or "")[-500:] logger.warning(f"{installer_name} 分词依赖安装失败: {stderr_tail}") if not deps_installed: logger.warning("MFA 环境分词依赖安装失败,某些语言对齐可能不可用(继续运行)") if verify_mfa_working(): logger.info("MFA 验证通过") return True logger.warning("MFA 安装后仍不可用,将以无 MFA 模式继续") return False except subprocess.TimeoutExpired as e: logger.error(f"MFA 安装超时: {e}") return False except Exception as e: logger.error(f"MFA 安装异常: {e}") return False def setup_environment(): """初始化云端环境""" global MODELS_DIR, MFA_DIR # 初始化模型目录(可能创建符号链接) MODELS_DIR = get_models_dir() MFA_DIR = MODELS_DIR / "mfa" # 检测运行环境 is_cloud = any([ os.environ.get("SPACE_ID"), os.environ.get("MODELSCOPE_SPACE"), os.environ.get("GRADIO_SERVER_NAME"), Path("/home/studio_service").exists(), ]) if is_cloud or platform.system() != "Windows": ensure_ffmpeg() # 魔搭创空间无法访问 HuggingFace,使用镜像 if is_cloud and Path("/home/studio_service").exists(): os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" logger.info("已设置 HuggingFace 镜像: hf-mirror.com") # 设置 pkuseg 模型目录(中文分词依赖,必须在 MFA 运行前设置) if platform.system() != "Windows": pkuseg_home = PERSISTENT_MODELS_DIR / "pkuseg" if PERSISTENT_MODELS_DIR.parent.exists() else Path("/root/.pkuseg") pkuseg_home.mkdir(parents=True, exist_ok=True) os.environ["PKUSEG_HOME"] = str(pkuseg_home) logger.info(f"设置 PKUSEG_HOME: {pkuseg_home}") logger.info("Linux 环境,检查并安装 MFA...") mfa_ok = setup_mfa_linux() if not mfa_ok: logger.warning("MFA 当前不可用,将跳过对齐功能但继续启动服务") if is_cloud: os.environ.setdefault("TMPDIR", "/tmp") download_all_models() else: logger.info("本地环境运行") def download_all_models(): """下载所有必需的模型,任何模型下载失败则退出程序""" logger.info("=" * 50) logger.info("开始下载模型...") logger.info("=" * 50) # 确保目录存在 os.makedirs(MODELS_DIR, exist_ok=True) os.makedirs(MFA_DIR, exist_ok=True) errors = [] # 1. 下载 Silero VAD if not download_silero_vad_model(): errors.append("Silero VAD") # 2. 下载 Whisper 模型 if not download_whisper_models(): errors.append("Whisper") # 3. 下载 MFA 模型(中文和日语) if not download_mfa_models_all(): errors.append("MFA") # 4. 下载 pkuseg 模型(中文分词必需,独立下载确保执行) if not download_pkuseg_models(): errors.append("pkuseg") if errors: logger.error("=" * 50) logger.error(f"以下模型加载失败: {', '.join(errors)}") strict_mode = os.environ.get("JINRIKI_STRICT_MODEL_DOWNLOAD", "0") == "1" if strict_mode: logger.error("严格模式开启,程序退出(可通过 JINRIKI_STRICT_MODEL_DOWNLOAD=0 关闭)") logger.error("=" * 50) sys.exit(1) logger.warning("将以降级模式继续启动(部分功能可能不可用)") logger.warning("如需下载成功后再启动,请设置 JINRIKI_STRICT_MODEL_DOWNLOAD=1") logger.error("=" * 50) logger.info("=" * 50) logger.info("所有模型下载完成") logger.info("=" * 50) def download_pkuseg_models() -> bool: """下载 pkuseg 中文分词模型,返回是否成功 spacy-pkuseg 检查模型的逻辑: 1. 先检查 PKUSEG_HOME/.zip 是否存在 2. 如果 zip 存在,解压到 PKUSEG_HOME// 目录 3. 如果 zip 不存在,从 GitHub 下载 因此我们需要保留 .zip 文件,否则 spacy_pkuseg 会尝试重新下载 注意:只需要 spacy_ontonotes.zip,postag.zip 在 GitHub releases 中不存在 """ logger.info("\n【下载 pkuseg 模型】") pkuseg_home = Path(os.environ.get("PKUSEG_HOME", "/root/.pkuseg")) pkuseg_home.mkdir(parents=True, exist_ok=True) pkuseg_model_dir = pkuseg_home / "spacy_ontonotes" postag_model_dir = pkuseg_home / "postag" # 关键:检查 spacy_ontonotes.zip 是否存在(spacy_pkuseg 的检查逻辑) # postag.zip 在 GitHub releases 中不存在,不需要检查 spacy_ontonotes_zip = pkuseg_home / "spacy_ontonotes.zip" def _pkuseg_model_ready() -> bool: feature_candidates = [ pkuseg_model_dir / "features.msgpack", pkuseg_model_dir / "features.pkl", pkuseg_model_dir / "features.json", ] has_feature = any(p.exists() for p in feature_candidates) has_unigram = (pkuseg_model_dir / "unigram_word.txt").exists() return spacy_ontonotes_zip.exists() and has_feature and has_unigram def _repair_extract_from_zip() -> bool: if not spacy_ontonotes_zip.exists(): return False try: pkuseg_model_dir.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(spacy_ontonotes_zip, "r") as zf: zf.extractall(pkuseg_model_dir) # 兼容 zip 内部带一层 spacy_ontonotes/ 目录的情况 nested_dir = pkuseg_model_dir / "spacy_ontonotes" if nested_dir.is_dir(): import shutil for item in nested_dir.iterdir(): dst = pkuseg_model_dir / item.name if dst.exists(): if dst.is_dir(): shutil.rmtree(dst, ignore_errors=True) else: dst.unlink(missing_ok=True) shutil.move(str(item), str(dst)) shutil.rmtree(nested_dir, ignore_errors=True) return _pkuseg_model_ready() except Exception as e: logger.warning(f"修复 pkuseg 解压失败: {e}") return False if _pkuseg_model_ready(): logger.info(f"pkuseg 模型已就绪: {pkuseg_home}") return True # zip 已存在但模型不完整:优先尝试重新解压修复 if spacy_ontonotes_zip.exists(): logger.warning("检测到 spacy_ontonotes.zip 已存在,但模型目录不完整,尝试修复解压...") if _repair_extract_from_zip(): logger.info("pkuseg 模型修复成功") return True logger.warning("pkuseg 模型修复失败,准备重新下载") try: spacy_ontonotes_zip.unlink(missing_ok=True) except Exception as e: logger.warning(f"删除损坏 zip 失败: {e}") # 检查是否有文件被错误解压到根目录(旧版本遗留问题) root_msgpack = pkuseg_home / "features.msgpack" if root_msgpack.exists(): logger.info("检测到模型文件在根目录,移动到正确位置...") pkuseg_model_dir.mkdir(parents=True, exist_ok=True) # 移动 spacy_ontonotes 相关文件 for filename in ["features.msgpack", "weights.npz"]: src = pkuseg_home / filename if src.exists(): dst = pkuseg_model_dir / filename src.rename(dst) logger.info(f"移动 {filename} -> spacy_ontonotes/") # 移动 postag 相关文件 postag_model_dir.mkdir(parents=True, exist_ok=True) for filename in ["features.pkl"]: src = pkuseg_home / filename if src.exists(): dst = postag_model_dir / filename src.rename(dst) logger.info(f"移动 {filename} -> postag/") # 只检查 spacy_ontonotes.zip(这是 spacy_pkuseg 必需的) # postag 模型在 GitHub releases 中不存在,spacy_pkuseg 会使用内置的词性标注 if _pkuseg_model_ready(): logger.info(f"pkuseg 模型已就绪: {pkuseg_home}") return True # 需要下载模型 logger.info("需要下载 pkuseg 模型: spacy_ontonotes") # 使用 spacy-pkuseg 的模型(新格式 msgpack) # 注意:必须保留 .zip 文件,spacy_pkuseg 会检查 zip 是否存在 # postag.zip 在 GitHub releases 中不存在,不需要下载 models = [ { "name": "spacy_ontonotes", "urls": [ "https://gitcode.com/gh_mirrors/sp/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip", "https://ghfast.top/https://github.com/explosion/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip", "https://gh-proxy.com/https://github.com/explosion/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip", "https://github.com/explosion/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip", ], "check_file": "features.msgpack", }, ] for model in models: model_name = model["name"] model_dir = pkuseg_home / model_name zip_path = pkuseg_home / f"{model_name}.zip" check_file = model_dir / model["check_file"] downloaded = False timeout_seconds = int(os.environ.get("JINRIKI_PKUSEG_DOWNLOAD_TIMEOUT", "300")) max_rounds = int(os.environ.get("JINRIKI_PKUSEG_DOWNLOAD_ROUNDS", "3")) for round_idx in range(1, max_rounds + 1): logger.info(f"{model_name} 下载轮次: {round_idx}/{max_rounds}") for url in model["urls"]: logger.info(f"下载 {model_name}: {url}") try: # 下载 result = subprocess.run( ["curl", "-fL", "--retry", "2", "--retry-delay", "2", "-o", str(zip_path), url], capture_output=True, text=True, timeout=timeout_seconds ) if result.returncode != 0: logger.warning(f"curl 下载失败: {result.stderr}") continue if not zip_path.exists() or zip_path.stat().st_size < 1000: logger.warning("下载文件无效或太小") continue logger.info(f"下载完成,文件大小: {zip_path.stat().st_size} bytes") # 创建目标目录并解压到其中(优先 python zipfile,避免系统 unzip 差异) model_dir.mkdir(parents=True, exist_ok=True) try: with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(model_dir) except Exception as zip_e: logger.warning(f"zipfile 解压失败,尝试系统 unzip: {zip_e}") result = subprocess.run( ["unzip", "-o", "-q", str(zip_path), "-d", str(model_dir)], capture_output=True, text=True, timeout=120 ) if result.returncode != 0: logger.warning(f"unzip 解压失败: {result.stderr}") continue # 兼容 zip 内部带一层 model_name/ 的情况 nested_dir = model_dir / model_name if nested_dir.is_dir(): import shutil for item in nested_dir.iterdir(): dst = model_dir / item.name if dst.exists(): if dst.is_dir(): shutil.rmtree(dst, ignore_errors=True) else: dst.unlink(missing_ok=True) shutil.move(str(item), str(dst)) shutil.rmtree(nested_dir, ignore_errors=True) # 重要:保留 zip 文件!spacy_pkuseg 会检查 zip 是否存在 # 不要删除 zip_path # 验证:至少要有 check_file + unigram_word.txt unigram_file = model_dir / "unigram_word.txt" if check_file.exists() and unigram_file.exists(): logger.info(f"{model_name} 下载并解压成功(保留 zip 文件)") files = [f.name for f in model_dir.iterdir()] logger.info(f"{model_name} 目录内容: {files}") downloaded = True break else: logger.warning(f"解压后关键文件缺失: {check_file.name} / {unigram_file.name}") if model_dir.exists(): files = [f.name for f in model_dir.iterdir()] logger.info(f"{model_name} 目录内容: {files}") except subprocess.TimeoutExpired: logger.warning(f"下载超时: {url}") except Exception as e: logger.warning(f"下载异常: {e}") if downloaded: break sleep_seconds = min(5 * round_idx, 15) logger.info(f"{model_name} 本轮未成功,{sleep_seconds} 秒后重试") time.sleep(sleep_seconds) if not downloaded: logger.error(f"{model_name} 所有镜像下载失败") return False logger.info("pkuseg 模型下载完成") return True def download_silero_vad_model() -> bool: """下载 Silero VAD 模型,返回是否成功""" logger.info("\n【下载 Silero VAD 模型】") try: from src.silero_vad_downloader import download_silero_vad, is_vad_model_downloaded if is_vad_model_downloaded(str(MODELS_DIR)): logger.info("Silero VAD 模型已存在,跳过下载") return True success, result = download_silero_vad(str(MODELS_DIR), logger.info) if success: logger.info(f"Silero VAD 下载成功: {result}") return True else: logger.error(f"Silero VAD 下载失败: {result}") return False except Exception as e: logger.error(f"Silero VAD 下载异常: {e}") return False def download_whisper_models() -> bool: """下载 Whisper 模型 (small 和 medium),返回是否成功""" logger.info("\n【下载 Whisper 模型】") try: from transformers import WhisperProcessor, WhisperForConditionalGeneration import torch cache_dir = str(MODELS_DIR / "whisper") os.makedirs(cache_dir, exist_ok=True) os.environ["HF_HOME"] = cache_dir os.environ["TRANSFORMERS_CACHE"] = cache_dir models = ["openai/whisper-small", "openai/whisper-medium"] for model_name in models: logger.info(f"下载 {model_name}...") try: # 检查是否已下载 model_cache_name = model_name.replace("/", "--") model_cache_path = Path(cache_dir) / f"models--{model_cache_name}" if model_cache_path.exists(): logger.info(f"{model_name} 已存在,跳过下载") continue # 下载 processor 和 model _ = WhisperProcessor.from_pretrained(model_name, cache_dir=cache_dir) _ = WhisperForConditionalGeneration.from_pretrained( model_name, cache_dir=cache_dir, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) logger.info(f"{model_name} 下载完成") except Exception as e: logger.error(f"{model_name} 下载失败: {e}") return False return True except Exception as e: logger.error(f"Whisper 模型下载异常: {e}") return False def download_mfa_models_all() -> bool: """下载 MFA 中文和日语模型(带完整性校验),返回是否成功""" logger.info("\n【下载 MFA 模型】") try: from src.mfa_model_downloader import download_language_models, _verify_file_integrity, LANGUAGE_MODELS languages = ["mandarin", "japanese"] per_language_retries = int(os.environ.get("JINRIKI_MFA_MODEL_RETRIES", "3")) for lang in languages: logger.info(f"\n下载 {lang} 模型...") # 检查现有字典文件是否损坏 dict_config = LANGUAGE_MODELS[lang]["dictionary"] dict_path = MFA_DIR / dict_config["filename"] hash_path = MFA_DIR / (dict_config["filename"] + ".sha256") if dict_path.exists(): # 如果没有哈希文件,说明是旧版本下载的,需要验证 if not hash_path.exists(): logger.info(f"检测到旧版字典文件(无哈希),验证完整性...") min_lines = dict_config.get("min_lines") is_valid, reason = _verify_file_integrity(str(dict_path), min_lines, logger.info) if not is_valid: logger.warning(f"字典文件损坏: {reason},删除并重新下载...") try: dict_path.unlink() except Exception as e: logger.error(f"删除损坏文件失败: {e}") success = False last_err = "" for attempt in range(1, per_language_retries + 1): logger.info(f"{lang} 下载尝试: {attempt}/{per_language_retries}") ok, acoustic_path, dict_path = download_language_models( lang, str(MFA_DIR), logger.info ) if ok: success = True break last_err = dict_path wait_seconds = min(10 * attempt, 30) logger.warning(f"{lang} 下载失败(第{attempt}次): {last_err}") if attempt < per_language_retries: logger.info(f"{wait_seconds} 秒后重试 {lang}...") time.sleep(wait_seconds) if not success: logger.error(f"{lang} 模型下载失败") return False logger.info(f"{lang} 模型下载完成") return True except Exception as e: logger.error(f"MFA 模型下载异常: {e}") return False def main(): """主入口""" setup_environment() # 导入并启动云端 GUI from src.gui_cloud import create_cloud_ui app = create_cloud_ui() # 云端配置 # 启用队列,魔搭CPU按需分配,无需设置并发上限 app.queue() app.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, ) if __name__ == "__main__": main()