yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
from typing import cast
from unittest.mock import Mock
import env.keras as keras_env
import pipeline.pipeline as pipeline_module
from pipeline.base.model_builder import ModelArtifact, ModelBuilder
from test.pipeline.helpers import create_pipeline
def _assert_fit_kwargs(fit_kwargs):
# 校验训练参数与验证集内容,确保 execute 编排正确
assert fit_kwargs["initial_epoch"] == 0
assert fit_kwargs["epochs"] == 1
assert fit_kwargs["steps_per_epoch"] == 1
assert len(fit_kwargs["callbacks"]) == 4
validation_batch = next(fit_kwargs["validation_data"].as_numpy_iterator())
assert validation_batch[0].tolist() == [[1, 2, 3]]
assert validation_batch[1].tolist() == [[2, 3, 4]]
def test_execute_runs_training_flow(tmp_path, monkeypatch):
"""训练流程测试:验证 execute 能完成训练编排并调用 fit。"""
# 构造最小训练产物,避免真实训练开销
model = Mock()
training_artifact = ModelArtifact(model=model, generate=Mock())
model_builder_mock = Mock()
model_builder_mock.build_training_artifact.return_value = training_artifact
model_builder = cast(ModelBuilder, cast(object, model_builder_mock))
pipeline = create_pipeline(tmp_path / "task", model_builder)
# 屏蔽混合精度与配置输出副作用,只关注训练主流程
enable_mixed_precision = Mock()
log_config = Mock()
monkeypatch.setattr(keras_env, "enable_mixed_precision", enable_mixed_precision)
monkeypatch.setattr(pipeline, "log_config", log_config)
monkeypatch.setattr(pipeline_module, "resolve_checkpoint", lambda **kwargs: (None, 0))
# 执行训练流程
pipeline.execute()
# 验证训练前准备和模型装配都已发生
enable_mixed_precision.assert_called_once_with()
log_config.assert_called_once_with()
model_builder_mock.build_training_artifact.assert_called_once_with(
vocab_size=32,
sequence_length=16
)
model.compile.assert_called_once()
model.summary.assert_called_once_with()
model.fit.assert_called_once()
# 验证 fit 接收到的关键训练参数与验证集
_, fit_kwargs = model.fit.call_args
_assert_fit_kwargs(fit_kwargs)
assert pipeline.log_dir.exists()
assert pipeline.checkpoint_dir.exists()