Spaces:
Sleeping
Sleeping
| from unittest.mock import Mock | |
| from pipeline.base.configs import CheckpointRules | |
| from pipeline.base.generation import GenerationResult, TextGenerator | |
| from pipeline.base.generation_runner import BaseGenerationRunner | |
| from pipeline.base.model_builder import ModelArtifact | |
| from test.pipeline.helpers import create_pipeline | |
| class DummyGenerationRunner(BaseGenerationRunner): | |
| # 提供最小生成 runner,复用真实的 run_fixed 流程 | |
| title = "测试生成器" | |
| fixed_prompts = ["白日依山尽", "床前明月光"] | |
| def _build_generator(self) -> TextGenerator: | |
| # 按真实流程读取 testing checkpoint 规则并构造 TextGenerator | |
| checkpoint_rule = self.pipeline.checkpoint_rules.resolve_testing_rule( | |
| default_dirs=[self.pipeline.checkpoint_dir] | |
| ) | |
| artifact, tokenizer_info = self.loader(self.pipeline, checkpoint_rule) | |
| return TextGenerator( | |
| artifact=artifact, | |
| tokenizer=tokenizer_info.tokenizer, | |
| decode=tokenizer_info.decode, | |
| end_of_text=tokenizer_info.end_of_text, | |
| max_length=16, | |
| sample_fn=self.pipeline.generation_rule.sample_strategy | |
| ) | |
| def test_generation_runner_runs_generation_flow(tmp_path, capsys): | |
| # 构造最小 pipeline,固定生成参数与 checkpoint 规则 | |
| pipeline = create_pipeline(tmp_path / "task", Mock(), CheckpointRules()) | |
| log_config = Mock() | |
| pipeline.log_config = log_config | |
| # 构造可控的推理产物,避免真实加载模型与推理 | |
| artifact = ModelArtifact( | |
| model=Mock(), | |
| generate=Mock(return_value=GenerationResult([7, 8], "<|stop|>")) | |
| ) | |
| expected_tokenizer_info = pipeline.dataset.tokenizer_bundle() | |
| loader = Mock(return_value=(artifact, expected_tokenizer_info)) | |
| DummyGenerationRunner.loader = loader | |
| # 执行固定 prompts 的生成流程 | |
| runner = DummyGenerationRunner(lambda: pipeline) | |
| runner.run_fixed() | |
| # 验证打印 config | |
| log_config.assert_called_once_with() | |
| # 验证生成流程确实按 testing checkpoint 规则装配了生成器 | |
| loader.assert_called_once() | |
| loader_pipeline, checkpoint_rule = loader.call_args.args | |
| assert loader_pipeline is pipeline | |
| assert checkpoint_rule == { | |
| "dirs": [pipeline.checkpoint_dir], | |
| "path": None, | |
| "epoch": None, | |
| "suffix": None | |
| } | |
| assert artifact.generate.call_count == 2 | |
| # 验证两个固定提示词都完成了生成并输出到控制台 | |
| output = capsys.readouterr().out | |
| assert "白日依山尽" in output | |
| assert "床前明月光" in output | |
| assert "78<|stop|>" in output | |