general-deep-learning / test /models /checkpoint_test.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""检查点解析功能单元测试
测试 resolve_checkpoint 函数的各种场景。
测试场景:
resolve_checkpoint 测试场景:
├── path 提供
│ ├── 绝对路径
│ │ ├── 存在
│ │ │ ├── .keras文件 → 成功返回
│ │ │ └── .weights.h5文件 → 成功返回
│ │ ├── 存在(同时提供dirs)→ 使用绝对路径,忽略dirs,打印警告
│ │ ├── suffix不匹配 → FileNotFoundError
│ │ └── 不存在 → FileNotFoundError
│ └── 相对路径
│ ├── dirs提供
│ │ ├── 单目录 → 成功解析
│ │ └── 多目录按顺序查找 → 成功解析
│ └── dirs=None → ValueError
└── path 未提供
└── dirs 提供
├── epoch=None
│ ├── 单目录
│ │ ├── 存在 .weights.h5 / .keras 文件 → 返回最新的
│ │ ├── 存在但为空 → 返回 (None, 0)
│ │ └── 目录不存在 → 返回 (None, 0)
│ └── 多目录 → 返回全局最新的检查点
└── epoch指定
├── 未指定suffix
│ ├── 存在对应epoch → 返回对应epoch的检查点
│ └── epoch不存在 → FileNotFoundError
└── 指定suffix
├── 存在对应后缀 → 返回对应检查点
└── 无对应后缀 → FileNotFoundError
└── 两者都为None → ValueError
extract_number_of_filename 测试场景:
├── 正常提取
│ ├── 从包含 epoch 的文件名中提取数字 → 返回数字
│ ├── 从多个数字的文件名中提取最后一个数字 → 返回最后一个数字
│ └── 从 .keras 文件名中提取数字 → 返回数字
└── 异常情况
├── 没有数字的文件名 → 抛出 ValueError
└── .weights.h5 文件名中没有数字 → 抛出 ValueError
"""
import pathlib
import tempfile
import pytest
from pipeline.base.configs import CheckpointConfig
from pipeline.base.checkpoint import (
extract_number_of_filename,
resolve_checkpoint
)
class TestCheckpointConfig:
def test_default_values(self):
checkpoint = CheckpointConfig()
assert checkpoint.dirs is None
assert checkpoint.path is None
assert checkpoint.epoch is None
assert checkpoint.suffix is None
def test_custom_values(self):
checkpoint = CheckpointConfig(
dirs=[pathlib.Path("dir_a"), pathlib.Path("dir_b")],
path=pathlib.Path("model_epoch_005.weights.h5"),
epoch=5,
suffix=".weights.h5"
)
assert checkpoint.dirs == [pathlib.Path("dir_a"), pathlib.Path("dir_b")]
assert checkpoint.path == pathlib.Path("model_epoch_005.weights.h5")
assert checkpoint.epoch == 5
assert checkpoint.suffix == ".weights.h5"
class TestExtractNumberOfFilename:
"""测试 extract_number_of_filename 函数"""
def test_extract_from_epoch_filename(self):
"""从包含 epoch 的文件名中提取数字"""
assert extract_number_of_filename("model_epoch_001") == 1
assert extract_number_of_filename("model_epoch_010") == 10
assert extract_number_of_filename("model_epoch_100") == 100
def test_extract_last_number(self):
"""提取最后一个数字"""
assert extract_number_of_filename("checkpoint_2024_06_30_epoch_002") == 2
assert extract_number_of_filename("model_v1_epoch_005") == 5
def test_extract_from_keras_file(self):
"""从 .keras 文件名中提取数字"""
assert extract_number_of_filename("epoch_005_model") == 5
assert extract_number_of_filename("model_epoch_003.keras") == 3
def test_no_number_raises_error(self):
"""没有数字时抛出 ValueError"""
with pytest.raises(ValueError, match="No number found"):
extract_number_of_filename("model_final")
def test_no_number_in_weights_file_raises_error(self):
""".weights.h5 文件名中没有数字时抛出 ValueError"""
with pytest.raises(ValueError, match="No number found"):
extract_number_of_filename("model_final.weights")
class TestResolveCheckpoint:
"""测试 resolve_checkpoint 函数"""
@pytest.fixture
def temp_dir(self):
"""创建临时目录"""
with tempfile.TemporaryDirectory() as tmp:
yield pathlib.Path(tmp)
def test_absolute_path_exists_returns_path_and_epoch(self, temp_dir):
"""path=绝对路径且存在 → 成功返回"""
checkpoint_file = temp_dir / "model_epoch_005.keras"
checkpoint_file.write_text("dummy")
path, epoch = resolve_checkpoint(path=checkpoint_file)
assert path == checkpoint_file
assert epoch == 5
def test_absolute_path_with_dirs_ignores_dir_and_warns(self, temp_dir):
"""path=绝对路径且存在(同时提供dirs)→ 使用绝对路径,忽略dirs,打印警告"""
checkpoint_file = temp_dir / "model_epoch_005.keras"
checkpoint_file.write_text("dummy")
other_dir = temp_dir / "other_dir"
other_dir.mkdir()
with pytest.warns(UserWarning, match="dirs 参数将被忽略"):
path, epoch = resolve_checkpoint(
path=checkpoint_file,
dirs=[other_dir]
)
assert path == checkpoint_file
assert epoch == 5
def test_absolute_path_not_exists_raises_error(self, temp_dir):
"""path=绝对路径但不存在 → FileNotFoundError"""
checkpoint_file = temp_dir / "model_epoch_005.keras"
with pytest.raises(FileNotFoundError, match="检查点文件不存在"):
resolve_checkpoint(path=checkpoint_file)
def test_relative_path_with_dirs_returns_path(self, temp_dir):
"""path=相对路径+dirs → 成功解析"""
checkpoint_file = temp_dir / "model_epoch_010.weights.h5"
checkpoint_file.write_text("dummy")
path, epoch = resolve_checkpoint(
dirs=[temp_dir],
path="model_epoch_010.weights.h5"
)
assert path == checkpoint_file
assert epoch == 10
def test_relative_path_without_dirs_raises_error(self):
"""path=相对路径+dirs=None → ValueError"""
with pytest.raises(ValueError, match="path 是相对路径时,必须提供 dirs"):
resolve_checkpoint(path="model.keras")
def test_resolve_latest_weights_h5(self, temp_dir):
"""path=None+dirs存在+epoch=None → 返回最新的检查点"""
(temp_dir / "model_epoch_001.weights.h5").write_text("dummy")
(temp_dir / "model_epoch_005.weights.h5").write_text("dummy")
(temp_dir / "model_epoch_003.weights.h5").write_text("dummy")
(temp_dir / "model_epoch_004.keras").write_text("dummy")
path, epoch = resolve_checkpoint(dirs=[temp_dir])
assert path.name == "model_epoch_005.weights.h5"
assert epoch == 5
def test_resolve_specific_epoch(self, temp_dir):
"""path=None+dirs存在+epoch指定 → 返回对应epoch的检查点"""
(temp_dir / "model_epoch_001.weights.h5").write_text("dummy")
(temp_dir / "model_epoch_005.weights.h5").write_text("dummy")
(temp_dir / "model_epoch_010.weights.h5").write_text("dummy")
path, epoch = resolve_checkpoint(dirs=[temp_dir], epoch=5)
assert path.name == "model_epoch_005.weights.h5"
assert epoch == 5
def test_resolve_nonexistent_epoch_raises_error(self, temp_dir):
"""请求不存在的 epoch → FileNotFoundError"""
(temp_dir / "model_epoch_001.weights.h5").write_text("dummy")
with pytest.raises(FileNotFoundError, match="未找到 epoch 5"):
resolve_checkpoint(dirs=[temp_dir], epoch=5)
def test_empty_dirs_returns_none(self, temp_dir):
"""path=None+dirs存在但为空 → 返回 (None, 0)"""
path, epoch = resolve_checkpoint(dirs=[temp_dir])
assert path is None
assert epoch == 0
def test_nonexistent_dirs_returns_none(self):
"""path=None+dirs不存在 → 返回 (None, 0)"""
path, epoch = resolve_checkpoint(dirs=["/nonexistent/path"])
assert path is None
assert epoch == 0
def test_both_none_raises_error(self):
"""两者都为None → ValueError"""
with pytest.raises(ValueError, match="必须提供 dirs 或 path"):
resolve_checkpoint()
def test_resolve_keras_file(self, temp_dir):
"""支持 .keras 文件格式"""
checkpoint_file = temp_dir / "epoch_007_model.keras"
checkpoint_file.write_text("dummy")
path, epoch = resolve_checkpoint(path=checkpoint_file)
assert path == checkpoint_file
assert epoch == 7
def test_resolve_weights_h5_file(self, temp_dir):
"""支持 .weights.h5 文件格式"""
checkpoint_file = temp_dir / "model_epoch_012.weights.h5"
checkpoint_file.write_text("dummy")
path, epoch = resolve_checkpoint(path=checkpoint_file)
assert path == checkpoint_file
assert epoch == 12
def test_relative_path_uses_checkpoint_dirs_in_order(self, temp_dir):
first_dir = temp_dir / "first"
second_dir = temp_dir / "second"
first_dir.mkdir()
second_dir.mkdir()
checkpoint_file = second_dir / "model_epoch_012.weights.h5"
checkpoint_file.write_text("dummy")
path, epoch = resolve_checkpoint(
dirs=[first_dir, second_dir],
path="model_epoch_012.weights.h5"
)
assert path == checkpoint_file
assert epoch == 12
def test_resolve_latest_from_checkpoint_dirs(self, temp_dir):
first_dir = temp_dir / "first"
second_dir = temp_dir / "second"
first_dir.mkdir()
second_dir.mkdir()
(first_dir / "model_epoch_003.weights.h5").write_text("dummy")
(second_dir / "model_epoch_008.weights.h5").write_text("dummy")
path, epoch = resolve_checkpoint(dirs=[first_dir, second_dir])
assert path == second_dir / "model_epoch_008.weights.h5"
assert epoch == 8
def test_resolve_with_suffix(self, temp_dir):
(temp_dir / "model_epoch_003.weights.h5").write_text("dummy")
(temp_dir / "model_epoch_005.keras").write_text("dummy")
path, epoch = resolve_checkpoint(
dirs=[temp_dir],
suffix=".keras"
)
assert path == temp_dir / "model_epoch_005.keras"
assert epoch == 5
def test_resolve_with_missing_suffix_raises_error(self, temp_dir):
(temp_dir / "model_epoch_003.weights.h5").write_text("dummy")
with pytest.raises(FileNotFoundError, match="未找到 epoch 3"):
resolve_checkpoint(
dirs=[temp_dir],
epoch=3,
suffix=".keras"
)
def test_absolute_path_with_suffix_mismatch_raises_error(self, temp_dir):
checkpoint_file = temp_dir / "model_epoch_005.keras"
checkpoint_file.write_text("dummy")
with pytest.raises(FileNotFoundError, match="检查点文件后缀不匹配"):
resolve_checkpoint(
path=checkpoint_file,
suffix=".weights.h5"
)