Spaces:
Sleeping
Sleeping
File size: 1,777 Bytes
a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 14f6839 a5fd608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import keras
import tensorflow as tf
from unittest.mock import Mock
from models.mini_gpt import GptModelBuilder
from pipeline import pipeline as pipeline_module
from pipeline.specs.text_pipeline import text_custom_objects
from test.pipeline.helpers import create_pipeline, save_training_checkpoint
def test_save_inference_model_runs_save_flow(tmp_path, monkeypatch):
# 构造最小 GPT 模型,保留真实保存与加载链路
builder = GptModelBuilder(
hidden_dim=8,
intermediate_dim=16,
num_heads=2,
num_layers=1
)
pipeline = create_pipeline(tmp_path / "task", builder)
log_config = Mock()
monkeypatch.setattr("pipeline.pipeline.Pipeline.log_config", lambda self: log_config())
# 先写入训练权重,作为后续导出推理模型的输入检查点
save_training_checkpoint(
builder,
pipeline.checkpoint_dir / "model_epoch_005.weights.h5"
)
# 将保存目录重定向到临时目录,避免污染仓库默认路径
monkeypatch.setattr(
pipeline_module,
"resolve_saved",
lambda path=None: tmp_path / path if path else tmp_path
)
# 执行推理模型导出,并重新加载验证文件可用
model_path = pipeline.save_inference_model()
loaded_model = keras.models.load_model(
str(model_path),
custom_objects=text_custom_objects()
)
outputs = loaded_model(tf.constant([[2, 3, 4]], dtype="int32"), training=False)
# 验证保存模型流程启动时会先打印 config
log_config.assert_called_once_with()
# 验证导出文件名、文件存在性和前向输出形状
assert model_path.name == "model_epoch_005.keras"
assert model_path.exists()
assert outputs.shape == (1, 3, 32)
|