Spaces:
Running
Running
File size: 4,130 Bytes
a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | from unittest.mock import Mock
import pytest
from pipeline.base.configs import CheckpointRules
from pipeline.base.generation import GenerationResult
from pipeline.base.generation_runner import BaseGenerationRunner
from pipeline.base.model_builder import ModelArtifact
from pipeline.specs.text_pipeline import TextInferenceBundle
from test.pipeline.helpers import DummyDataset, create_pipeline, sample_one
class DummyGenerationRunner(BaseGenerationRunner):
# 提供最小生成 runner,复用真实的 run_fixed 流程
title = "测试生成器"
fixed_prompts = ["白日依山尽", "床前明月光"]
max_length = 16
def test_generation_runner_runs_generation_flow(tmp_path, capsys, monkeypatch):
# 构造最小 pipeline,固定生成参数与 checkpoint 规则
pipeline = create_pipeline(tmp_path / "task", Mock(), CheckpointRules())
log_config = Mock()
# ASK: monkeypatch 是什么?
monkeypatch.setattr(
"pipeline.pipeline.Pipeline.log_config",
lambda self: log_config()
)
# 构造可控的推理产物,避免真实加载模型与推理
artifact = ModelArtifact(
model=Mock(),
generate=Mock(return_value=GenerationResult([7, 8], "<|stop|>"))
)
dataset = DummyDataset(data_dir="unused", sequence_length=16)
expected_resource = TextInferenceBundle(
tokenizer_bundle=dataset.tokenizer_bundle(),
docs_ds=dataset.doc_ds(),
max_length=16,
sample_fn=sample_one
)
loader = Mock(return_value=(artifact, expected_resource))
monkeypatch.setattr("pipeline.base.generation_runner.load_inference_artifact_from_pipeline", loader)
# 执行固定 prompts 的生成流程
runner = DummyGenerationRunner(lambda: pipeline)
runner.run_fixed()
# 验证打印 config
log_config.assert_called_once_with()
# 验证生成流程确实按 testing checkpoint 规则装配了生成器
loader.assert_called_once()
loader_pipeline, checkpoint_rule = loader.call_args.args
assert loader_pipeline is pipeline
assert checkpoint_rule == {
"dirs": [pipeline.checkpoint_dir],
"path": None,
"epoch": None,
"suffix": None
}
assert artifact.generate.call_count == 2
# 验证两个固定提示词都完成了生成并输出到控制台
output = capsys.readouterr().out
assert "白日依山尽" in output
assert "床前明月光" in output
assert "78<|stop|>" in output
def test_text_pipeline_rejects_task_specific_attribute_write(tmp_path):
"""验证文本流水线不会允许临时写入任务专属配置。"""
pipeline = create_pipeline(tmp_path / "task", Mock(), CheckpointRules())
with pytest.raises((AttributeError, TypeError)):
pipeline.dataset = Mock()
with pytest.raises((AttributeError, TypeError)):
pipeline.generation_rule = Mock()
def test_generation_runner_random_prompts_use_text_inference_bundle(tmp_path, capsys, monkeypatch):
"""验证随机提示生成会使用文本推理资源中的文档流。"""
pipeline = create_pipeline(tmp_path / "task", Mock(), CheckpointRules())
monkeypatch.setattr(
"pipeline.pipeline.Pipeline.log_config",
lambda self: None
)
artifact = ModelArtifact(
model=Mock(),
generate=Mock(return_value=GenerationResult([7, 8], "<|stop|>"))
)
dataset = DummyDataset(data_dir="unused", sequence_length=16)
resource = TextInferenceBundle(
tokenizer_bundle=dataset.tokenizer_bundle(),
docs_ds=dataset.doc_ds(),
max_length=16,
sample_fn=sample_one
)
loader = Mock(return_value=(artifact, resource))
monkeypatch.setattr(
"pipeline.base.generation_runner.load_inference_artifact_from_pipeline",
loader
)
monkeypatch.setattr(
"pipeline.base.generation_runner.random_prompts",
lambda **kwargs: lambda docs_ds: ["abc"]
)
runner = DummyGenerationRunner(lambda: pipeline)
runner.run_random()
output = capsys.readouterr().out
assert "abc" in output
assert "78<|stop|>" in output
|