yetrun's picture
ver1: 实现深度学习训练框架,支持 Wiki GPT 与诗歌生成双任务
a5fd608
"""
与生成有关的组件
"""
import pathlib
from dataclasses import dataclass
from typing import Any, Callable
import keras
import numpy as np
from keras import callbacks, ops
from env.vocab import PAD
from env.logger import get_logger
from pipeline.base.model_builder import GenerationContext, GenerationResult, ModelArtifact
def generate_with_training_model(
model: keras.Model,
context: GenerationContext,
prompt_tokens: list[int]
) -> GenerationResult:
prompt_length = len(prompt_tokens)
if prompt_length == 0:
return GenerationResult([], "<|empty|>")
tokens = prompt_tokens + [PAD] * (context.max_length - prompt_length)
for i in range(prompt_length, context.max_length):
prediction = model.predict(np.array([tokens]), verbose=0)
prediction = prediction[0, i - 1]
next_token = ops.convert_to_numpy(context.sample_fn(prediction))
next_token_id = np.array(next_token).item()
tokens[i] = next_token_id
if next_token_id == context.end_of_text:
return GenerationResult(tokens[:i], "<|endoftext|>")
if next_token_id == PAD:
return GenerationResult(tokens[:i], "<|pad|>")
return GenerationResult(tokens, "<|maxlength|>")
def generate_with_stateful_model(
model: keras.Model,
context: GenerationContext,
prompt_tokens: list[int],
initial_states: list
) -> GenerationResult:
if not prompt_tokens:
return GenerationResult([], "<|empty|>")
tokens = list(prompt_tokens)
batch_tokens = np.array([tokens])
logits, *states = model.predict([batch_tokens] + initial_states, verbose=0)
for _ in range(len(tokens), context.max_length):
next_token = ops.convert_to_numpy(context.sample_fn(logits[0]))
next_token_id = np.array(next_token).item()
tokens.append(next_token_id)
if next_token_id == context.end_of_text:
return GenerationResult(tokens[:-1], "<|endoftext|>")
if next_token_id <= PAD:
return GenerationResult(tokens, "<|pad|>")
logits, *states = model.predict([np.array([[next_token_id]])] + states, verbose=0)
return GenerationResult(tokens, "<|maxlength|>")
@dataclass
class TextGenerationResult:
text: str
stop_reason: str
class TextGenerator:
def __init__(
self,
artifact: ModelArtifact,
tokenizer: Any,
decode: Callable,
end_of_text: int,
sample_fn: Callable,
max_length: int
):
self.artifact = artifact
self.tokenizer = tokenizer
self.decode = decode
self.context = GenerationContext(
end_of_text=end_of_text,
max_length=max_length,
sample_fn=sample_fn
)
def generate_tokens(
self,
prompt: str,
max_length: int | None = None,
sample_fn: Callable | None = None
) -> GenerationResult:
context = GenerationContext(
end_of_text=self.context.end_of_text,
max_length=max_length if max_length is not None else self.context.max_length,
sample_fn=sample_fn if sample_fn is not None else self.context.sample_fn
)
prompt_tokens = self._tokenize_prompt(prompt)
return self.artifact.generate(context, prompt_tokens)
def generate_text(
self,
prompt: str,
max_length: int | None = None,
sample_fn: Callable | None = None
) -> TextGenerationResult:
result = self.generate_tokens(prompt, max_length, sample_fn)
return TextGenerationResult(
text=self.decode(result.token_ids),
stop_reason=result.stop_reason
)
def _tokenize_prompt(self, prompt: str) -> list[int]:
prompt_tokens = list(ops.convert_to_numpy(self.tokenizer(prompt)))
return [token for token in prompt_tokens if token > PAD]
class GenerationCallback(callbacks.Callback):
def __init__(
self,
prompts: list[str],
log_file: pathlib.Path,
tokenizer: Any,
decode: Callable,
end_of_text: int,
max_length: int,
sample_fn: Callable,
training_artifact: ModelArtifact
):
super().__init__()
self.prompts = prompts
self.tokenizer = tokenizer
self.decode = decode
self.end_of_text = end_of_text
self.max_length = max_length
self.sample_fn = sample_fn
self.training_artifact = training_artifact
self.logger = self.init_logger(log_file)
def on_epoch_end(self, epoch, logs=None):
generator = TextGenerator(
artifact=self.training_artifact,
tokenizer=self.tokenizer,
decode=self.decode,
end_of_text=self.end_of_text,
max_length=self.max_length,
sample_fn=self.sample_fn
)
self.logger.info(f"\nGenerated text after epoch {epoch + 1}:")
for i, prompt in enumerate(self.prompts):
result = generator.generate_text(prompt)
self.logger.info(f"Prompt {i + 1:2}: {prompt}")
self.logger.info(f"Generated: {result.text}{result.stop_reason}\n")
@staticmethod
def init_logger(log_file: pathlib.Path):
if not log_file.parent.exists():
log_file.parent.mkdir(parents=True)
logger = get_logger("GenerationCallback", filepath=str(log_file))
logger.info("Initialized GenerationCallback logger")
return logger