File size: 1,777 Bytes
a5fd608
14f6839
a5fd608
 
 
 
14f6839
 
a5fd608
 
 
 
 
 
 
 
 
 
 
 
 
14f6839
a5fd608
 
 
 
 
 
 
 
 
14f6839
a5fd608
 
 
 
 
 
14f6839
 
 
 
a5fd608
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

import keras
import tensorflow as tf
from unittest.mock import Mock

from models.mini_gpt import GptModelBuilder
from pipeline import pipeline as pipeline_module
from pipeline.specs.text_pipeline import text_custom_objects
from test.pipeline.helpers import create_pipeline, save_training_checkpoint


def test_save_inference_model_runs_save_flow(tmp_path, monkeypatch):
    # 构造最小 GPT 模型,保留真实保存与加载链路
    builder = GptModelBuilder(
        hidden_dim=8,
        intermediate_dim=16,
        num_heads=2,
        num_layers=1
    )
    pipeline = create_pipeline(tmp_path / "task", builder)
    log_config = Mock()
    monkeypatch.setattr("pipeline.pipeline.Pipeline.log_config", lambda self: log_config())

    # 先写入训练权重,作为后续导出推理模型的输入检查点
    save_training_checkpoint(
        builder,
        pipeline.checkpoint_dir / "model_epoch_005.weights.h5"
    )

    # 将保存目录重定向到临时目录,避免污染仓库默认路径
    monkeypatch.setattr(
        pipeline_module,
        "resolve_saved",
        lambda path=None: tmp_path / path if path else tmp_path
    )

    # 执行推理模型导出,并重新加载验证文件可用
    model_path = pipeline.save_inference_model()
    loaded_model = keras.models.load_model(
        str(model_path),
        custom_objects=text_custom_objects()
    )
    outputs = loaded_model(tf.constant([[2, 3, 4]], dtype="int32"), training=False)

    # 验证保存模型流程启动时会先打印 config
    log_config.assert_called_once_with()

    # 验证导出文件名、文件存在性和前向输出形状
    assert model_path.name == "model_epoch_005.keras"
    assert model_path.exists()
    assert outputs.shape == (1, 3, 32)