general-deep-learning / pipeline /base /model_loader.py
yetrun's picture
适配 hf-space
8aa4d8f
"""
模型加载工具模块
"""
from pathlib import Path
from functools import partial
from typing import TYPE_CHECKING
import keras
from data.base import TokenizerBundle
from pipeline.base.checkpoint import describe_checkpoint_lookup, resolve_checkpoint
from pipeline.base.generation import generate_with_training_model
from pipeline.base.model_builder import ModelArtifact
if TYPE_CHECKING:
from pipeline import Pipeline
def load_training_artifact_from_pipeline(
pipeline: "Pipeline",
checkpoint_rule: dict
) -> tuple[ModelArtifact, TokenizerBundle]:
tokenizer_info = pipeline.dataset.tokenizer_bundle()
checkpoint_path, _ = resolve_checkpoint(**checkpoint_rule)
if checkpoint_path is None:
lookup_info = describe_checkpoint_lookup(
dirs=checkpoint_rule.get("dirs"),
path=checkpoint_rule.get("path"),
suffix=checkpoint_rule.get("suffix")
)
raise FileNotFoundError(f"未找到任何检查点文件。查找信息: {lookup_info}")
if checkpoint_path.suffix.lower() == ".keras":
model = _load_keras_model(checkpoint_path)
training_artifact = ModelArtifact(
model=model,
generate=partial(generate_with_training_model, model)
)
else:
vocab_size = tokenizer_info.vocab_size
training_artifact = pipeline.model_builder.build_training_artifact(
vocab_size=vocab_size,
sequence_length=pipeline.dataset.sequence_length
)
training_artifact.model.load_weights(str(checkpoint_path))
print(f"已加载推理检查点: {checkpoint_path}")
return training_artifact, tokenizer_info
def load_inference_artifact_from_pipeline(
pipeline: "Pipeline",
checkpoint_rule: dict
) -> tuple[ModelArtifact, TokenizerBundle]:
training_artifact, tokenizer_info = load_training_artifact_from_pipeline(
pipeline,
checkpoint_rule
)
inference_artifact = pipeline.model_builder.build_inference_artifact(
training_artifact=training_artifact
)
return inference_artifact, tokenizer_info
def _load_keras_model(checkpoint_path: Path) -> keras.Model:
from pipeline.pipeline import WarmupSchedule
from models.mini_gpt.gpt_components import PositionalEmbedding, TransformerDecoder
return keras.models.load_model(
str(checkpoint_path),
# TODO: 这种在通用结构里引入特定模型组件的方式需要改进
custom_objects={
"WarmupSchedule": WarmupSchedule,
"PositionalEmbedding": PositionalEmbedding,
"TransformerDecoder": TransformerDecoder
}
)