JinrikiHelper / src /mfa_runner.py
TNOT's picture
fix: self-heal pkuseg model files before MFA tokenization
a84527e
# -*- coding: utf-8 -*-
"""
MFA 调用模块
支持 Windows (外挂模式) 和 Linux (系统安装) 双平台
"""
import os
import platform
import shutil
import subprocess
import logging
import time
import re
import zipfile
from pathlib import Path
from typing import Optional, Callable
logger = logging.getLogger(__name__)
PROB_PATTERN = re.compile(r"\b(\d+\.\d+|1)\b")
# 定位路径
BASE_DIR = Path(__file__).parent.parent.absolute()
MFA_ENGINE_DIR = BASE_DIR / "tools" / "mfa_engine"
MFA_PYTHON = MFA_ENGINE_DIR / "python.exe"
# 默认模型路径
DEFAULT_DICT_PATH = BASE_DIR / "models" / "mandarin.dict"
DEFAULT_MODEL_PATH = BASE_DIR / "models" / "mandarin.zip"
DEFAULT_TEMP_DIR = BASE_DIR / "mfa_temp"
# 平台检测
IS_WINDOWS = platform.system() == "Windows"
def _get_callable_conda_tools() -> dict[str, str]:
"""查找可调用的 conda 类工具(conda/mamba/micromamba)。"""
tools = {}
for tool_name in ["conda", "mamba", "micromamba"]:
tool_path = shutil.which(tool_name)
if not tool_path:
continue
try:
result = subprocess.run(
[tool_path, "--version"],
capture_output=True,
text=True,
timeout=20,
)
if result.returncode == 0:
tools[tool_name] = tool_path
else:
logger.warning(f"检测到 {tool_name} 但不可调用,已跳过: {tool_path}")
except Exception as e:
logger.warning(f"探测 {tool_name} 失败,已跳过: {e}")
return tools
def _resolve_linux_mfa_command() -> Optional[list]:
"""解析 Linux/macOS 下可用的 MFA 命令入口,优先官方 conda/mamba 方案。"""
candidates = []
mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa")
env_names = []
for env_name in [mfa_env_name, "base"]:
if env_name and env_name not in env_names:
env_names.append(env_name)
conda_tools = _get_callable_conda_tools()
conda = conda_tools.get("conda")
if conda:
for env_name in env_names:
candidates.append([conda, "run", "-n", env_name, "mfa"])
micromamba = conda_tools.get("micromamba")
if micromamba:
for env_name in env_names:
candidates.append([micromamba, "run", "-n", env_name, "mfa"])
mamba = conda_tools.get("mamba")
if mamba:
for env_name in env_names:
candidates.append([mamba, "run", "-n", env_name, "mfa"])
# 兜底:系统 PATH 里的 mfa(可能是 pip 版)
candidates.append(["mfa"])
for cmd in candidates:
try:
result = subprocess.run(
cmd + ["--help"],
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
logger.info(f"MFA 入口可用: {' '.join(cmd)}")
return cmd
output = f"{result.stdout}\n{result.stderr}"
if "No module named '_kalpy'" in output:
logger.warning(
f"MFA 入口不可用(缺少 _kalpy): {' '.join(cmd)},请使用 conda-forge 方式安装"
)
except Exception as e:
logger.debug(f"探测 MFA 入口失败: {' '.join(cmd)} -> {e}")
return None
def check_mfa_available() -> bool:
"""
检查 MFA 是否可用
Windows: 检查外挂 Python 环境
Linux: 检查 mfa 命令是否可用
"""
if IS_WINDOWS:
if not MFA_ENGINE_DIR.exists():
logger.warning(f"MFA 引擎目录不存在: {MFA_ENGINE_DIR}")
return False
if not MFA_PYTHON.exists():
logger.warning(f"MFA Python 不存在: {MFA_PYTHON}")
return False
return True
else:
# Linux/macOS: 自动解析可用入口,优先 conda/mamba
cmd = _resolve_linux_mfa_command()
if not cmd:
logger.warning("未找到可用 MFA 入口,请按官方推荐使用 conda-forge 安装: conda install -c conda-forge montreal-forced-aligner")
return False
# 验证 mfa 能正常运行
# 云端首次运行可能需要较长时间初始化,设置 120 秒超时
try:
result = subprocess.run(
cmd + ["version"],
capture_output=True,
text=True,
timeout=120
)
if result.returncode == 0:
logger.info(f"MFA 可用: {result.stdout.strip()}")
return True
output = f"{result.stdout}\n{result.stderr}"
if "No module named '_kalpy'" in output:
logger.warning("MFA 运行失败:检测到 pip 版缺少 _kalpy,请改用 conda-forge 安装的 MFA")
else:
logger.warning(f"MFA 命令执行失败: {result.stderr or result.stdout}")
except subprocess.TimeoutExpired:
logger.warning("MFA 验证超时(120秒),可能正在初始化,将尝试继续使用")
# 超时但命令存在,假设可用(实际对齐时会再次验证)
return True
except Exception as e:
logger.warning(f"MFA 验证异常: {e}")
return False
def _ensure_mfa_japanese_support() -> bool:
"""在 MFA 环境中安装日语 tokenizer 支持(sudachipy/sudachidict-core)。
这是必需的,因为 MFA 检测到日语文本时会调用 sudachipy,
如果环境中缺少会抛出 ImportError。
返回:
bool: 是否成功安装或已存在
"""
if IS_WINDOWS:
# Windows 外挂模式:日语支持通常已包含在 mfa_engine Python 中
return True
mfa_env_name = os.environ.get("JINRIKI_MFA_ENV_NAME", "mfa")
conda_tools = _get_callable_conda_tools()
if not conda_tools:
logger.warning("未找到 conda/mamba,无法在 MFA 环境安装日语支持")
return False
# 尝试用各种方式在 MFA 环境中安装依赖
install_attempts = []
if conda_tools.get("micromamba"):
micromamba = conda_tools["micromamba"]
install_attempts.append(
("micromamba(conda-forge)",
[micromamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge",
"sudachipy", "sudachidict-core"])
)
if conda_tools.get("mamba"):
mamba = conda_tools["mamba"]
install_attempts.append(
("mamba(conda-forge)",
[mamba, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge",
"sudachipy", "sudachidict-core"])
)
if conda_tools.get("conda"):
conda = conda_tools["conda"]
install_attempts.append(
("conda(conda-forge)",
[conda, "install", "-y", "-n", mfa_env_name, "-c", "conda-forge",
"sudachipy", "sudachidict-core"])
)
# 如果 conda 安装失败,尝试用 mamba run pip install
if conda_tools.get("micromamba"):
micromamba = conda_tools["micromamba"]
install_attempts.append(
("micromamba(pip-in-env)",
[micromamba, "run", "-n", mfa_env_name, "pip", "install", "--no-cache-dir",
"sudachipy", "sudachidict-core"])
)
logger.info("检查 MFA 环境中的日语 tokenizer 支持...")
for installer_name, cmd in install_attempts:
try:
logger.debug(f"尝试: {installer_name} | {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, check=False)
if result.returncode == 0:
logger.info(f"日语支持安装完成: {installer_name}")
return True
else:
stderr_tail = (result.stderr or "")[-300:]
logger.debug(f"{installer_name} 失败: {stderr_tail}")
except subprocess.TimeoutExpired:
logger.warning(f"{installer_name} 超时(300秒)")
except Exception as e:
logger.debug(f"{installer_name} 异常: {e}")
logger.warning("MFA 环境日语支持安装失败,日语对齐可能会失败(继续运行)")
return False
def _get_mfa_command() -> list:
"""
获取 MFA 命令前缀
Windows: 使用外挂 Python 调用
Linux: 使用系统 mfa 命令
"""
if IS_WINDOWS:
return [str(MFA_PYTHON), "-m", "montreal_forced_aligner"]
else:
return _resolve_linux_mfa_command() or ["mfa"]
def _build_mfa_env(mfa_root: Optional[Path] = None) -> dict:
"""
构造 MFA 专用环境变量
参数:
mfa_root: 会话独立的 MFA 数据目录(用于并发隔离)
"""
env = os.environ.copy()
if IS_WINDOWS:
# Windows: 必须把 Library\bin 加入 PATH,否则 Kaldi DLL 找不到
mfa_paths = [
str(MFA_ENGINE_DIR),
str(MFA_ENGINE_DIR / "Library" / "bin"),
str(MFA_ENGINE_DIR / "Scripts"),
str(MFA_ENGINE_DIR / "bin"),
]
env["PATH"] = ";".join(mfa_paths) + ";" + env.get("PATH", "")
else:
# Linux: 设置会话独立的 MFA_ROOT_DIR(解决并发数据库冲突)
if mfa_root:
env["MFA_ROOT_DIR"] = str(mfa_root)
logger.info(f"设置会话独立 MFA_ROOT_DIR: {mfa_root}")
# Linux: 设置 pkuseg 模型目录(云端使用持久化路径)
persistent_models = Path("/home/studio_service/models")
if persistent_models.exists():
pkuseg_home = persistent_models / "pkuseg"
pkuseg_home.mkdir(parents=True, exist_ok=True)
env["PKUSEG_HOME"] = str(pkuseg_home)
logger.info(f"设置 PKUSEG_HOME: {pkuseg_home}")
# 验证并自愈 pkuseg 模型目录(避免 MFA 在 normalize_text 阶段崩溃)
spacy_ontonotes_zip = pkuseg_home / "spacy_ontonotes.zip"
model_dir = pkuseg_home / "spacy_ontonotes"
required_candidates = [
model_dir / "features.msgpack",
model_dir / "features.pkl",
model_dir / "features.json",
]
has_feature = any(p.exists() for p in required_candidates)
has_unigram = (model_dir / "unigram_word.txt").exists()
if spacy_ontonotes_zip.exists():
logger.info(f"pkuseg 模型 zip 已存在: {spacy_ontonotes_zip}")
if not (has_feature and has_unigram):
logger.warning("pkuseg 模型目录不完整,尝试从 zip 自动修复...")
try:
model_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(spacy_ontonotes_zip, "r") as zf:
zf.extractall(model_dir)
# 兼容 zip 内部带一层 spacy_ontonotes/ 的情况
nested_dir = model_dir / "spacy_ontonotes"
if nested_dir.is_dir():
for item in nested_dir.iterdir():
dst = model_dir / item.name
if dst.exists():
if dst.is_dir():
shutil.rmtree(dst, ignore_errors=True)
else:
dst.unlink(missing_ok=True)
shutil.move(str(item), str(dst))
shutil.rmtree(nested_dir, ignore_errors=True)
has_feature = any(p.exists() for p in required_candidates)
has_unigram = (model_dir / "unigram_word.txt").exists()
if has_feature and has_unigram:
logger.info("pkuseg 模型目录修复成功")
else:
logger.warning("pkuseg 模型目录修复后仍不完整,建议触发重新下载")
except Exception as e:
logger.warning(f"pkuseg 自动修复失败: {e}")
else:
logger.warning(f"pkuseg 模型 zip 不存在: {spacy_ontonotes_zip}")
# 列出目录内容供调试
if pkuseg_home.exists():
files = list(pkuseg_home.iterdir())
logger.info(f"pkuseg 目录内容: {[f.name for f in files]}")
# 确保从系统环境继承 PKUSEG_HOME(如果已设置)
if "PKUSEG_HOME" not in env and os.environ.get("PKUSEG_HOME"):
env["PKUSEG_HOME"] = os.environ["PKUSEG_HOME"]
return env
def _clean_dict_empty_lines(dict_path: str) -> int:
"""
清理字典文件中的空行和无效行
MFA 3.x 解析字典时遇到空行会报 IndexError
返回: 清理的无效行数量
"""
try:
# utf-8-sig: 自动兼容可能存在的 BOM
with open(dict_path, 'r', encoding='utf-8-sig', errors='replace') as f:
lines = f.readlines()
original_count = len(lines)
# 过滤空行/注释/无效行,并重写为标准格式:word<TAB>phones...
# 这样可以最大化规避 MFA 解析器在边缘行上的 IndexError。
sanitized_lines = []
malformed_count = 0
prob_only_count = 0
comment_count = 0
for line in lines:
stripped = line.replace('\ufeff', '').strip()
# 跳过空行
if not stripped:
continue
# 常见注释行
if stripped.startswith('#') or stripped.startswith(';') or stripped.startswith('//'):
comment_count += 1
continue
tokens = stripped.split()
if len(tokens) < 2:
malformed_count += 1
continue
rest = tokens[1:]
# 与 MFA parse_dictionary_file 对齐:允许 1~4 个前置概率字段
# 但概率字段之后必须至少有一个音素,否则 MFA 内部会 IndexError。
idx = 0
while idx < len(rest) and idx < 4 and PROB_PATTERN.match(rest[idx]):
idx += 1
if idx >= len(rest):
prob_only_count += 1
continue
word = tokens[0]
pronunciation = " ".join(tokens[1:])
sanitized_lines.append(f"{word}\t{pronunciation}\n")
removed_count = original_count - len(sanitized_lines)
# 无论是否移除行,都强制重写为标准 tab 分隔,避免 MFA 因格式边缘情况报错。
with open(dict_path, 'w', encoding='utf-8') as f:
f.writelines(sanitized_lines)
if removed_count > 0:
logger.info(
f"字典文件清理完成: 原 {original_count} 行, 现 {len(sanitized_lines)} 行, "
f"移除 {removed_count} 行(注释 {comment_count} 行, 格式异常 {malformed_count} 行, 概率无音素 {prob_only_count} 行)"
)
else:
logger.info(f"字典文件标准化完成: 共 {len(sanitized_lines)} 行(已统一为 tab 分隔)")
return removed_count
except PermissionError as e:
logger.error(f"清理字典文件失败 - 权限不足: {e}")
return 0
except Exception as e:
logger.error(f"清理字典文件失败: {e}")
return 0
def _create_isolated_mfa_root(session_id: str) -> Path:
"""
为每个会话创建独立的 MFA_ROOT_DIR,避免多用户并发时数据库冲突
MFA 使用 MFA_ROOT_DIR 环境变量指定数据目录,包含:
- pretrained_models/: 预训练模型缓存
- 各种 .db 文件: SQLite 数据库
通过为每个会话创建独立目录,完全隔离并发用户
"""
import tempfile
# 在系统临时目录下创建会话专属的 MFA 根目录
mfa_root = Path(tempfile.gettempdir()) / f"mfa_session_{session_id}"
mfa_root.mkdir(parents=True, exist_ok=True)
logger.info(f"创建会话独立 MFA 目录: {mfa_root}")
return mfa_root
def _cleanup_isolated_mfa_root(mfa_root: Path):
"""清理会话独立的 MFA 目录"""
if mfa_root and mfa_root.exists() and "mfa_session_" in str(mfa_root):
try:
shutil.rmtree(mfa_root)
logger.info(f"已清理会话 MFA 目录: {mfa_root}")
except Exception as e:
logger.warning(f"清理会话 MFA 目录失败: {e}")
def run_mfa_alignment(
corpus_dir: str,
output_dir: str,
dict_path: Optional[str] = None,
model_path: Optional[str] = None,
temp_dir: Optional[str] = None,
single_speaker: bool = True,
clean: bool = True,
num_jobs: Optional[int] = None,
progress_callback: Optional[Callable[[str], None]] = None,
cancel_checker: Optional[Callable[[], bool]] = None,
timeout_seconds: Optional[int] = None
) -> tuple[bool, str]:
"""
执行 MFA 对齐
参数:
corpus_dir: 包含 wav 和 lab/txt 的输入目录
output_dir: TextGrid 输出目录
dict_path: 字典文件路径,默认使用 models/mandarin.dict
model_path: 声学模型路径,默认使用 models/mandarin.zip
temp_dir: 临时目录,默认使用 mfa_temp(云端会自动创建独立目录)
single_speaker: 是否为单说话人模式
clean: 是否清理旧缓存
num_jobs: 并行进程数,默认使用 CPU 核心数
progress_callback: 进度回调函数
返回:
(成功标志, 输出信息或错误信息)
"""
import uuid
def log(msg: str):
logger.info(msg)
if progress_callback:
progress_callback(msg)
# 为本次会话创建独立的 MFA 数据目录(并发安全)
session_id = uuid.uuid4().hex[:8]
isolated_mfa_root = _create_isolated_mfa_root(session_id) if not IS_WINDOWS else None
# 检查环境
if not check_mfa_available():
platform_hint = "tools/mfa_engine 目录" if IS_WINDOWS else "conda install -c conda-forge montreal-forced-aligner"
return False, f"MFA 环境不可用,请检查 {platform_hint}"
# 确保 MFA 环境有日语/中文分词支持(Linux 环境下自动安装)
if not IS_WINDOWS:
_ensure_mfa_japanese_support()
# 设置默认路径
dict_path = dict_path or str(DEFAULT_DICT_PATH)
model_path = model_path or str(DEFAULT_MODEL_PATH)
# 临时目录:如果未指定,创建独立目录避免多用户冲突
if temp_dir is None:
session_id = uuid.uuid4().hex[:8]
temp_dir = str(DEFAULT_TEMP_DIR / f"session_{session_id}")
# 验证路径
if not os.path.isdir(corpus_dir):
return False, f"输入目录不存在: {corpus_dir}"
if not os.path.isfile(dict_path):
return False, f"字典文件不存在: {dict_path}"
if not os.path.isfile(model_path):
return False, f"声学模型不存在: {model_path}"
# 创建输出和临时目录
os.makedirs(output_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)
# 拷贝字典到会话临时目录后再清理,避免修改共享模型文件
temp_dict_path = os.path.join(temp_dir, f"dict_{uuid.uuid4().hex[:8]}.dict")
shutil.copy2(dict_path, temp_dict_path)
log(f"检查字典文件: {temp_dict_path}")
removed = _clean_dict_empty_lines(temp_dict_path)
if removed > 0:
log(f"已清理字典文件中的 {removed} 个无效行")
# 构造命令
cmd = _get_mfa_command() + [
"align",
str(corpus_dir),
str(temp_dict_path),
str(model_path),
str(output_dir),
"--temp_directory", str(temp_dir),
]
# 设置并行进程数(默认使用 CPU 核心数,最少 1 个)
import multiprocessing
if num_jobs is None:
num_jobs = max(1, multiprocessing.cpu_count())
cmd.extend(["--num_jobs", str(num_jobs)])
# Windows 外挂模式:启用多进程可能有兼容性问题,但可以尝试
# 如果遇到问题,用户可以通过设置 num_jobs=1 来禁用
# 注释掉原来的禁用逻辑,让 Windows 也能使用多进程
# if IS_WINDOWS:
# cmd.extend(["--use_mp", "false"])
if clean:
cmd.append("--clean")
if single_speaker:
cmd.append("--single_speaker")
log(f"正在启动 MFA 对齐引擎...")
log(f"运行平台: {'Windows (外挂模式)' if IS_WINDOWS else 'Linux (系统安装)'}")
log(f"并行进程数: {num_jobs}")
log(f"输入目录: {corpus_dir}")
log(f"输出目录: {output_dir}")
if timeout_seconds is None:
timeout_seconds = int(os.environ.get("JINRIKI_MFA_TIMEOUT_SECONDS", "1800"))
try:
env = _build_mfa_env(isolated_mfa_root)
process = subprocess.Popen(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8',
errors='replace'
)
start_time = time.time()
while process.poll() is None:
if cancel_checker and cancel_checker():
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
log("MFA 任务已取消")
return False, "任务已取消"
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
msg = f"MFA 执行超时(>{timeout_seconds}秒)"
log(msg)
return False, msg
time.sleep(1)
stdout, stderr = process.communicate()
if process.returncode == 0:
log("MFA 对齐完成!")
# 清理临时目录(仅清理会话独立目录)
if "session_" in temp_dir and os.path.exists(temp_dir):
try:
shutil.rmtree(temp_dir)
log(f"已清理临时目录: {temp_dir}")
except Exception as e:
logger.warning(f"清理临时目录失败: {e}")
return True, stdout
else:
error_msg = stderr or stdout or "未知错误"
if "parse_dictionary_file" in error_msg and "IndexError: list index out of range" in error_msg:
error_msg = (
"字典解析失败(IndexError)。已尝试自动清理字典,"
"请检查字典中是否存在异常符号行或损坏内容。\n\n" + error_msg
)
log(f"MFA 运行出错: {error_msg}")
return False, error_msg
except FileNotFoundError as e:
msg = f"找不到 MFA 命令: {e}"
log(msg)
return False, msg
except Exception as e:
msg = f"MFA 执行异常: {e}"
log(msg)
return False, msg
finally:
# 确保临时目录被清理(即使出错)
if "session_" in temp_dir and os.path.exists(temp_dir):
try:
shutil.rmtree(temp_dir)
except Exception:
pass
# 清理会话独立的 MFA 数据目录
if isolated_mfa_root:
_cleanup_isolated_mfa_root(isolated_mfa_root)
def run_mfa_validate(
corpus_dir: str,
dict_path: Optional[str] = None,
progress_callback: Optional[Callable[[str], None]] = None
) -> tuple[bool, str]:
"""
验证语料库格式是否正确
参数:
corpus_dir: 语料库目录
dict_path: 字典文件路径
progress_callback: 进度回调函数
返回:
(成功标志, 输出信息)
"""
def log(msg: str):
logger.info(msg)
if progress_callback:
progress_callback(msg)
if not check_mfa_available():
return False, "MFA 环境不可用"
dict_path = dict_path or str(DEFAULT_DICT_PATH)
cmd = _get_mfa_command() + [
"validate",
str(corpus_dir),
str(dict_path),
]
log("正在验证语料库...")
try:
env = _build_mfa_env()
result = subprocess.run(
cmd,
env=env,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace'
)
output = result.stdout + "\n" + result.stderr
log("验证完成")
return result.returncode == 0, output
except Exception as e:
return False, str(e)
def install_mfa_model(
model_type: str,
model_name: str,
progress_callback: Optional[Callable[[str], None]] = None
) -> tuple[bool, str]:
"""
下载 MFA 预训练模型 (仅 Linux 支持)
参数:
model_type: 模型类型 ("acoustic" 或 "dictionary")
model_name: 模型名称 (如 "mandarin_mfa", "mandarin_china_mfa")
progress_callback: 进度回调函数
返回:
(成功标志, 输出信息)
"""
def log(msg: str):
logger.info(msg)
if progress_callback:
progress_callback(msg)
if IS_WINDOWS:
return False, "Windows 平台请手动下载模型文件"
if not check_mfa_available():
return False, "MFA 环境不可用"
cmd = _get_mfa_command() + [
"model", "download", model_type, model_name
]
log(f"正在下载 MFA 模型: {model_type}/{model_name}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
encoding='utf-8',
errors='replace'
)
if result.returncode == 0:
log(f"模型下载完成: {model_name}")
return True, result.stdout
else:
error_msg = result.stderr or result.stdout or "未知错误"
log(f"模型下载失败: {error_msg}")
return False, error_msg
except Exception as e:
return False, str(e)
def get_mfa_model_path(model_type: str, model_name: str) -> Optional[str]:
"""
获取 MFA 模型路径
Linux: 返回 MFA 内置模型名称 (mfa 会自动查找)
Windows: 返回本地文件路径
参数:
model_type: 模型类型 ("acoustic" 或 "dictionary")
model_name: 模型名称
返回:
模型路径或名称,不存在返回 None
"""
if IS_WINDOWS:
# Windows: 使用本地文件
mfa_dir = BASE_DIR / "models" / "mfa"
if model_type == "acoustic":
path = mfa_dir / f"{model_name}.zip"
else:
path = mfa_dir / f"{model_name}.dict"
return str(path) if path.exists() else None
else:
# Linux: 直接返回模型名称,mfa 会从缓存中查找
return model_name