general-deep-learning / pipeline /base /model_builder.py
yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""
ModelBuilder 协议定义
所有模型构建器应实现的接口。
"""
from dataclasses import dataclass
from typing import Callable, Protocol
import keras
@dataclass
class GenerationContext:
end_of_text: int
max_length: int
sample_fn: Callable
@dataclass
class GenerationResult:
token_ids: list[int]
stop_reason: str
GenerateFn = Callable[[GenerationContext, list[int]], GenerationResult]
@dataclass
class ModelArtifact:
model: keras.Model
generate: GenerateFn
class ModelBuilder(Protocol):
"""模型构建器协议"""
def build_training_artifact(
self,
vocab_size: int,
sequence_length: int
) -> ModelArtifact:
"""构建训练产物"""
...
def build_inference_artifact(
self,
training_artifact: ModelArtifact
) -> ModelArtifact:
"""基于训练产物构建推理产物"""
...