Spaces:
Running
Running
| """ | |
| 与生成有关的组件 | |
| """ | |
| import pathlib | |
| from dataclasses import dataclass | |
| from typing import Any, Callable | |
| import keras | |
| import numpy as np | |
| from keras import callbacks, ops | |
| from env.vocab import PAD | |
| from env.logger import get_logger | |
| from pipeline.base.model_builder import GenerationContext, GenerationResult, ModelArtifact | |
| def generate_with_training_model( | |
| model: keras.Model, | |
| context: GenerationContext, | |
| prompt_tokens: list[int] | |
| ) -> GenerationResult: | |
| prompt_length = len(prompt_tokens) | |
| if prompt_length == 0: | |
| return GenerationResult([], "<|empty|>") | |
| tokens = prompt_tokens + [PAD] * (context.max_length - prompt_length) | |
| for i in range(prompt_length, context.max_length): | |
| prediction = model.predict(np.array([tokens]), verbose=0) | |
| prediction = prediction[0, i - 1] | |
| next_token = ops.convert_to_numpy(context.sample_fn(prediction)) | |
| next_token_id = np.array(next_token).item() | |
| tokens[i] = next_token_id | |
| if next_token_id == context.end_of_text: | |
| return GenerationResult(tokens[:i], "<|endoftext|>") | |
| if next_token_id == PAD: | |
| return GenerationResult(tokens[:i], "<|pad|>") | |
| return GenerationResult(tokens, "<|maxlength|>") | |
| def generate_with_stateful_model( | |
| model: keras.Model, | |
| context: GenerationContext, | |
| prompt_tokens: list[int], | |
| initial_states: list | |
| ) -> GenerationResult: | |
| if not prompt_tokens: | |
| return GenerationResult([], "<|empty|>") | |
| tokens = list(prompt_tokens) | |
| batch_tokens = np.array([tokens]) | |
| logits, *states = model.predict([batch_tokens] + initial_states, verbose=0) | |
| for _ in range(len(tokens), context.max_length): | |
| next_token = ops.convert_to_numpy(context.sample_fn(logits[0])) | |
| next_token_id = np.array(next_token).item() | |
| tokens.append(next_token_id) | |
| if next_token_id == context.end_of_text: | |
| return GenerationResult(tokens[:-1], "<|endoftext|>") | |
| if next_token_id <= PAD: | |
| return GenerationResult(tokens, "<|pad|>") | |
| logits, *states = model.predict([np.array([[next_token_id]])] + states, verbose=0) | |
| return GenerationResult(tokens, "<|maxlength|>") | |
| class TextGenerationResult: | |
| text: str | |
| stop_reason: str | |
| class TextGenerator: | |
| def __init__( | |
| self, | |
| artifact: ModelArtifact, | |
| tokenizer: Any, | |
| decode: Callable, | |
| end_of_text: int, | |
| sample_fn: Callable, | |
| max_length: int | |
| ): | |
| self.artifact = artifact | |
| self.tokenizer = tokenizer | |
| self.decode = decode | |
| self.context = GenerationContext( | |
| end_of_text=end_of_text, | |
| max_length=max_length, | |
| sample_fn=sample_fn | |
| ) | |
| def generate_tokens( | |
| self, | |
| prompt: str, | |
| max_length: int | None = None, | |
| sample_fn: Callable | None = None | |
| ) -> GenerationResult: | |
| context = GenerationContext( | |
| end_of_text=self.context.end_of_text, | |
| max_length=max_length if max_length is not None else self.context.max_length, | |
| sample_fn=sample_fn if sample_fn is not None else self.context.sample_fn | |
| ) | |
| prompt_tokens = self._tokenize_prompt(prompt) | |
| return self.artifact.generate(context, prompt_tokens) | |
| def generate_text( | |
| self, | |
| prompt: str, | |
| max_length: int | None = None, | |
| sample_fn: Callable | None = None | |
| ) -> TextGenerationResult: | |
| result = self.generate_tokens(prompt, max_length, sample_fn) | |
| return TextGenerationResult( | |
| text=self.decode(result.token_ids), | |
| stop_reason=result.stop_reason | |
| ) | |
| def _tokenize_prompt(self, prompt: str) -> list[int]: | |
| prompt_tokens = list(ops.convert_to_numpy(self.tokenizer(prompt))) | |
| return [token for token in prompt_tokens if token > PAD] | |
| class GenerationCallback(callbacks.Callback): | |
| def __init__( | |
| self, | |
| prompts: list[str], | |
| log_file: pathlib.Path, | |
| tokenizer: Any, | |
| decode: Callable, | |
| end_of_text: int, | |
| max_length: int, | |
| sample_fn: Callable, | |
| training_artifact: ModelArtifact | |
| ): | |
| super().__init__() | |
| self.prompts = prompts | |
| self.tokenizer = tokenizer | |
| self.decode = decode | |
| self.end_of_text = end_of_text | |
| self.max_length = max_length | |
| self.sample_fn = sample_fn | |
| self.training_artifact = training_artifact | |
| self.logger = self.init_logger(log_file) | |
| def on_epoch_end(self, epoch, logs=None): | |
| generator = TextGenerator( | |
| artifact=self.training_artifact, | |
| tokenizer=self.tokenizer, | |
| decode=self.decode, | |
| end_of_text=self.end_of_text, | |
| max_length=self.max_length, | |
| sample_fn=self.sample_fn | |
| ) | |
| self.logger.info(f"\nGenerated text after epoch {epoch + 1}:") | |
| for i, prompt in enumerate(self.prompts): | |
| result = generator.generate_text(prompt) | |
| self.logger.info(f"Prompt {i + 1:2}: {prompt}") | |
| self.logger.info(f"Generated: {result.text}{result.stop_reason}\n") | |
| def init_logger(log_file: pathlib.Path): | |
| if not log_file.parent.exists(): | |
| log_file.parent.mkdir(parents=True) | |
| logger = get_logger("GenerationCallback", filepath=str(log_file)) | |
| logger.info("Initialized GenerationCallback logger") | |
| return logger | |