Spaces:
Running
Running
File size: 919 Bytes
a5fd608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | """
ModelBuilder 协议定义
所有模型构建器应实现的接口。
"""
from dataclasses import dataclass
from typing import Callable, Protocol
import keras
@dataclass
class GenerationContext:
end_of_text: int
max_length: int
sample_fn: Callable
@dataclass
class GenerationResult:
token_ids: list[int]
stop_reason: str
GenerateFn = Callable[[GenerationContext, list[int]], GenerationResult]
@dataclass
class ModelArtifact:
model: keras.Model
generate: GenerateFn
class ModelBuilder(Protocol):
"""模型构建器协议"""
def build_training_artifact(
self,
vocab_size: int,
sequence_length: int
) -> ModelArtifact:
"""构建训练产物"""
...
def build_inference_artifact(
self,
training_artifact: ModelArtifact
) -> ModelArtifact:
"""基于训练产物构建推理产物"""
...
|