Phi2-Fine-Tuning
/
phivenv
/Lib
/site-packages
/torch
/fx
/experimental
/unification
/unification_tools.py
| # mypy: allow-untyped-defs | |
| import collections | |
| import operator | |
| from collections.abc import Mapping | |
| from functools import reduce | |
| __all__ = [ | |
| "merge", | |
| "merge_with", | |
| "valmap", | |
| "keymap", | |
| "itemmap", | |
| "valfilter", | |
| "keyfilter", | |
| "itemfilter", | |
| "assoc", | |
| "dissoc", | |
| "assoc_in", | |
| "update_in", | |
| "get_in", | |
| ] | |
| def _get_factory(f, kwargs): | |
| factory = kwargs.pop("factory", dict) | |
| if kwargs: | |
| raise TypeError( | |
| f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" | |
| ) | |
| return factory | |
| def merge(*dicts, **kwargs): | |
| """Merge a collection of dictionaries | |
| >>> merge({1: "one"}, {2: "two"}) | |
| {1: 'one', 2: 'two'} | |
| Later dictionaries have precedence | |
| >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) | |
| {1: 2, 3: 3, 4: 4} | |
| See Also: | |
| merge_with | |
| """ | |
| if len(dicts) == 1 and not isinstance(dicts[0], Mapping): | |
| dicts = dicts[0] | |
| factory = _get_factory(merge, kwargs) | |
| rv = factory() | |
| for d in dicts: | |
| rv.update(d) | |
| return rv | |
| def merge_with(func, *dicts, **kwargs): | |
| """Merge dictionaries and apply function to combined values | |
| A key may occur in more than one dict, and all values mapped from the key | |
| will be passed to the function as a list, such as func([val1, val2, ...]). | |
| >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) | |
| {1: 11, 2: 22} | |
| >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP | |
| {1: 1, 2: 2, 3: 30} | |
| See Also: | |
| merge | |
| """ | |
| if len(dicts) == 1 and not isinstance(dicts[0], Mapping): | |
| dicts = dicts[0] | |
| factory = _get_factory(merge_with, kwargs) | |
| result = factory() | |
| for d in dicts: | |
| for k, v in d.items(): | |
| if k not in result: | |
| result[k] = [v] | |
| else: | |
| result[k].append(v) | |
| return valmap(func, result, factory) | |
| def valmap(func, d, factory=dict): | |
| """Apply function to values of dictionary | |
| >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} | |
| >>> valmap(sum, bills) # doctest: +SKIP | |
| {'Alice': 65, 'Bob': 45} | |
| See Also: | |
| keymap | |
| itemmap | |
| """ | |
| rv = factory() | |
| rv.update(zip(d.keys(), map(func, d.values()))) | |
| return rv | |
| def keymap(func, d, factory=dict): | |
| """Apply function to keys of dictionary | |
| >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} | |
| >>> keymap(str.lower, bills) # doctest: +SKIP | |
| {'alice': [20, 15, 30], 'bob': [10, 35]} | |
| See Also: | |
| valmap | |
| itemmap | |
| """ | |
| rv = factory() | |
| rv.update(zip(map(func, d.keys()), d.values())) | |
| return rv | |
| def itemmap(func, d, factory=dict): | |
| """Apply function to items of dictionary | |
| >>> accountids = {"Alice": 10, "Bob": 20} | |
| >>> itemmap(reversed, accountids) # doctest: +SKIP | |
| {10: "Alice", 20: "Bob"} | |
| See Also: | |
| keymap | |
| valmap | |
| """ | |
| rv = factory() | |
| rv.update(map(func, d.items())) | |
| return rv | |
| def valfilter(predicate, d, factory=dict): | |
| """Filter items in dictionary by value | |
| >>> iseven = lambda x: x % 2 == 0 | |
| >>> d = {1: 2, 2: 3, 3: 4, 4: 5} | |
| >>> valfilter(iseven, d) | |
| {1: 2, 3: 4} | |
| See Also: | |
| keyfilter | |
| itemfilter | |
| valmap | |
| """ | |
| rv = factory() | |
| for k, v in d.items(): | |
| if predicate(v): | |
| rv[k] = v | |
| return rv | |
| def keyfilter(predicate, d, factory=dict): | |
| """Filter items in dictionary by key | |
| >>> iseven = lambda x: x % 2 == 0 | |
| >>> d = {1: 2, 2: 3, 3: 4, 4: 5} | |
| >>> keyfilter(iseven, d) | |
| {2: 3, 4: 5} | |
| See Also: | |
| valfilter | |
| itemfilter | |
| keymap | |
| """ | |
| rv = factory() | |
| for k, v in d.items(): | |
| if predicate(k): | |
| rv[k] = v | |
| return rv | |
| def itemfilter(predicate, d, factory=dict): | |
| """Filter items in dictionary by item | |
| >>> def isvalid(item): | |
| ... k, v = item | |
| ... return k % 2 == 0 and v < 4 | |
| >>> d = {1: 2, 2: 3, 3: 4, 4: 5} | |
| >>> itemfilter(isvalid, d) | |
| {2: 3} | |
| See Also: | |
| keyfilter | |
| valfilter | |
| itemmap | |
| """ | |
| rv = factory() | |
| for item in d.items(): | |
| if predicate(item): | |
| k, v = item | |
| rv[k] = v | |
| return rv | |
| def assoc(d, key, value, factory=dict): | |
| """Return a new dict with new key value pair | |
| New dict has d[key] set to value. Does not modify the initial dictionary. | |
| >>> assoc({"x": 1}, "x", 2) | |
| {'x': 2} | |
| >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP | |
| {'x': 1, 'y': 3} | |
| """ | |
| d2 = factory() | |
| d2.update(d) | |
| d2[key] = value | |
| return d2 | |
| def dissoc(d, *keys, **kwargs): | |
| """Return a new dict with the given key(s) removed. | |
| New dict has d[key] deleted for each supplied key. | |
| Does not modify the initial dictionary. | |
| >>> dissoc({"x": 1, "y": 2}, "y") | |
| {'x': 1} | |
| >>> dissoc({"x": 1, "y": 2}, "y", "x") | |
| {} | |
| >>> dissoc({"x": 1}, "y") # Ignores missing keys | |
| {'x': 1} | |
| """ | |
| factory = _get_factory(dissoc, kwargs) | |
| d2 = factory() | |
| if len(keys) < len(d) * 0.6: | |
| d2.update(d) | |
| for key in keys: | |
| if key in d2: | |
| del d2[key] | |
| else: | |
| remaining = set(d) | |
| remaining.difference_update(keys) | |
| for k in remaining: | |
| d2[k] = d[k] | |
| return d2 | |
| def assoc_in(d, keys, value, factory=dict): | |
| """Return a new dict with new, potentially nested, key value pair | |
| >>> purchase = { | |
| ... "name": "Alice", | |
| ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, | |
| ... "credit card": "5555-1234-1234-1234", | |
| ... } | |
| >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP | |
| {'credit card': '5555-1234-1234-1234', | |
| 'name': 'Alice', | |
| 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} | |
| """ | |
| return update_in(d, keys, lambda x: value, value, factory) | |
| def update_in(d, keys, func, default=None, factory=dict): | |
| """Update value in a (potentially) nested dictionary | |
| inputs: | |
| d - dictionary on which to operate | |
| keys - list or tuple giving the location of the value to be changed in d | |
| func - function to operate on that value | |
| If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the | |
| original dictionary with v replaced by func(v), but does not mutate the | |
| original dictionary. | |
| If k0 is not a key in d, update_in creates nested dictionaries to the depth | |
| specified by the keys, with the innermost value set to func(default). | |
| >>> inc = lambda x: x + 1 | |
| >>> update_in({"a": 0}, ["a"], inc) | |
| {'a': 1} | |
| >>> transaction = { | |
| ... "name": "Alice", | |
| ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, | |
| ... "credit card": "5555-1234-1234-1234", | |
| ... } | |
| >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP | |
| {'credit card': '5555-1234-1234-1234', | |
| 'name': 'Alice', | |
| 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} | |
| >>> # updating a value when k0 is not in d | |
| >>> update_in({}, [1, 2, 3], str, default="bar") | |
| {1: {2: {3: 'bar'}}} | |
| >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) | |
| {1: 'foo', 2: {3: {4: 1}}} | |
| """ | |
| ks = iter(keys) | |
| k = next(ks) | |
| rv = inner = factory() | |
| rv.update(d) | |
| for key in ks: | |
| if k in d: | |
| d = d[k] | |
| dtemp = factory() | |
| dtemp.update(d) | |
| else: | |
| d = dtemp = factory() | |
| inner[k] = inner = dtemp | |
| k = key | |
| if k in d: | |
| inner[k] = func(d[k]) | |
| else: | |
| inner[k] = func(default) | |
| return rv | |
| def get_in(keys, coll, default=None, no_default=False): | |
| """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. | |
| If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless | |
| ``no_default`` is specified, then it raises KeyError or IndexError. | |
| ``get_in`` is a generalization of ``operator.getitem`` for nested data | |
| structures such as dictionaries and lists. | |
| >>> transaction = { | |
| ... "name": "Alice", | |
| ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, | |
| ... "credit card": "5555-1234-1234-1234", | |
| ... } | |
| >>> get_in(["purchase", "items", 0], transaction) | |
| 'Apple' | |
| >>> get_in(["name"], transaction) | |
| 'Alice' | |
| >>> get_in(["purchase", "total"], transaction) | |
| >>> get_in(["purchase", "items", "apple"], transaction) | |
| >>> get_in(["purchase", "items", 10], transaction) | |
| >>> get_in(["purchase", "total"], transaction, 0) | |
| 0 | |
| >>> get_in(["y"], {}, no_default=True) | |
| Traceback (most recent call last): | |
| ... | |
| KeyError: 'y' | |
| See Also: | |
| itertoolz.get | |
| operator.getitem | |
| """ | |
| try: | |
| return reduce(operator.getitem, keys, coll) | |
| except (KeyError, IndexError, TypeError): | |
| if no_default: | |
| raise | |
| return default | |
| def getter(index): | |
| if isinstance(index, list): | |
| if len(index) == 1: | |
| index = index[0] | |
| return lambda x: (x[index],) | |
| elif index: | |
| return operator.itemgetter(*index) | |
| else: | |
| return lambda x: () | |
| else: | |
| return operator.itemgetter(index) | |
| def groupby(key, seq): | |
| """Group a collection by a key function | |
| >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] | |
| >>> groupby(len, names) # doctest: +SKIP | |
| {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} | |
| >>> iseven = lambda x: x % 2 == 0 | |
| >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP | |
| {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} | |
| Non-callable keys imply grouping on a member. | |
| >>> groupby( | |
| ... "gender", | |
| ... [ | |
| ... {"name": "Alice", "gender": "F"}, | |
| ... {"name": "Bob", "gender": "M"}, | |
| ... {"name": "Charlie", "gender": "M"}, | |
| ... ], | |
| ... ) # doctest:+SKIP | |
| {'F': [{'gender': 'F', 'name': 'Alice'}], | |
| 'M': [{'gender': 'M', 'name': 'Bob'}, | |
| {'gender': 'M', 'name': 'Charlie'}]} | |
| Not to be confused with ``itertools.groupby`` | |
| See Also: | |
| countby | |
| """ | |
| if not callable(key): | |
| key = getter(key) | |
| d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] | |
| for item in seq: | |
| d[key(item)](item) | |
| rv = {} | |
| for k, v in d.items(): | |
| rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] | |
| return rv | |
| def first(seq): | |
| """The first element in a sequence | |
| >>> first("ABC") | |
| 'A' | |
| """ | |
| return next(iter(seq)) | |