Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """RASP programs only using the subset of RASP supported by the compiler.""" | |
| from typing import Sequence | |
| from tracr.rasp import rasp | |
| ### Programs that work only under non-causal evaluation. | |
| def make_length() -> rasp.SOp: | |
| """Creates the `length` SOp using selector width primitive. | |
| Example usage: | |
| length = make_length() | |
| length("abcdefg") | |
| >> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0] | |
| Returns: | |
| length: SOp mapping an input to a sequence, where every element | |
| is the length of that sequence. | |
| """ | |
| all_true_selector = rasp.Select( | |
| rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector") | |
| return rasp.SelectorWidth(all_true_selector).named("length") | |
| length = make_length() | |
| def make_reverse(sop: rasp.SOp) -> rasp.SOp: | |
| """Create an SOp that reverses a sequence, using length primitive. | |
| Example usage: | |
| reverse = make_reverse(rasp.tokens) | |
| reverse("Hello") | |
| >> ['o', 'l', 'l', 'e', 'H'] | |
| Args: | |
| sop: an SOp | |
| Returns: | |
| reverse : SOp that reverses the input sequence. | |
| """ | |
| opp_idx = (length - rasp.indices).named("opp_idx") | |
| opp_idx = (opp_idx - 1).named("opp_idx-1") | |
| reverse_selector = rasp.Select(rasp.indices, opp_idx, | |
| rasp.Comparison.EQ).named("reverse_selector") | |
| return rasp.Aggregate(reverse_selector, sop).named("reverse") | |
| def make_pair_balance(sop: rasp.SOp, open_token: str, | |
| close_token: str) -> rasp.SOp: | |
| """Return fraction of previous open tokens minus the fraction of close tokens. | |
| (As implemented in the RASP paper.) | |
| If the outputs are always non-negative and end in 0, that implies the input | |
| has balanced parentheses. | |
| Example usage: | |
| num_l = make_pair_balance(rasp.tokens, "(", ")") | |
| num_l("a()b(c))") | |
| >> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8] | |
| Args: | |
| sop: Input SOp. | |
| open_token: Token that counts positive. | |
| close_token: Token that counts negative. | |
| Returns: | |
| pair_balance: SOp mapping an input to a sequence, where every element | |
| is the fraction of previous open tokens minus previous close tokens. | |
| """ | |
| bools_open = rasp.numerical(sop == open_token).named("bools_open") | |
| opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens") | |
| bools_close = rasp.numerical(sop == close_token).named("bools_close") | |
| closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes") | |
| pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1)) | |
| return pair_balance.named("pair_balance") | |
| def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp: | |
| """Returns 1 if a set of parentheses are balanced, 0 else. | |
| (As implemented in the RASP paper.) | |
| Example usage: | |
| shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"]) | |
| shuffle_dyck2("({)}") | |
| >> [1, 1, 1, 1] | |
| shuffle_dyck2("(){)}") | |
| >> [0, 0, 0, 0, 0] | |
| Args: | |
| pairs: List of pairs of open and close tokens that each should be balanced. | |
| """ | |
| assert len(pairs) >= 1 | |
| # Compute running balance of each type of parenthesis | |
| balances = [] | |
| for pair in pairs: | |
| assert len(pair) == 2 | |
| open_token, close_token = pair | |
| balance = make_pair_balance( | |
| rasp.tokens, open_token=open_token, | |
| close_token=close_token).named(f"balance_{pair}") | |
| balances.append(balance) | |
| # Check if balances where negative anywhere -> parentheses not balanced | |
| any_negative = balances[0] < 0 | |
| for balance in balances[1:]: | |
| any_negative = any_negative | (balance < 0) | |
| # Convert to numerical SOp | |
| any_negative = rasp.numerical(rasp.Map(lambda x: x, | |
| any_negative)).named("any_negative") | |
| select_all = rasp.Select(rasp.indices, rasp.indices, | |
| rasp.Comparison.TRUE).named("select_all") | |
| has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative, | |
| default=0)).named("has_neg") | |
| # Check if all balances are 0 at the end -> closed all parentheses | |
| all_zero = balances[0] == 0 | |
| for balance in balances[1:]: | |
| all_zero = all_zero & (balance == 0) | |
| select_last = rasp.Select(rasp.indices, length - 1, | |
| rasp.Comparison.EQ).named("select_last") | |
| last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero") | |
| not_has_neg = (~has_neg).named("not_has_neg") | |
| return (last_zero & not_has_neg).named("shuffle_dyck") | |
| def make_shuffle_dyck2() -> rasp.SOp: | |
| return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2") | |
| def make_hist() -> rasp.SOp: | |
| """Returns the number of times each token occurs in the input. | |
| (As implemented in the RASP paper.) | |
| Example usage: | |
| hist = make_hist() | |
| hist("abac") | |
| >> [2, 1, 2, 1] | |
| """ | |
| same_tok = rasp.Select(rasp.tokens, rasp.tokens, | |
| rasp.Comparison.EQ).named("same_tok") | |
| return rasp.SelectorWidth(same_tok).named("hist") | |
| def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp: | |
| """Returns vals sorted by < relation on keys. | |
| Only supports unique keys. | |
| Example usage: | |
| sort = make_sort(rasp.tokens, rasp.tokens) | |
| sort([2, 4, 3, 1]) | |
| >> [1, 2, 3, 4] | |
| Args: | |
| vals: Values to sort. | |
| keys: Keys for sorting. | |
| """ | |
| smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller") | |
| target_pos = rasp.SelectorWidth(smaller).named("target_pos") | |
| sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ) | |
| return rasp.Aggregate(sel_new, vals).named("sort") | |
| def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int, | |
| min_key: float) -> rasp.SOp: | |
| """Returns vals sorted by < relation on keys, which don't need to be unique. | |
| The implementation differs from the RASP paper, as it avoids using | |
| compositions of selectors to break ties. Instead, it uses the arguments | |
| max_seq_len and min_key to ensure the keys are unique. | |
| Note that this approach only works for numerical keys. | |
| Example usage: | |
| sort = make_sort(rasp.tokens, rasp.tokens, 5, 1) | |
| sort([2, 4, 3, 1]) | |
| >> [1, 2, 3, 4] | |
| sort([2, 4, 1, 2]) | |
| >> [1, 2, 2, 4] | |
| Args: | |
| vals: Values to sort. | |
| keys: Keys for sorting. | |
| max_seq_len: Maximum sequence length (used to ensure keys are unique) | |
| min_key: Minimum key value (used to ensure keys are unique) | |
| Returns: | |
| Output SOp of sort program. | |
| """ | |
| keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys, | |
| rasp.indices) | |
| return make_sort_unique(vals, keys) | |
| def make_sort_freq(max_seq_len: int) -> rasp.SOp: | |
| """Returns tokens sorted by the frequency they appear in the input. | |
| Tokens the appear the same amount of times are output in the same order as in | |
| the input. | |
| Example usage: | |
| sort = make_sort_freq(rasp.tokens, rasp.tokens, 5) | |
| sort([2, 4, 2, 1]) | |
| >> [2, 2, 4, 1] | |
| Args: | |
| max_seq_len: Maximum sequence length (used to ensure keys are unique) | |
| """ | |
| hist = -1 * make_hist().named("hist") | |
| return make_sort( | |
| rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq") | |
| ### Programs that work under both causal and regular evaluation. | |
| def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp: | |
| """Count the fraction of previous tokens where a specific condition was True. | |
| (As implemented in the RASP paper.) | |
| Example usage: | |
| num_l = make_frac_prevs(rasp.tokens=="l") | |
| num_l("hello") | |
| >> [0, 0, 1/3, 1/2, 2/5] | |
| Args: | |
| bools: SOp mapping a sequence to a sequence of booleans. | |
| Returns: | |
| frac_prevs: SOp mapping an input to a sequence, where every element | |
| is the fraction of previous "True" tokens. | |
| """ | |
| bools = rasp.numerical(bools) | |
| prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ) | |
| return rasp.numerical(rasp.Aggregate(prevs, bools, | |
| default=0)).named("frac_prevs") | |
| def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp: | |
| """Returns the sop, shifted by `offset`, None-padded.""" | |
| select_off_by_offset = rasp.Select(rasp.indices, rasp.indices, | |
| lambda k, q: q == k + offset) | |
| out = rasp.Aggregate(select_off_by_offset, sop, default=None) | |
| return out.named(f"shift_by({offset})") | |
| def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp: | |
| """Returns an SOp which is True at the final element of the pattern. | |
| The first len(pattern) - 1 elements of the output SOp are None-padded. | |
| detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T] | |
| Args: | |
| sop: the SOp in which to look for patterns. | |
| pattern: a sequence of values to look for. | |
| Returns: | |
| a sop which detects the pattern. | |
| """ | |
| if len(pattern) < 1: | |
| raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}") | |
| # detectors[i] will be a boolean-valued SOp which is true at position j iff | |
| # the i'th (from the end) element of the pattern was detected at position j-i. | |
| detectors = [] | |
| for i, element in enumerate(reversed(pattern)): | |
| detector = sop == element | |
| if i != 0: | |
| detector = shift_by(i, detector) | |
| detectors.append(detector) | |
| # All that's left is to take the AND over all detectors. | |
| pattern_detected = detectors.pop() | |
| while detectors: | |
| pattern_detected = pattern_detected & detectors.pop() | |
| return pattern_detected.named(f"detect_pattern({pattern})") | |
| def make_count_less_freq(n: int) -> rasp.SOp: | |
| """Returns how many tokens appear fewer than n times in the input. | |
| The output sequence contains this count in each position. | |
| Example usage: | |
| count_less_freq = make_count_less_freq(2) | |
| count_less_freq(["a", "a", "a", "b", "b", "c"]) | |
| >> [3, 3, 3, 3, 3, 3] | |
| count_less_freq(["a", "a", "c", "b", "b", "c"]) | |
| >> [6, 6, 6, 6, 6, 6] | |
| Args: | |
| n: Integer to compare token frequences to. | |
| """ | |
| hist = make_hist().named("hist") | |
| select_less = rasp.Select(hist, hist, | |
| lambda x, y: x <= n).named("select_less") | |
| return rasp.SelectorWidth(select_less).named("count_less_freq") | |
| def make_count(sop, token): | |
| """Returns the count of `token` in `sop`. | |
| The output sequence contains this count in each position. | |
| Example usage: | |
| count = make_count(tokens, "a") | |
| count(["a", "a", "a", "b", "b", "c"]) | |
| >> [3, 3, 3, 3, 3, 3] | |
| count(["c", "a", "b", "c"]) | |
| >> [1, 1, 1, 1] | |
| Args: | |
| sop: Sop to count tokens in. | |
| token: Token to count. | |
| """ | |
| return rasp.SelectorWidth(rasp.Select( | |
| sop, sop, lambda k, q: k == token)).named(f"count_{token}") | |
| def make_nary_sequencemap(f, *sops): | |
| """Returns an SOp that simulates an n-ary SequenceMap. | |
| Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n | |
| into a single SOp arguments that takes n-tuples as value. The n-ary sequence | |
| map implementing f is then a Map on this resulting SOp. | |
| Note that the intermediate variables representing tuples of varying length | |
| will be encoded categorically, and can become very high-dimensional. So, | |
| using this function might lead to very large compiled models. | |
| Args: | |
| f: Function with n arguments. | |
| *sops: Sequence of SOps, one for each argument of f. | |
| """ | |
| values, *sops = sops | |
| for sop in sops: | |
| # x is a single entry in the first iteration but a tuple in later iterations | |
| values = rasp.SequenceMap( | |
| lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop) | |
| return rasp.Map(lambda args: f(*args), values) | |