Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Sequence to sequence model.""" | |
| from typing import Any, Callable, Dict, Tuple | |
| from absl import logging | |
| from flax import linen as nn | |
| from flax.training import common_utils | |
| import gin | |
| import jax | |
| import jax.numpy as jnp | |
| import metrics_summary | |
| from transformer import decoder_stack | |
| from transformer import metric_utils | |
| from transformer import text_dataset | |
| import numpy as np | |
| import seqio | |
| Array = jnp.ndarray | |
| MetricsSummary = metrics_summary.MetricsSummary | |
| # TODO(mrabe): Remove this function and find a better way to turn text metrics | |
| # into text on tensorboard. | |
| def process_summaries(vocab: seqio.Vocabulary, | |
| met_summary: MetricsSummary, | |
| mode: str) -> MetricsSummary: | |
| """Compute some additional summaries, and convert tokens to text. | |
| Args: | |
| vocab: The vocabulary to detokenize generated text. | |
| met_summary: The summary object to process. | |
| mode: The mode of the summary (e.g. "test", "train") | |
| Returns: | |
| The modified summary dictionary. | |
| """ | |
| mdict = met_summary.current_metric_dict() | |
| # Calculate perplexity from the average nats_per_token over all replicas. | |
| # This has to be done here, because the perplexities themselves can't be | |
| # averaged in the usual way. | |
| if "nats_per_token" in mdict: | |
| nats_per_token = mdict["nats_per_token"].to_value() | |
| met_summary.add({"perplexity": np.exp(nats_per_token)}) | |
| if mode == "generate" and "gen_tokens" in mdict: | |
| # Convert output tokens to example output text. | |
| # Write text to both the summary, and pretty-print to the log file. | |
| gen_toks = mdict["gen_tokens"].to_value() | |
| if np.ndim(gen_toks) != 2: | |
| raise ValueError("Unsupported shape for gen_tokens: %s" % gen_toks.shape) | |
| ntoks = gen_toks.shape[-1] | |
| gen_text = text_dataset.decode_tokens(gen_toks, vocab, max_length=ntoks) | |
| logging.info("Generated text = %s", gen_text) | |
| met_summary.add_text({"gen_text": gen_text}) | |
| del mdict["gen_tokens"] # Otherwise it will turn into a histogram. | |
| return met_summary | |
| def process_summaries_function(vocab: seqio.Vocabulary) -> Callable[ | |
| [MetricsSummary, str], MetricsSummary]: | |
| """Return a function that processes summaries with the given vocabulary.""" | |
| # For use with training_loop.process_summaries_function | |
| def process_fn(met_summary: MetricsSummary, mode: str): | |
| return process_summaries(vocab, met_summary, mode) | |
| return process_fn | |
| class DecoderOnlyLanguageModel(nn.Module): | |
| """Decoder only language modeling.""" | |
| mode: str | |
| task_config: decoder_stack.TransformerTaskConfig = gin.REQUIRED | |
| decoder_factory: Callable[[], Any] = gin.REQUIRED | |
| sample_method: str = "sample" # Can be {"sample", "greedy"} | |
| output_token_losses: bool = False | |
| def get_fake_input(self): | |
| """Returns a fake input for initialization of the appropriate shape.""" | |
| b = self.task_config.batch_size | |
| fake_input_dict = { | |
| "targets": jnp.ones([b, self.task_config.sequence_length], | |
| dtype=jnp.int32), | |
| "start_of_sequence": jnp.ones([b], dtype=jnp.bool_), | |
| "epoch": jnp.ones([b], dtype=jnp.int32), | |
| } | |
| if text_dataset.get_loss_mask_tokens(split=self.mode) != (None, None): | |
| # We are not adding the loss mask to the dummy input by default as it can | |
| # cause a slowdown during evaluation and perhaps inference. | |
| fake_input_dict["loss_mask"] = jnp.ones( | |
| [b, self.task_config.sequence_length], dtype=jnp.bool_) | |
| return fake_input_dict | |
| def metrics_summary_operations(self, aggregate_over: str) -> Dict[str, str]: | |
| """Summary operation to use for recorded metrics.""" | |
| metric_ops = { | |
| "loss": "mean", | |
| "nats_per_token": "mean", | |
| "bits_per_token": "mean", | |
| "bits_per_char": "mean", | |
| "accuracy": "mean", | |
| "num_tokens": "mean", | |
| "num_chars_per_device": "mean", | |
| "num_chars_per_batch": "mean", | |
| "nonzero_tokens": "mean", | |
| "num_tokens_per_device": "mean", | |
| "num_tokens_per_batch": "mean", | |
| "epoch": "mean", | |
| } | |
| if aggregate_over == "steps": | |
| return metric_ops | |
| elif aggregate_over == "devices": | |
| # Ensure that statistics that refer to the total batch size stay constant | |
| # as TPU topologies change. For those we have to sum over devices, but | |
| # compute the mean over steps. | |
| metric_ops.update({ | |
| "num_tokens_per_batch": "sum", | |
| "num_chars_per_batch": "sum", | |
| "loss": "sum"}) | |
| return metric_ops | |
| else: | |
| raise ValueError("Don't know how to aggregate over: %s" % aggregate_over) | |
| def setup(self): | |
| self.decoder = self.decoder_factory(mode=self.mode, | |
| task_config=self.task_config) # pytype: disable=wrong-keyword-args # trace-all-classes | |
| def __call__(self, inputs: ...): | |
| task_config = self.task_config | |
| input_tokens = inputs["targets"] # [b, seq_len] | |
| start_of_sequence = inputs["start_of_sequence"] # [b] | |
| epochs = inputs["epoch"] # [b] | |
| if "loss_mask" in inputs: | |
| loss_mask = inputs["loss_mask"] # [b, seq_len] | |
| else: | |
| loss_mask = jnp.ones((1, 1), dtype=jnp.bool_) | |
| input_tokens = jnp.asarray(input_tokens) | |
| assert input_tokens.ndim == 2 | |
| assert input_tokens.shape[0] == task_config.batch_size | |
| assert input_tokens.shape[1] == task_config.sequence_length | |
| assert start_of_sequence.shape[0] == task_config.batch_size | |
| # Sanity check to avoid out-of-bounds on token lookup. | |
| input_tokens = input_tokens % task_config.vocab_size | |
| logging.info("langmodel: Compiling model for mode %s", self.mode) | |
| logging.info("langmodel: input_tokens = %r", input_tokens) | |
| logging.info("langmodel: start_of_sequece = %r", start_of_sequence) | |
| logging.info("langmodel: epochs = %r", epochs) | |
| # The target outputs are the next character in each sequence. | |
| # Shift tokens left and pad with a zero at the end. | |
| # TODO(delesley): We don't predict the first token of each sequence. | |
| target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)]) | |
| logging.info("langmodel: target_tokens = %r", target_tokens) | |
| # Invoke the decoder stack. | |
| # The decoder will return pre-softmax logits for the predicted targets. | |
| (logits, _, d_metrics) = self.decoder(input_tokens=input_tokens, | |
| target_tokens=target_tokens, | |
| start_of_sequence=start_of_sequence) | |
| # Softmax cross-entropy loss on target tokens. | |
| logits = nn.log_softmax(logits, axis=-1) # (b, seq_len, vocab_size) | |
| logging.info("langmodel: logits = %r", logits) | |
| soft_targets = common_utils.onehot(target_tokens, task_config.vocab_size) | |
| logging.info("langmodel: soft_targets = %r", soft_targets) | |
| losses = -jnp.sum(soft_targets * logits, axis=-1) # (b, seq_len) | |
| logging.info("langmodel: losses = %r", losses) | |
| # Don't predict null tokens which are past the end-of-sequence. | |
| # Also don't predict the 0 at the end of the sequence. | |
| # TODO(delesley): Predict the final end-of-sequence marker. | |
| loss_mask = jnp.logical_and( | |
| loss_mask, | |
| input_tokens > 0) | |
| loss_mask = jnp.logical_and( | |
| loss_mask, | |
| target_tokens > 0) | |
| logging.info("langmodel: loss_mask = %r", loss_mask) | |
| losses = jnp.where(loss_mask, losses, 0.0) # (batch_size, seq_len) | |
| loss = jnp.sum(losses) # total loss on device | |
| token_count = jnp.sum(loss_mask) # tokens on device | |
| token_count_nz = token_count + 1.0e-6 | |
| loss_per_token = loss / token_count_nz | |
| bits_per_token = loss_per_token * 1.442695 # log(e)/log(2) | |
| accuracy = metric_utils.compute_accuracy_sum(logits, target_tokens, | |
| loss_mask) | |
| accuracy = accuracy / token_count_nz # Percent correct. | |
| epoch = jnp.mean(epochs) | |
| if self.mode == "generate" and self.decoder.supports_generate(): | |
| # Generate example text. | |
| logging.info("lang_model: text inference.") | |
| gen_tokens = self.generate(inputs, task_config.sequence_length) | |
| # Return generated text, along with vizualizations and histograms. | |
| metrics = {"gen_tokens": gen_tokens, **d_metrics} | |
| return (loss, metrics) | |
| # Just return metrics related to the loss. | |
| metrics = { | |
| "loss": loss, # will be summed over devices | |
| "nats_per_token": (loss_per_token, token_count), | |
| "bits_per_token": (bits_per_token, token_count), | |
| "accuracy": (accuracy, token_count), | |
| "num_tokens_per_device": token_count, | |
| "num_tokens_per_batch": token_count, # will be summed over devices | |
| "epoch": epoch, | |
| } | |
| # Compute bits per character if we have the number of characters. | |
| if "num_chars" in inputs: | |
| num_chars = jnp.sum(inputs["num_chars"]) | |
| bits_per_char = loss / (num_chars + 1e-6) * 1.442695 | |
| metrics["num_chars_per_device"] = num_chars | |
| metrics["num_chars_per_batch"] = num_chars # will be summed over devices | |
| metrics["bits_per_char"] = (bits_per_char, num_chars) | |
| # Provided to make sure that the data pipeline and the the model agree | |
| # on the number of tokens with a loss. | |
| if "nonzero_tokens" in inputs: | |
| nonzero_tokens = jnp.sum(inputs["nonzero_tokens"]) | |
| metrics["nonzero_tokens"] = nonzero_tokens | |
| if self.output_token_losses: | |
| metrics["token_losses"] = losses | |
| return (loss, metrics) | |
| def generate(self, inputs: ..., sequence_length: int) -> Array: | |
| """Generate an output sequence. | |
| Args: | |
| inputs: the same as argument to _call_. | |
| sequence_length: the length of sequence to generate. | |
| Returns: | |
| An array of generated tokens of shape (batch_size, sequence_length). | |
| """ | |
| # TODO(delesley): Add support for passing the prefix as an argument. | |
| # TODO(delesley): Add support for temperature, gumbel softmax, beam search. | |
| batch_size = self.task_config.batch_size | |
| input_tokens = inputs["targets"] # [b,seq_len] | |
| start_of_sequence = inputs["start_of_sequence"] # [b] | |
| # Initialize decoder. | |
| dstate = self.decoder.init_decoder_state(sequence_length, | |
| start_of_sequence) | |
| # TODO(delesley): Handle start-of-sequence in a better way. | |
| # There is no special token for start of sequence, so we grab the first | |
| # one from the ground-truth input data. | |
| first_token = input_tokens[:, 0:1] | |
| no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_) | |
| sample_method = self.sample_method | |
| sample_prng = self.make_rng("sample") | |
| # Greedy autoregressive decoder function. | |
| def loop_fn(scan_state: Any, i: Array) -> Tuple[Any, Array]: | |
| prng = jax.random.fold_in(sample_prng, i) | |
| (dstate, input_token) = scan_state | |
| del i | |
| (logits, dstate, _) = self.decoder(input_tokens=input_token, | |
| target_tokens=None, | |
| start_of_sequence=no_start_of_seq, | |
| decoder_state=dstate) | |
| if sample_method == "sample": | |
| logging.info("Using categorical sampling.") | |
| output_token = jax.random.categorical(prng, logits, axis=-1) | |
| elif sample_method == "greedy": | |
| logging.info("Using greedy sampling.") | |
| output_token = jnp.argmax(logits, axis=-1) | |
| else: | |
| raise ValueError(f"Invalid sampling method: {sample_method}") | |
| logging.info("generate_loop_fn: output_token = %r", output_token) | |
| return ((dstate, output_token), output_token) | |
| # Scan over the sequence length. | |
| iterations = jnp.arange(sequence_length) | |
| initial_scan_state = (dstate, first_token) | |
| (_, output_tokens) = jax.lax.scan(loop_fn, initial_scan_state, iterations) | |
| logging.info("generate: output_tokens = %r", output_tokens) | |
| # Output_tokens has shape (sequence_length, batch_size, 1) | |
| assert output_tokens.shape == (sequence_length, batch_size, 1) | |
| output_tokens = jnp.reshape( | |
| output_tokens, (sequence_length, self.task_config.batch_size)) | |
| output_tokens = output_tokens.transpose([1, 0]) | |
| return output_tokens | |