Spaces:
Sleeping
Sleeping
| """ | |
| 模型工具模块 | |
| 包含模型构建、检查点管理等通用功能。 | |
| """ | |
| import pathlib | |
| import re | |
| import warnings | |
| from env.resolve import resolve_path | |
| def extract_number_of_filename(filename: str) -> int: | |
| """ | |
| 从文件名中提取数字,无论数字出现在文件名的哪个位置。 | |
| 例如: | |
| - "model_epoch_001.weights.h5" -> 1 | |
| - "checkpoint_2024_06_30_epoch_002.weights.h5" -> 2 | |
| - "model_epoch_final.weights.h5" -> 抛出异常 | |
| :param filename: 包含数字的文件名字符串 | |
| :return: 提取的数字,如果没有数字则返回0 | |
| """ | |
| numbers = re.findall(r"\d+", filename) | |
| if numbers: | |
| return int(numbers[-1]) # 返回最后一个数字,假设它是代数 | |
| else: | |
| raise ValueError(f"No number found in filename: {filename}") | |
| def resolve_checkpoint( | |
| dirs: list[pathlib.Path | str] | None = None, | |
| path: pathlib.Path | str | None = None, | |
| epoch: int | None = None, | |
| suffix: str | None = None | |
| ): | |
| """统一解析模型检查点路径 | |
| 支持直接指定检查点文件路径或在目录中查找检查点文件。 | |
| 参数: | |
| dirs: 检查点目录列表 | |
| path: 直接指定的检查点文件路径(支持绝对路径和相对路径) | |
| epoch: 指定的 epoch,用于查找对应的 .weights.h5 文件 | |
| suffix: 指定检查点文件后缀 | |
| 返回: | |
| (resolved_path, epoch): 绝对路径和 epoch 数 | |
| 抛出: | |
| FileNotFoundError: 当指定的路径不存在或未找到检查点文件时 | |
| ValueError: 当参数无效时 | |
| """ | |
| resolved_dirs = _resolve_checkpoint_dirs(dirs) | |
| if path is not None: | |
| path = pathlib.Path(path) | |
| if not path.is_absolute(): | |
| if not resolved_dirs: | |
| raise ValueError("path 是相对路径时,必须提供 dirs") | |
| path = _resolve_relative_checkpoint_path(path, resolved_dirs) | |
| else: | |
| if dirs is not None: | |
| warnings.warn( | |
| "警告:path 是绝对路径,dirs 参数将被忽略", | |
| UserWarning | |
| ) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"检查点文件不存在: {path}") | |
| if suffix is not None and not path.name.endswith(suffix): | |
| raise FileNotFoundError(f"检查点文件后缀不匹配: {path}") | |
| try: | |
| epoch_num = extract_number_of_filename(path.stem) | |
| except ValueError: | |
| epoch_num = 0 | |
| return path, epoch_num | |
| if not resolved_dirs: | |
| raise ValueError("必须提供 dirs 或 path") | |
| files_with_number = _collect_checkpoint_files( | |
| checkpoint_dirs=resolved_dirs, | |
| suffix=suffix | |
| ) | |
| if epoch is not None: | |
| matches = [(f, num) for f, num in files_with_number if num == epoch] | |
| if not matches: | |
| raise FileNotFoundError(f"未找到 epoch {epoch} 对应的检查点文件") | |
| if len(matches) > 1: | |
| raise RuntimeError( | |
| f"找到多个 epoch {epoch} 对应的检查点文件: {[match[0].name for match in matches]}" | |
| ) | |
| return matches[0] | |
| if not files_with_number: | |
| return None, 0 | |
| return max(files_with_number, key=lambda item: item[1]) | |
| def describe_checkpoint_lookup( | |
| dirs: list[pathlib.Path | str] | None = None, | |
| path: pathlib.Path | str | None = None, | |
| suffix: str | None = None | |
| ) -> str: | |
| resolved_dirs = _resolve_checkpoint_dirs(dirs) | |
| parts = [] | |
| if path is not None: | |
| parts.append(f"path={resolve_path(path)}") | |
| if suffix is not None: | |
| parts.append(f"suffix={suffix}") | |
| if resolved_dirs: | |
| dir_infos = [] | |
| for checkpoint_dir in resolved_dirs: | |
| dir_infos.append(_describe_checkpoint_dir(checkpoint_dir)) | |
| parts.append("dirs=[" + "; ".join(dir_infos) + "]") | |
| else: | |
| parts.append("dirs=[]") | |
| return ",".join(parts) | |
| def _resolve_checkpoint_dirs( | |
| dirs: list[pathlib.Path | str] | None | |
| ) -> list[pathlib.Path]: | |
| if dirs is None: | |
| return [] | |
| return [resolve_path(path) for path in dirs] | |
| def _resolve_relative_checkpoint_path( | |
| checkpoint_path: pathlib.Path, | |
| checkpoint_dirs: list[pathlib.Path] | |
| ) -> pathlib.Path: | |
| for checkpoint_dir in checkpoint_dirs: | |
| candidate = checkpoint_dir / checkpoint_path | |
| if candidate.exists(): | |
| return candidate | |
| return checkpoint_dirs[0] / checkpoint_path | |
| def _collect_checkpoint_files( | |
| checkpoint_dirs: list[pathlib.Path], | |
| suffix: str | None | |
| ) -> list[tuple[pathlib.Path, int]]: | |
| files_with_number = [] | |
| for checkpoint_dir in checkpoint_dirs: | |
| if not checkpoint_dir.exists(): | |
| continue | |
| for file_path in sorted(checkpoint_dir.iterdir()): | |
| if not file_path.is_file(): | |
| continue | |
| if suffix is not None and not file_path.name.endswith(suffix): | |
| continue | |
| if suffix is None and not _is_checkpoint_file(file_path): | |
| continue | |
| files_with_number.append((file_path, extract_number_of_filename(file_path.stem))) | |
| return files_with_number | |
| def _describe_checkpoint_dir(checkpoint_dir: pathlib.Path) -> str: | |
| exists = checkpoint_dir.exists() | |
| if not exists: | |
| return f"{checkpoint_dir} (missing)" | |
| entries = [] | |
| for file_path in sorted(checkpoint_dir.iterdir()): | |
| if file_path.is_file(): | |
| entries.append(file_path.name) | |
| shown_entries = entries[:5] | |
| if len(entries) > 5: | |
| shown_entries.append("...") | |
| files_text = ", ".join(shown_entries) if shown_entries else "<empty>" | |
| return f"{checkpoint_dir} (files: {files_text})" | |
| def _is_checkpoint_file(file_path: pathlib.Path) -> bool: | |
| return file_path.name.endswith(".keras") or file_path.name.endswith(".weights.h5") | |