Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Functions on vector spaces.""" | |
| import abc | |
| import dataclasses | |
| from typing import Callable, Sequence | |
| import numpy as np | |
| from tracr.craft import bases | |
| VectorSpaceWithBasis = bases.VectorSpaceWithBasis | |
| VectorInBasis = bases.VectorInBasis | |
| BasisDirection = bases.BasisDirection | |
| class VectorFunction(abc.ABC): | |
| """A function that acts on vectors.""" | |
| input_space: VectorSpaceWithBasis | |
| output_space: VectorSpaceWithBasis | |
| def __call__(self, x: VectorInBasis) -> VectorInBasis: | |
| """Evaluates the function.""" | |
| class Linear(VectorFunction): | |
| """A linear function.""" | |
| def __init__( | |
| self, | |
| input_space: VectorSpaceWithBasis, | |
| output_space: VectorSpaceWithBasis, | |
| matrix: np.ndarray, | |
| ): | |
| """Initialises. | |
| Args: | |
| input_space: The input vector space. | |
| output_space: The output vector space. | |
| matrix: a [input, output] matrix acting in a (sorted) basis. | |
| """ | |
| self.input_space = input_space | |
| self.output_space = output_space | |
| self.matrix = matrix | |
| def __post_init__(self) -> None: | |
| output_size, input_size = self.matrix.shape | |
| assert input_size == self.input_space.num_dims | |
| assert output_size == self.output_space.num_dims | |
| def __call__(self, x: VectorInBasis) -> VectorInBasis: | |
| if x not in self.input_space: | |
| raise TypeError(f"{x=} not in {self.input_space=}.") | |
| return VectorInBasis( | |
| basis_directions=sorted(self.output_space.basis), | |
| magnitudes=x.magnitudes @ self.matrix, | |
| ) | |
| def from_action( | |
| cls, | |
| input_space: VectorSpaceWithBasis, | |
| output_space: VectorSpaceWithBasis, | |
| action: Callable[[BasisDirection], VectorInBasis], | |
| ) -> "Linear": | |
| """from_action(i, o)(action) creates a Linear.""" | |
| matrix = np.zeros((input_space.num_dims, output_space.num_dims)) | |
| for i, direction in enumerate(input_space.basis): | |
| out_vector = action(direction) | |
| if out_vector not in output_space: | |
| raise TypeError(f"image of {direction} from {input_space=} " | |
| f"is not in {output_space=}") | |
| matrix[i, :] = out_vector.magnitudes | |
| return Linear(input_space, output_space, matrix) | |
| def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear": | |
| """Combines multiple parallel linear functions into a single one.""" | |
| joint_input_space = bases.join_vector_spaces( | |
| *[fn.input_space for fn in fns]) | |
| joint_output_space = bases.join_vector_spaces( | |
| *[fn.output_space for fn in fns]) | |
| def action(x: bases.BasisDirection) -> bases.VectorInBasis: | |
| out = joint_output_space.null_vector() | |
| for fn in fns: | |
| if x in fn.input_space: | |
| x_vec = fn.input_space.vector_from_basis_direction(x) | |
| out += fn(x_vec).project(joint_output_space) | |
| return out | |
| return cls.from_action(joint_input_space, joint_output_space, action) | |
| def project( | |
| from_space: VectorSpaceWithBasis, | |
| to_space: VectorSpaceWithBasis, | |
| ) -> Linear: | |
| """Creates a projection.""" | |
| def action(direction: bases.BasisDirection) -> VectorInBasis: | |
| if direction in to_space: | |
| return to_space.vector_from_basis_direction(direction) | |
| else: | |
| return to_space.null_vector() | |
| return Linear.from_action(from_space, to_space, action=action) | |
| class ScalarBilinear: | |
| """A scalar-valued bilinear operator.""" | |
| left_space: VectorSpaceWithBasis | |
| right_space: VectorSpaceWithBasis | |
| matrix: np.ndarray | |
| def __post_init__(self): | |
| """Ensure matrix acts in sorted bases and typecheck sizes.""" | |
| left_size, right_size = self.matrix.shape | |
| assert left_size == self.left_space.num_dims | |
| assert right_size == self.right_space.num_dims | |
| def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float: | |
| """Describes the action of the operator on vectors.""" | |
| if x not in self.left_space: | |
| raise TypeError(f"{x=} not in {self.left_space=}.") | |
| if y not in self.right_space: | |
| raise TypeError(f"{y=} not in {self.right_space=}.") | |
| return (x.magnitudes.T @ self.matrix @ y.magnitudes).item() | |
| def from_action( | |
| cls, | |
| left_space: VectorSpaceWithBasis, | |
| right_space: VectorSpaceWithBasis, | |
| action: Callable[[BasisDirection, BasisDirection], float], | |
| ) -> "ScalarBilinear": | |
| """from_action(l, r)(action) creates a ScalarBilinear.""" | |
| matrix = np.zeros((left_space.num_dims, right_space.num_dims)) | |
| for i, left_direction in enumerate(left_space.basis): | |
| for j, right_direction in enumerate(right_space.basis): | |
| matrix[i, j] = action(left_direction, right_direction) | |
| return ScalarBilinear(left_space, right_space, matrix) | |