yetrun's picture
适配 hf-space
8aa4d8f
"""
模型工具模块
包含模型构建、检查点管理等通用功能。
"""
import pathlib
import re
import warnings
from env.resolve import resolve_path
def extract_number_of_filename(filename: str) -> int:
"""
从文件名中提取数字,无论数字出现在文件名的哪个位置。
例如:
- "model_epoch_001.weights.h5" -> 1
- "checkpoint_2024_06_30_epoch_002.weights.h5" -> 2
- "model_epoch_final.weights.h5" -> 抛出异常
:param filename: 包含数字的文件名字符串
:return: 提取的数字,如果没有数字则返回0
"""
numbers = re.findall(r"\d+", filename)
if numbers:
return int(numbers[-1]) # 返回最后一个数字,假设它是代数
else:
raise ValueError(f"No number found in filename: {filename}")
def resolve_checkpoint(
dirs: list[pathlib.Path | str] | None = None,
path: pathlib.Path | str | None = None,
epoch: int | None = None,
suffix: str | None = None
):
"""统一解析模型检查点路径
支持直接指定检查点文件路径或在目录中查找检查点文件。
参数:
dirs: 检查点目录列表
path: 直接指定的检查点文件路径(支持绝对路径和相对路径)
epoch: 指定的 epoch,用于查找对应的 .weights.h5 文件
suffix: 指定检查点文件后缀
返回:
(resolved_path, epoch): 绝对路径和 epoch 数
抛出:
FileNotFoundError: 当指定的路径不存在或未找到检查点文件时
ValueError: 当参数无效时
"""
resolved_dirs = _resolve_checkpoint_dirs(dirs)
if path is not None:
path = pathlib.Path(path)
if not path.is_absolute():
if not resolved_dirs:
raise ValueError("path 是相对路径时,必须提供 dirs")
path = _resolve_relative_checkpoint_path(path, resolved_dirs)
else:
if dirs is not None:
warnings.warn(
"警告:path 是绝对路径,dirs 参数将被忽略",
UserWarning
)
if not path.exists():
raise FileNotFoundError(f"检查点文件不存在: {path}")
if suffix is not None and not path.name.endswith(suffix):
raise FileNotFoundError(f"检查点文件后缀不匹配: {path}")
try:
epoch_num = extract_number_of_filename(path.stem)
except ValueError:
epoch_num = 0
return path, epoch_num
if not resolved_dirs:
raise ValueError("必须提供 dirs 或 path")
files_with_number = _collect_checkpoint_files(
checkpoint_dirs=resolved_dirs,
suffix=suffix
)
if epoch is not None:
matches = [(f, num) for f, num in files_with_number if num == epoch]
if not matches:
raise FileNotFoundError(f"未找到 epoch {epoch} 对应的检查点文件")
if len(matches) > 1:
raise RuntimeError(
f"找到多个 epoch {epoch} 对应的检查点文件: {[match[0].name for match in matches]}"
)
return matches[0]
if not files_with_number:
return None, 0
return max(files_with_number, key=lambda item: item[1])
def describe_checkpoint_lookup(
dirs: list[pathlib.Path | str] | None = None,
path: pathlib.Path | str | None = None,
suffix: str | None = None
) -> str:
resolved_dirs = _resolve_checkpoint_dirs(dirs)
parts = []
if path is not None:
parts.append(f"path={resolve_path(path)}")
if suffix is not None:
parts.append(f"suffix={suffix}")
if resolved_dirs:
dir_infos = []
for checkpoint_dir in resolved_dirs:
dir_infos.append(_describe_checkpoint_dir(checkpoint_dir))
parts.append("dirs=[" + "; ".join(dir_infos) + "]")
else:
parts.append("dirs=[]")
return ",".join(parts)
def _resolve_checkpoint_dirs(
dirs: list[pathlib.Path | str] | None
) -> list[pathlib.Path]:
if dirs is None:
return []
return [resolve_path(path) for path in dirs]
def _resolve_relative_checkpoint_path(
checkpoint_path: pathlib.Path,
checkpoint_dirs: list[pathlib.Path]
) -> pathlib.Path:
for checkpoint_dir in checkpoint_dirs:
candidate = checkpoint_dir / checkpoint_path
if candidate.exists():
return candidate
return checkpoint_dirs[0] / checkpoint_path
def _collect_checkpoint_files(
checkpoint_dirs: list[pathlib.Path],
suffix: str | None
) -> list[tuple[pathlib.Path, int]]:
files_with_number = []
for checkpoint_dir in checkpoint_dirs:
if not checkpoint_dir.exists():
continue
for file_path in sorted(checkpoint_dir.iterdir()):
if not file_path.is_file():
continue
if suffix is not None and not file_path.name.endswith(suffix):
continue
if suffix is None and not _is_checkpoint_file(file_path):
continue
files_with_number.append((file_path, extract_number_of_filename(file_path.stem)))
return files_with_number
def _describe_checkpoint_dir(checkpoint_dir: pathlib.Path) -> str:
exists = checkpoint_dir.exists()
if not exists:
return f"{checkpoint_dir} (missing)"
entries = []
for file_path in sorted(checkpoint_dir.iterdir()):
if file_path.is_file():
entries.append(file_path.name)
shown_entries = entries[:5]
if len(entries) > 5:
shown_entries.append("...")
files_text = ", ".join(shown_entries) if shown_entries else "<empty>"
return f"{checkpoint_dir} (files: {files_text})"
def _is_checkpoint_file(file_path: pathlib.Path) -> bool:
return file_path.name.endswith(".keras") or file_path.name.endswith(".weights.h5")