JinrikiHelper / app.py
TNOT's picture
fix: self-heal pkuseg model files before MFA tokenization
a84527e
# -*- coding: utf-8 -*-
"""
人力V助手 (JinrikiHelper) - 云端部署入口
适用于 Hugging Face Spaces / 魔塔社区
"""
import os
import sys
import subprocess
import platform
import logging
import time
import zipfile
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 项目根目录
BASE_DIR = Path(__file__).parent.absolute()
# 云端持久化模型目录(魔搭创空间 /home/studio_service/ 是持久化的)
PERSISTENT_MODELS_DIR = Path("/home/studio_service/models")
# 本地模型目录
LOCAL_MODELS_DIR = BASE_DIR / "models"
# 根据环境选择模型目录
def get_models_dir():
"""获取模型目录,云端使用持久化路径"""
if PERSISTENT_MODELS_DIR.parent.exists() and not LOCAL_MODELS_DIR.is_symlink():
# 魔搭创空间环境,使用持久化目录
PERSISTENT_MODELS_DIR.mkdir(parents=True, exist_ok=True)
# 如果本地 models 目录存在且不是符号链接,先迁移已有模型
if LOCAL_MODELS_DIR.exists() and LOCAL_MODELS_DIR.is_dir():
import shutil
for item in LOCAL_MODELS_DIR.iterdir():
dest = PERSISTENT_MODELS_DIR / item.name
if not dest.exists():
shutil.move(str(item), str(dest))
shutil.rmtree(LOCAL_MODELS_DIR, ignore_errors=True)
# 创建符号链接
if not LOCAL_MODELS_DIR.exists():
LOCAL_MODELS_DIR.symlink_to(PERSISTENT_MODELS_DIR)
return PERSISTENT_MODELS_DIR
return LOCAL_MODELS_DIR
MODELS_DIR = None # 延迟初始化
MFA_DIR = None
def ensure_ffmpeg():
"""确保 ffmpeg 已安装(用于音频格式转换,支持 m4a 等格式)"""
import shutil
if shutil.which("ffmpeg"):
logger.info("ffmpeg 已安装")
return True
logger.info("ffmpeg 未安装,尝试安装...")
try:
subprocess.run(["apt-get", "update"], capture_output=True, text=True, timeout=60)
subprocess.run(["apt-get", "install", "-y", "ffmpeg"], capture_output=True, text=True, timeout=120)
if shutil.which("ffmpeg"):
logger.info("ffmpeg 安装成功")
return True
logger.warning("ffmpeg 安装后仍未找到")
return False
except subprocess.TimeoutExpired:
logger.warning("ffmpeg 安装超时")
return False
except Exception as e:
logger.warning(f"ffmpeg 安装失败: {e}")
return False
def setup_mfa_linux() -> bool:
"""Linux 环境下安装 MFA(使用 MFA 官方推荐的 conda-forge 方案)
返回:
bool: MFA 是否可用
"""
import shutil
import importlib.util
import tarfile
import tempfile
import urllib.request
import stat
def _run_cmd_ok(cmd, timeout=30):
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
return result.returncode == 0, (result.stdout or ""), (result.stderr or "")
except Exception:
return False, "", ""
def _get_callable_conda_tools() -> dict[str, str]:
"""查找可调用的 conda 类工具(conda/mamba/micromamba)。"""
tools = {}
for tool_name in ["conda", "mamba", "micromamba"]:
tool_path = shutil.which(tool_name)
if not tool_path:
continue
ok, _, _ = _run_cmd_ok([tool_path, "--version"], timeout=20)
if ok:
tools[tool_name] = tool_path
else:
logger.warning(f"检测到 {tool_name} 但不可调用,已跳过: {tool_path}")
return tools
def _ensure_micromamba_available() -> str | None:
"""自动下载并配置 micromamba(无控制台场景)。"""
mm_path = shutil.which("micromamba")
if mm_path:
return mm_path
if PERSISTENT_MODELS_DIR.parent.exists():
mamba_root = PERSISTENT_MODELS_DIR.parent / "micromamba"
else:
mamba_root = BASE_DIR / ".micromamba"
mamba_bin_dir = mamba_root / "bin"
mamba_bin_dir.mkdir(parents=True, exist_ok=True)
target_bin = mamba_bin_dir / "micromamba"
os.environ["MAMBA_ROOT_PREFIX"] = str(mamba_root)
if str(mamba_bin_dir) not in os.environ.get("PATH", ""):
os.environ["PATH"] = f"{mamba_bin_dir}:{os.environ.get('PATH', '')}"
if target_bin.exists():
target_bin.chmod(target_bin.stat().st_mode | stat.S_IEXEC)
logger.info(f"检测到已存在 micromamba: {target_bin}")
return str(target_bin)
urls = [
"https://micro.mamba.pm/api/micromamba/linux-64/latest",
"https://gh-proxy.com/https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-linux-64",
"https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-linux-64",
]
for url in urls:
for attempt in range(1, 3):
try:
logger.info(f"下载 micromamba: {url} (第{attempt}次)")
with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as tmp:
tmp_path = Path(tmp.name)
try:
with urllib.request.urlopen(url, timeout=180) as resp:
data = resp.read()
tmp_path.write_bytes(data)
# 处理 tar.bz2 包格式
extracted = False
try:
with tarfile.open(tmp_path, mode="r:*") as tar:
member = None
for m in tar.getmembers():
if m.name.endswith("/bin/micromamba") or m.name == "bin/micromamba":
member = m
break
if member is not None:
src = tar.extractfile(member)
if src is None:
raise RuntimeError("无法从压缩包读取 micromamba")
target_bin.write_bytes(src.read())
extracted = True
except tarfile.TarError:
extracted = False
# 处理单文件二进制格式
if not extracted:
target_bin.write_bytes(tmp_path.read_bytes())
target_bin.chmod(target_bin.stat().st_mode | stat.S_IEXEC)
logger.info(f"micromamba 就绪: {target_bin}")
return str(target_bin)
finally:
if tmp_path.exists():
tmp_path.unlink(missing_ok=True)
except Exception as e:
logger.warning(f"下载/安装 micromamba 失败: {e}")
logger.error("自动配置 micromamba 失败")
return None
def verify_mfa_working() -> bool:
mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa")
commands = []
conda_tools = _get_callable_conda_tools()
# 优先官方 conda/mamba 入口,避免 PATH 中残留的 pip 版 mfa(缺少 _kalpy)
conda = conda_tools.get("conda")
if conda:
commands.append([conda, "run", "-n", mfa_env_name, "mfa", "--help"])
commands.append([conda, "run", "-n", "base", "mfa", "--help"])
micromamba = conda_tools.get("micromamba")
if micromamba:
commands.append([micromamba, "run", "-n", mfa_env_name, "mfa", "--help"])
commands.append([micromamba, "run", "-n", "base", "mfa", "--help"])
mamba = conda_tools.get("mamba")
if mamba:
commands.append([mamba, "run", "-n", mfa_env_name, "mfa", "--help"])
commands.append([mamba, "run", "-n", "base", "mfa", "--help"])
commands.extend([
[sys.executable, "-m", "montreal_forced_aligner.command_line.mfa", "--help"],
[sys.executable, "-m", "montreal_forced_aligner", "--help"],
["mfa", "--help"],
])
py_bin_dir = Path(sys.executable).parent
mfa_bin = py_bin_dir / "mfa"
if mfa_bin.exists():
commands.insert(0, [str(mfa_bin), "--help"])
for cmd in commands:
ok, stdout, stderr = _run_cmd_ok(cmd, timeout=120)
if ok:
logger.info(f"MFA 验证命令通过: {' '.join(cmd)}")
return True
output = f"{stdout}\n{stderr}"
if "No module named '_kalpy'" in output:
logger.warning(f"命令 {' '.join(cmd)} 缺少 _kalpy,跳过该入口")
logger.warning("MFA 验证命令均未通过,可能缺少 kalpy/_kalpy 或入口脚本未加入 PATH")
return False
# 检查是否已可用
if verify_mfa_working():
logger.info("MFA 已安装且工作正常")
return True
logger.info("MFA 不可用,Linux 下将使用 conda/mamba 从 conda-forge 安装(官方推荐)...")
try:
mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa")
conda_tools = _get_callable_conda_tools()
# 无控制台场景:自动补齐 micromamba
if not conda_tools:
logger.info("未检测到 conda/mamba/micromamba,开始自动配置 micromamba...")
_ensure_micromamba_available()
conda_tools = _get_callable_conda_tools()
install_attempts = []
mamba = conda_tools.get("mamba")
if mamba:
install_attempts.append((
"mamba",
[mamba, "install", "-y", "-n", "base", "-c", "conda-forge", "montreal-forced-aligner"],
))
micromamba = conda_tools.get("micromamba")
if micromamba:
install_attempts.append((
f"micromamba(create:{mfa_env_name})",
[micromamba, "create", "-y", "-n", mfa_env_name, "-c", "conda-forge", "montreal-forced-aligner"],
))
install_attempts.append((
f"micromamba(install:{mfa_env_name})",
[micromamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge", "montreal-forced-aligner"],
))
install_attempts.append((
"micromamba(base)",
[micromamba, "install", "-y", "-n", "base", "-c", "conda-forge", "montreal-forced-aligner"],
))
conda = conda_tools.get("conda")
if conda:
install_attempts.append((
"conda",
[conda, "install", "-y", "-n", "base", "-c", "conda-forge", "montreal-forced-aligner"],
))
if not install_attempts:
logger.error("未找到 conda/mamba/micromamba,无法按官方推荐方法安装 MFA")
return False
installed = False
for installer_name, install_cmd in install_attempts:
logger.info(f"尝试使用 {installer_name} 安装 MFA...")
result = subprocess.run(
install_cmd,
capture_output=True,
text=True,
timeout=1800,
check=False,
)
if result.returncode == 0:
logger.info(f"{installer_name} 安装 MFA 完成")
installed = True
break
stderr_tail = (result.stderr or result.stdout or "")[-800:]
logger.warning(f"{installer_name} 安装 MFA 失败: {stderr_tail}")
if not installed:
logger.error("所有 conda/mamba 安装尝试均失败")
return False
# 某些云端环境不会自动刷新 PATH,补充常见 conda bin 目录
if not shutil.which("mfa"):
candidate_dirs = [Path("/opt/conda/bin"), Path.home() / "micromamba" / "bin"]
conda_prefix = os.environ.get("CONDA_PREFIX")
if conda_prefix:
candidate_dirs.insert(0, Path(conda_prefix) / "bin")
for candidate_dir in candidate_dirs:
candidate_mfa = candidate_dir / "mfa"
if candidate_mfa.exists() and str(candidate_dir) not in os.environ.get("PATH", ""):
os.environ["PATH"] = f"{candidate_dir}:{os.environ.get('PATH', '')}"
logger.info(f"已将 {candidate_dir} 加入 PATH")
break
# 在 MFA 环境中安装日语/中文分词依赖(必须在 MFA 环境中,而不是系统 Python)
# 否则 MFA 无法访问 spacy/sudachipy 包
pkuseg_home = PERSISTENT_MODELS_DIR / "pkuseg" if PERSISTENT_MODELS_DIR.parent.exists() else Path("/root/.pkuseg")
pkuseg_home.mkdir(parents=True, exist_ok=True)
os.environ["PKUSEG_HOME"] = str(pkuseg_home)
logger.info("在 MFA 环境中安装日语/中文分词支持...")
conda_tools = _get_callable_conda_tools()
mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa")
# 优先使用conda-forge安装到MFA环境(最稳定)
deps_installed = False
# 尝试候选安装方式
install_attempts_deps = []
if conda_tools.get("micromamba"):
micromamba = conda_tools["micromamba"]
install_attempts_deps.extend([
("micromamba(conda-forge)",
[micromamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge",
"spacy", "sudachipy", "sudachidict-core"]),
("micromamba(pip in mfa env)",
[micromamba, "run", "-n", mfa_env_name, "pip", "install", "--no-cache-dir",
"spacy-pkuseg", "dragonmapper", "hanziconv"]),
])
if conda_tools.get("mamba"):
mamba = conda_tools["mamba"]
install_attempts_deps.append(
("mamba(conda-forge)",
[mamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge",
"spacy", "sudachipy", "sudachidict-core"])
)
if conda_tools.get("conda"):
conda = conda_tools["conda"]
install_attempts_deps.append(
("conda(conda-forge)",
[conda, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge",
"spacy", "sudachipy", "sudachidict-core"])
)
for installer_name, dep_cmd in install_attempts_deps:
logger.info(f"尝试用 {installer_name} 安装分词依赖...")
dep_result = subprocess.run(dep_cmd, capture_output=True, text=True, timeout=600, check=False)
if dep_result.returncode == 0:
logger.info(f"{installer_name} 分词依赖安装完成")
deps_installed = True
break
else:
stderr_tail = (dep_result.stderr or "")[-500:]
logger.warning(f"{installer_name} 分词依赖安装失败: {stderr_tail}")
if not deps_installed:
logger.warning("MFA 环境分词依赖安装失败,某些语言对齐可能不可用(继续运行)")
if verify_mfa_working():
logger.info("MFA 验证通过")
return True
logger.warning("MFA 安装后仍不可用,将以无 MFA 模式继续")
return False
except subprocess.TimeoutExpired as e:
logger.error(f"MFA 安装超时: {e}")
return False
except Exception as e:
logger.error(f"MFA 安装异常: {e}")
return False
def setup_environment():
"""初始化云端环境"""
global MODELS_DIR, MFA_DIR
# 初始化模型目录(可能创建符号链接)
MODELS_DIR = get_models_dir()
MFA_DIR = MODELS_DIR / "mfa"
# 检测运行环境
is_cloud = any([
os.environ.get("SPACE_ID"),
os.environ.get("MODELSCOPE_SPACE"),
os.environ.get("GRADIO_SERVER_NAME"),
Path("/home/studio_service").exists(),
])
if is_cloud or platform.system() != "Windows":
ensure_ffmpeg()
# 魔搭创空间无法访问 HuggingFace,使用镜像
if is_cloud and Path("/home/studio_service").exists():
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
logger.info("已设置 HuggingFace 镜像: hf-mirror.com")
# 设置 pkuseg 模型目录(中文分词依赖,必须在 MFA 运行前设置)
if platform.system() != "Windows":
pkuseg_home = PERSISTENT_MODELS_DIR / "pkuseg" if PERSISTENT_MODELS_DIR.parent.exists() else Path("/root/.pkuseg")
pkuseg_home.mkdir(parents=True, exist_ok=True)
os.environ["PKUSEG_HOME"] = str(pkuseg_home)
logger.info(f"设置 PKUSEG_HOME: {pkuseg_home}")
logger.info("Linux 环境,检查并安装 MFA...")
mfa_ok = setup_mfa_linux()
if not mfa_ok:
logger.warning("MFA 当前不可用,将跳过对齐功能但继续启动服务")
if is_cloud:
os.environ.setdefault("TMPDIR", "/tmp")
download_all_models()
else:
logger.info("本地环境运行")
def download_all_models():
"""下载所有必需的模型,任何模型下载失败则退出程序"""
logger.info("=" * 50)
logger.info("开始下载模型...")
logger.info("=" * 50)
# 确保目录存在
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(MFA_DIR, exist_ok=True)
errors = []
# 1. 下载 Silero VAD
if not download_silero_vad_model():
errors.append("Silero VAD")
# 2. 下载 Whisper 模型
if not download_whisper_models():
errors.append("Whisper")
# 3. 下载 MFA 模型(中文和日语)
if not download_mfa_models_all():
errors.append("MFA")
# 4. 下载 pkuseg 模型(中文分词必需,独立下载确保执行)
if not download_pkuseg_models():
errors.append("pkuseg")
if errors:
logger.error("=" * 50)
logger.error(f"以下模型加载失败: {', '.join(errors)}")
strict_mode = os.environ.get("JINRIKI_STRICT_MODEL_DOWNLOAD", "0") == "1"
if strict_mode:
logger.error("严格模式开启,程序退出(可通过 JINRIKI_STRICT_MODEL_DOWNLOAD=0 关闭)")
logger.error("=" * 50)
sys.exit(1)
logger.warning("将以降级模式继续启动(部分功能可能不可用)")
logger.warning("如需下载成功后再启动,请设置 JINRIKI_STRICT_MODEL_DOWNLOAD=1")
logger.error("=" * 50)
logger.info("=" * 50)
logger.info("所有模型下载完成")
logger.info("=" * 50)
def download_pkuseg_models() -> bool:
"""下载 pkuseg 中文分词模型,返回是否成功
spacy-pkuseg 检查模型的逻辑:
1. 先检查 PKUSEG_HOME/<model_name>.zip 是否存在
2. 如果 zip 存在,解压到 PKUSEG_HOME/<model_name>/ 目录
3. 如果 zip 不存在,从 GitHub 下载
因此我们需要保留 .zip 文件,否则 spacy_pkuseg 会尝试重新下载
注意:只需要 spacy_ontonotes.zip,postag.zip 在 GitHub releases 中不存在
"""
logger.info("\n【下载 pkuseg 模型】")
pkuseg_home = Path(os.environ.get("PKUSEG_HOME", "/root/.pkuseg"))
pkuseg_home.mkdir(parents=True, exist_ok=True)
pkuseg_model_dir = pkuseg_home / "spacy_ontonotes"
postag_model_dir = pkuseg_home / "postag"
# 关键:检查 spacy_ontonotes.zip 是否存在(spacy_pkuseg 的检查逻辑)
# postag.zip 在 GitHub releases 中不存在,不需要检查
spacy_ontonotes_zip = pkuseg_home / "spacy_ontonotes.zip"
def _pkuseg_model_ready() -> bool:
feature_candidates = [
pkuseg_model_dir / "features.msgpack",
pkuseg_model_dir / "features.pkl",
pkuseg_model_dir / "features.json",
]
has_feature = any(p.exists() for p in feature_candidates)
has_unigram = (pkuseg_model_dir / "unigram_word.txt").exists()
return spacy_ontonotes_zip.exists() and has_feature and has_unigram
def _repair_extract_from_zip() -> bool:
if not spacy_ontonotes_zip.exists():
return False
try:
pkuseg_model_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(spacy_ontonotes_zip, "r") as zf:
zf.extractall(pkuseg_model_dir)
# 兼容 zip 内部带一层 spacy_ontonotes/ 目录的情况
nested_dir = pkuseg_model_dir / "spacy_ontonotes"
if nested_dir.is_dir():
import shutil
for item in nested_dir.iterdir():
dst = pkuseg_model_dir / item.name
if dst.exists():
if dst.is_dir():
shutil.rmtree(dst, ignore_errors=True)
else:
dst.unlink(missing_ok=True)
shutil.move(str(item), str(dst))
shutil.rmtree(nested_dir, ignore_errors=True)
return _pkuseg_model_ready()
except Exception as e:
logger.warning(f"修复 pkuseg 解压失败: {e}")
return False
if _pkuseg_model_ready():
logger.info(f"pkuseg 模型已就绪: {pkuseg_home}")
return True
# zip 已存在但模型不完整:优先尝试重新解压修复
if spacy_ontonotes_zip.exists():
logger.warning("检测到 spacy_ontonotes.zip 已存在,但模型目录不完整,尝试修复解压...")
if _repair_extract_from_zip():
logger.info("pkuseg 模型修复成功")
return True
logger.warning("pkuseg 模型修复失败,准备重新下载")
try:
spacy_ontonotes_zip.unlink(missing_ok=True)
except Exception as e:
logger.warning(f"删除损坏 zip 失败: {e}")
# 检查是否有文件被错误解压到根目录(旧版本遗留问题)
root_msgpack = pkuseg_home / "features.msgpack"
if root_msgpack.exists():
logger.info("检测到模型文件在根目录,移动到正确位置...")
pkuseg_model_dir.mkdir(parents=True, exist_ok=True)
# 移动 spacy_ontonotes 相关文件
for filename in ["features.msgpack", "weights.npz"]:
src = pkuseg_home / filename
if src.exists():
dst = pkuseg_model_dir / filename
src.rename(dst)
logger.info(f"移动 {filename} -> spacy_ontonotes/")
# 移动 postag 相关文件
postag_model_dir.mkdir(parents=True, exist_ok=True)
for filename in ["features.pkl"]:
src = pkuseg_home / filename
if src.exists():
dst = postag_model_dir / filename
src.rename(dst)
logger.info(f"移动 {filename} -> postag/")
# 只检查 spacy_ontonotes.zip(这是 spacy_pkuseg 必需的)
# postag 模型在 GitHub releases 中不存在,spacy_pkuseg 会使用内置的词性标注
if _pkuseg_model_ready():
logger.info(f"pkuseg 模型已就绪: {pkuseg_home}")
return True
# 需要下载模型
logger.info("需要下载 pkuseg 模型: spacy_ontonotes")
# 使用 spacy-pkuseg 的模型(新格式 msgpack)
# 注意:必须保留 .zip 文件,spacy_pkuseg 会检查 zip 是否存在
# postag.zip 在 GitHub releases 中不存在,不需要下载
models = [
{
"name": "spacy_ontonotes",
"urls": [
"https://gitcode.com/gh_mirrors/sp/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip",
"https://ghfast.top/https://github.com/explosion/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip",
"https://gh-proxy.com/https://github.com/explosion/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip",
"https://github.com/explosion/spacy-pkuseg/releases/download/v0.0.26/spacy_ontonotes.zip",
],
"check_file": "features.msgpack",
},
]
for model in models:
model_name = model["name"]
model_dir = pkuseg_home / model_name
zip_path = pkuseg_home / f"{model_name}.zip"
check_file = model_dir / model["check_file"]
downloaded = False
timeout_seconds = int(os.environ.get("JINRIKI_PKUSEG_DOWNLOAD_TIMEOUT", "300"))
max_rounds = int(os.environ.get("JINRIKI_PKUSEG_DOWNLOAD_ROUNDS", "3"))
for round_idx in range(1, max_rounds + 1):
logger.info(f"{model_name} 下载轮次: {round_idx}/{max_rounds}")
for url in model["urls"]:
logger.info(f"下载 {model_name}: {url}")
try:
# 下载
result = subprocess.run(
["curl", "-fL", "--retry", "2", "--retry-delay", "2", "-o", str(zip_path), url],
capture_output=True, text=True, timeout=timeout_seconds
)
if result.returncode != 0:
logger.warning(f"curl 下载失败: {result.stderr}")
continue
if not zip_path.exists() or zip_path.stat().st_size < 1000:
logger.warning("下载文件无效或太小")
continue
logger.info(f"下载完成,文件大小: {zip_path.stat().st_size} bytes")
# 创建目标目录并解压到其中(优先 python zipfile,避免系统 unzip 差异)
model_dir.mkdir(parents=True, exist_ok=True)
try:
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(model_dir)
except Exception as zip_e:
logger.warning(f"zipfile 解压失败,尝试系统 unzip: {zip_e}")
result = subprocess.run(
["unzip", "-o", "-q", str(zip_path), "-d", str(model_dir)],
capture_output=True, text=True, timeout=120
)
if result.returncode != 0:
logger.warning(f"unzip 解压失败: {result.stderr}")
continue
# 兼容 zip 内部带一层 model_name/ 的情况
nested_dir = model_dir / model_name
if nested_dir.is_dir():
import shutil
for item in nested_dir.iterdir():
dst = model_dir / item.name
if dst.exists():
if dst.is_dir():
shutil.rmtree(dst, ignore_errors=True)
else:
dst.unlink(missing_ok=True)
shutil.move(str(item), str(dst))
shutil.rmtree(nested_dir, ignore_errors=True)
# 重要:保留 zip 文件!spacy_pkuseg 会检查 zip 是否存在
# 不要删除 zip_path
# 验证:至少要有 check_file + unigram_word.txt
unigram_file = model_dir / "unigram_word.txt"
if check_file.exists() and unigram_file.exists():
logger.info(f"{model_name} 下载并解压成功(保留 zip 文件)")
files = [f.name for f in model_dir.iterdir()]
logger.info(f"{model_name} 目录内容: {files}")
downloaded = True
break
else:
logger.warning(f"解压后关键文件缺失: {check_file.name} / {unigram_file.name}")
if model_dir.exists():
files = [f.name for f in model_dir.iterdir()]
logger.info(f"{model_name} 目录内容: {files}")
except subprocess.TimeoutExpired:
logger.warning(f"下载超时: {url}")
except Exception as e:
logger.warning(f"下载异常: {e}")
if downloaded:
break
sleep_seconds = min(5 * round_idx, 15)
logger.info(f"{model_name} 本轮未成功,{sleep_seconds} 秒后重试")
time.sleep(sleep_seconds)
if not downloaded:
logger.error(f"{model_name} 所有镜像下载失败")
return False
logger.info("pkuseg 模型下载完成")
return True
def download_silero_vad_model() -> bool:
"""下载 Silero VAD 模型,返回是否成功"""
logger.info("\n【下载 Silero VAD 模型】")
try:
from src.silero_vad_downloader import download_silero_vad, is_vad_model_downloaded
if is_vad_model_downloaded(str(MODELS_DIR)):
logger.info("Silero VAD 模型已存在,跳过下载")
return True
success, result = download_silero_vad(str(MODELS_DIR), logger.info)
if success:
logger.info(f"Silero VAD 下载成功: {result}")
return True
else:
logger.error(f"Silero VAD 下载失败: {result}")
return False
except Exception as e:
logger.error(f"Silero VAD 下载异常: {e}")
return False
def download_whisper_models() -> bool:
"""下载 Whisper 模型 (small 和 medium),返回是否成功"""
logger.info("\n【下载 Whisper 模型】")
try:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
cache_dir = str(MODELS_DIR / "whisper")
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
models = ["openai/whisper-small", "openai/whisper-medium"]
for model_name in models:
logger.info(f"下载 {model_name}...")
try:
# 检查是否已下载
model_cache_name = model_name.replace("/", "--")
model_cache_path = Path(cache_dir) / f"models--{model_cache_name}"
if model_cache_path.exists():
logger.info(f"{model_name} 已存在,跳过下载")
continue
# 下载 processor 和 model
_ = WhisperProcessor.from_pretrained(model_name, cache_dir=cache_dir)
_ = WhisperForConditionalGeneration.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
logger.info(f"{model_name} 下载完成")
except Exception as e:
logger.error(f"{model_name} 下载失败: {e}")
return False
return True
except Exception as e:
logger.error(f"Whisper 模型下载异常: {e}")
return False
def download_mfa_models_all() -> bool:
"""下载 MFA 中文和日语模型(带完整性校验),返回是否成功"""
logger.info("\n【下载 MFA 模型】")
try:
from src.mfa_model_downloader import download_language_models, _verify_file_integrity, LANGUAGE_MODELS
languages = ["mandarin", "japanese"]
per_language_retries = int(os.environ.get("JINRIKI_MFA_MODEL_RETRIES", "3"))
for lang in languages:
logger.info(f"\n下载 {lang} 模型...")
# 检查现有字典文件是否损坏
dict_config = LANGUAGE_MODELS[lang]["dictionary"]
dict_path = MFA_DIR / dict_config["filename"]
hash_path = MFA_DIR / (dict_config["filename"] + ".sha256")
if dict_path.exists():
# 如果没有哈希文件,说明是旧版本下载的,需要验证
if not hash_path.exists():
logger.info(f"检测到旧版字典文件(无哈希),验证完整性...")
min_lines = dict_config.get("min_lines")
is_valid, reason = _verify_file_integrity(str(dict_path), min_lines, logger.info)
if not is_valid:
logger.warning(f"字典文件损坏: {reason},删除并重新下载...")
try:
dict_path.unlink()
except Exception as e:
logger.error(f"删除损坏文件失败: {e}")
success = False
last_err = ""
for attempt in range(1, per_language_retries + 1):
logger.info(f"{lang} 下载尝试: {attempt}/{per_language_retries}")
ok, acoustic_path, dict_path = download_language_models(
lang, str(MFA_DIR), logger.info
)
if ok:
success = True
break
last_err = dict_path
wait_seconds = min(10 * attempt, 30)
logger.warning(f"{lang} 下载失败(第{attempt}次): {last_err}")
if attempt < per_language_retries:
logger.info(f"{wait_seconds} 秒后重试 {lang}...")
time.sleep(wait_seconds)
if not success:
logger.error(f"{lang} 模型下载失败")
return False
logger.info(f"{lang} 模型下载完成")
return True
except Exception as e:
logger.error(f"MFA 模型下载异常: {e}")
return False
def main():
"""主入口"""
setup_environment()
# 导入并启动云端 GUI
from src.gui_cloud import create_cloud_ui
app = create_cloud_ui()
# 云端配置
# 启用队列,魔搭CPU按需分配,无需设置并发上限
app.queue()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
)
if __name__ == "__main__":
main()