general-deep-learning / test /pipeline /generation_test.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
from unittest.mock import Mock
from pipeline.base.configs import CheckpointRules
from pipeline.base.generation import GenerationResult, TextGenerator
from pipeline.base.generation_runner import BaseGenerationRunner
from pipeline.base.model_builder import ModelArtifact
from test.pipeline.helpers import create_pipeline
class DummyGenerationRunner(BaseGenerationRunner):
# 提供最小生成 runner,复用真实的 run_fixed 流程
title = "测试生成器"
fixed_prompts = ["白日依山尽", "床前明月光"]
def _build_generator(self) -> TextGenerator:
# 按真实流程读取 testing checkpoint 规则并构造 TextGenerator
checkpoint_rule = self.pipeline.checkpoint_rules.resolve_testing_rule(
default_dirs=[self.pipeline.checkpoint_dir]
)
artifact, tokenizer_info = self.loader(self.pipeline, checkpoint_rule)
return TextGenerator(
artifact=artifact,
tokenizer=tokenizer_info.tokenizer,
decode=tokenizer_info.decode,
end_of_text=tokenizer_info.end_of_text,
max_length=16,
sample_fn=self.pipeline.generation_rule.sample_strategy
)
def test_generation_runner_runs_generation_flow(tmp_path, capsys):
# 构造最小 pipeline,固定生成参数与 checkpoint 规则
pipeline = create_pipeline(tmp_path / "task", Mock(), CheckpointRules())
log_config = Mock()
pipeline.log_config = log_config
# 构造可控的推理产物,避免真实加载模型与推理
artifact = ModelArtifact(
model=Mock(),
generate=Mock(return_value=GenerationResult([7, 8], "<|stop|>"))
)
expected_tokenizer_info = pipeline.dataset.tokenizer_bundle()
loader = Mock(return_value=(artifact, expected_tokenizer_info))
DummyGenerationRunner.loader = loader
# 执行固定 prompts 的生成流程
runner = DummyGenerationRunner(lambda: pipeline)
runner.run_fixed()
# 验证打印 config
log_config.assert_called_once_with()
# 验证生成流程确实按 testing checkpoint 规则装配了生成器
loader.assert_called_once()
loader_pipeline, checkpoint_rule = loader.call_args.args
assert loader_pipeline is pipeline
assert checkpoint_rule == {
"dirs": [pipeline.checkpoint_dir],
"path": None,
"epoch": None,
"suffix": None
}
assert artifact.generate.call_count == 2
# 验证两个固定提示词都完成了生成并输出到控制台
output = capsys.readouterr().out
assert "白日依山尽" in output
assert "床前明月光" in output
assert "78<|stop|>" in output