Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Flax modules and functions for using external memory.""" | |
| from typing import Any, Optional, Tuple | |
| from absl import logging | |
| from flax import linen | |
| import gin | |
| import jax | |
| from transformer import memory_layer | |
| PRNGKey = Any | |
| Shape = Tuple[int] | |
| Dtype = Any | |
| Array = Any | |
| MemoryResource = Any | |
| class MemoryManager: | |
| """Manages any external resources that may be required by external memory. | |
| MemoryManager also functions as a factory, to create Flax modules that will | |
| read and write to whatever external memory has been configured. | |
| """ | |
| def __init__(self, | |
| batch_size: int, | |
| mode: str, | |
| num_heads: int, | |
| key_size: int, | |
| value_size: int, | |
| database_size: Optional[int] = None, | |
| dtype: Dtype = "float32", | |
| off_device_memory: Optional[MemoryResource] = None): | |
| """Create a MemoryManager object. | |
| A MemoryManager configures external memory, and is used as a factory to | |
| construct flax modules that read or write to the memory. | |
| Args: | |
| batch_size: The number of separate documents in a batch. | |
| mode: e.g. ("train", or "test") | |
| num_heads: The number of transformer heads. | |
| key_size: The length of the key vectors. | |
| value_size: The length of the value vectors. | |
| database_size: The total number of tokens in the database. | |
| dtype: The datatype used for keys and values. | |
| off_device_memory: An object which manages underlying SCAM memory. | |
| If None, then the model will use on-device memory. | |
| """ | |
| self.batch_size = batch_size | |
| self.mode = mode | |
| self.num_heads = num_heads | |
| self.key_size = key_size | |
| self.value_size = value_size | |
| self.database_size = database_size | |
| self.dtype = dtype | |
| self.off_device_memory = off_device_memory | |
| def create_memory_layer(self) -> linen.Module: | |
| """Create a flax Module that implements external memory.""" | |
| num_datasets = ( | |
| self.batch_size * self.num_heads # | |
| if self.off_device_memory is None # | |
| else self.num_heads) | |
| if self.off_device_memory is not None: | |
| mem_layer = None | |
| if mem_layer is None: | |
| raise ValueError("Off-device memory is not supported at this time.") | |
| return memory_layer.BatchedMemory( | |
| mem_layer, | |
| split_dimensions=(-2,), | |
| ) | |
| else: | |
| assert self.database_size is not None | |
| mem_layer = memory_layer.MemoryOnTpu(num_datasets=num_datasets, | |
| key_features=self.key_size, | |
| value_features=self.value_size, | |
| database_size=self.database_size, | |
| dtype=self.dtype) | |
| # Handle queries of shape [batch_size, seq_len, num_heads, kv_features] | |
| return memory_layer.BatchedMemory(mem_layer, | |
| split_dimensions=(0, -2)) | |
| def memory_on_tpu_factory(batch_size: int, | |
| mode: str, | |
| num_heads: int = gin.REQUIRED, | |
| key_size: int = gin.REQUIRED, | |
| value_size: int = gin.REQUIRED, | |
| database_size: int = gin.REQUIRED, | |
| dtype: Dtype = gin.REQUIRED) -> MemoryManager: | |
| """Implement SCAM memory on device.""" | |
| return MemoryManager(batch_size=batch_size, | |
| mode=mode, | |
| num_heads=num_heads, | |
| key_size=key_size, | |
| value_size=value_size, | |
| database_size=database_size, | |
| dtype=dtype, | |
| off_device_memory=None) | |