diff --git a/.venv/lib/python3.11/site-packages/annotated_types/py.typed b/.venv/lib/python3.11/site-packages/annotated_types/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/annotated_types/test_cases.py b/.venv/lib/python3.11/site-packages/annotated_types/test_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..d9164d6883d2dd47cb766b483592ca3730f6f09d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/annotated_types/test_cases.py @@ -0,0 +1,151 @@ +import math +import sys +from datetime import date, datetime, timedelta, timezone +from decimal import Decimal +from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Set, Tuple + +if sys.version_info < (3, 9): + from typing_extensions import Annotated +else: + from typing import Annotated + +import annotated_types as at + + +class Case(NamedTuple): + """ + A test case for `annotated_types`. + """ + + annotation: Any + valid_cases: Iterable[Any] + invalid_cases: Iterable[Any] + + +def cases() -> Iterable[Case]: + # Gt, Ge, Lt, Le + yield Case(Annotated[int, at.Gt(4)], (5, 6, 1000), (4, 0, -1)) + yield Case(Annotated[float, at.Gt(0.5)], (0.6, 0.7, 0.8, 0.9), (0.5, 0.0, -0.1)) + yield Case( + Annotated[datetime, at.Gt(datetime(2000, 1, 1))], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + [datetime(2000, 1, 1), datetime(1999, 12, 31)], + ) + yield Case( + Annotated[datetime, at.Gt(date(2000, 1, 1))], + [date(2000, 1, 2), date(2000, 1, 3)], + [date(2000, 1, 1), date(1999, 12, 31)], + ) + yield Case( + Annotated[datetime, at.Gt(Decimal('1.123'))], + [Decimal('1.1231'), Decimal('123')], + [Decimal('1.123'), Decimal('0')], + ) + + yield Case(Annotated[int, at.Ge(4)], (4, 5, 6, 1000, 4), (0, -1)) + yield Case(Annotated[float, at.Ge(0.5)], (0.5, 0.6, 0.7, 0.8, 0.9), (0.4, 0.0, -0.1)) + yield Case( + Annotated[datetime, at.Ge(datetime(2000, 1, 1))], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + [datetime(1998, 1, 1), datetime(1999, 12, 31)], + ) + + yield Case(Annotated[int, at.Lt(4)], (0, -1), (4, 5, 6, 1000, 4)) + yield Case(Annotated[float, at.Lt(0.5)], (0.4, 0.0, -0.1), (0.5, 0.6, 0.7, 0.8, 0.9)) + yield Case( + Annotated[datetime, at.Lt(datetime(2000, 1, 1))], + [datetime(1999, 12, 31), datetime(1999, 12, 31)], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + ) + + yield Case(Annotated[int, at.Le(4)], (4, 0, -1), (5, 6, 1000)) + yield Case(Annotated[float, at.Le(0.5)], (0.5, 0.0, -0.1), (0.6, 0.7, 0.8, 0.9)) + yield Case( + Annotated[datetime, at.Le(datetime(2000, 1, 1))], + [datetime(2000, 1, 1), datetime(1999, 12, 31)], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + ) + + # Interval + yield Case(Annotated[int, at.Interval(gt=4)], (5, 6, 1000), (4, 0, -1)) + yield Case(Annotated[int, at.Interval(gt=4, lt=10)], (5, 6), (4, 10, 1000, 0, -1)) + yield Case(Annotated[float, at.Interval(ge=0.5, le=1)], (0.5, 0.9, 1), (0.49, 1.1)) + yield Case( + Annotated[datetime, at.Interval(gt=datetime(2000, 1, 1), le=datetime(2000, 1, 3))], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + [datetime(2000, 1, 1), datetime(2000, 1, 4)], + ) + + yield Case(Annotated[int, at.MultipleOf(multiple_of=3)], (0, 3, 9), (1, 2, 4)) + yield Case(Annotated[float, at.MultipleOf(multiple_of=0.5)], (0, 0.5, 1, 1.5), (0.4, 1.1)) + + # lengths + + yield Case(Annotated[str, at.MinLen(3)], ('123', '1234', 'x' * 10), ('', '1', '12')) + yield Case(Annotated[str, at.Len(3)], ('123', '1234', 'x' * 10), ('', '1', '12')) + yield Case(Annotated[List[int], at.MinLen(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2])) + yield Case(Annotated[List[int], at.Len(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2])) + + yield Case(Annotated[str, at.MaxLen(4)], ('', '1234'), ('12345', 'x' * 10)) + yield Case(Annotated[str, at.Len(0, 4)], ('', '1234'), ('12345', 'x' * 10)) + yield Case(Annotated[List[str], at.MaxLen(4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10)) + yield Case(Annotated[List[str], at.Len(0, 4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10)) + + yield Case(Annotated[str, at.Len(3, 5)], ('123', '12345'), ('', '1', '12', '123456', 'x' * 10)) + yield Case(Annotated[str, at.Len(3, 3)], ('123',), ('12', '1234')) + + yield Case(Annotated[Dict[int, int], at.Len(2, 3)], [{1: 1, 2: 2}], [{}, {1: 1}, {1: 1, 2: 2, 3: 3, 4: 4}]) + yield Case(Annotated[Set[int], at.Len(2, 3)], ({1, 2}, {1, 2, 3}), (set(), {1}, {1, 2, 3, 4})) + yield Case(Annotated[Tuple[int, ...], at.Len(2, 3)], ((1, 2), (1, 2, 3)), ((), (1,), (1, 2, 3, 4))) + + # Timezone + + yield Case( + Annotated[datetime, at.Timezone(None)], [datetime(2000, 1, 1)], [datetime(2000, 1, 1, tzinfo=timezone.utc)] + ) + yield Case( + Annotated[datetime, at.Timezone(...)], [datetime(2000, 1, 1, tzinfo=timezone.utc)], [datetime(2000, 1, 1)] + ) + yield Case( + Annotated[datetime, at.Timezone(timezone.utc)], + [datetime(2000, 1, 1, tzinfo=timezone.utc)], + [datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))], + ) + yield Case( + Annotated[datetime, at.Timezone('Europe/London')], + [datetime(2000, 1, 1, tzinfo=timezone(timedelta(0), name='Europe/London'))], + [datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))], + ) + + # Quantity + + yield Case(Annotated[float, at.Unit(unit='m')], (5, 4.2), ('5m', '4.2m')) + + # predicate types + + yield Case(at.LowerCase[str], ['abc', 'foobar'], ['', 'A', 'Boom']) + yield Case(at.UpperCase[str], ['ABC', 'DEFO'], ['', 'a', 'abc', 'AbC']) + yield Case(at.IsDigit[str], ['123'], ['', 'ab', 'a1b2']) + yield Case(at.IsAscii[str], ['123', 'foo bar'], ['£100', '😊', 'whatever 👀']) + + yield Case(Annotated[int, at.Predicate(lambda x: x % 2 == 0)], [0, 2, 4], [1, 3, 5]) + + yield Case(at.IsFinite[float], [1.23], [math.nan, math.inf, -math.inf]) + yield Case(at.IsNotFinite[float], [math.nan, math.inf], [1.23]) + yield Case(at.IsNan[float], [math.nan], [1.23, math.inf]) + yield Case(at.IsNotNan[float], [1.23, math.inf], [math.nan]) + yield Case(at.IsInfinite[float], [math.inf], [math.nan, 1.23]) + yield Case(at.IsNotInfinite[float], [math.nan, 1.23], [math.inf]) + + # check stacked predicates + yield Case(at.IsInfinite[Annotated[float, at.Predicate(lambda x: x > 0)]], [math.inf], [-math.inf, 1.23, math.nan]) + + # doc + yield Case(Annotated[int, at.doc("A number")], [1, 2], []) + + # custom GroupedMetadata + class MyCustomGroupedMetadata(at.GroupedMetadata): + def __iter__(self) -> Iterator[at.Predicate]: + yield at.Predicate(lambda x: float(x).is_integer()) + + yield Case(Annotated[float, MyCustomGroupedMetadata()], [0, 2.0], [0.01, 1.5]) diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dbb34b9b8247f0550f82aa8702698622b462ca0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/__main__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/__main__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63bc6054e03ae7d800adb6bc243243fe3586f44 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/__main__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_abc.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_abc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60a34f25bacf1990f3327c82934b048dc6f1918e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_abc.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_check.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_check.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d359f810d001285b4edbb7e5a91aee184e93aae7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_check.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_classdef.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_classdef.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af9d8fa7c9d5e99161810460291a82ec052626ae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_classdef.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_dataclasses.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_dataclasses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99106da6c2c3d3032d6af8ed4e9c199d1c32838c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_dataclasses.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_detect.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_detect.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..855d8d4c35aa483dd2b0765af745eb2ff757e56f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_detect.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_dictviews.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_dictviews.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79c0dd71cb13ae088dccf537de1f5dc9a38b1dc1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_dictviews.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_diff.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_diff.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..074252f85d2d6c820b7717e13f9712e507d292de Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_diff.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_extendpickle.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_extendpickle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64817eed215b5027b113b7d71ea903749a4426ab Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_extendpickle.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_fglobals.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_fglobals.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d25c255f8b18e1eea2f584175b1f9030e74bd4d1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_fglobals.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_file.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_file.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..617f590ec98c8f27f4c931319b46446ebab16db4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_file.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_functions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..829621c9bc2b1c79fa94f4deaf49e2d202d63d41 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_functions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_functors.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_functors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d80ff3a18ab2c1ee42b06ca45aa3553515e38c37 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_functors.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_logger.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0968b37e8b2fc6e9c6140ecd3304318fd9d3dd73 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_logger.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_mixins.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_mixins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4be30f1ef386b14fb9bd8eb19cd91fbc45286f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_mixins.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_module.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c8d76f82ca528dd1f4655cadc2de797bd5b34db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_module.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_nested.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_nested.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dce26c5d953e8c2be700425bff529de24e953f27 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_nested.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_objects.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_objects.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc17677d3af2fadb769e5ae3921769cd8f6a262c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_objects.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_properties.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_properties.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63060f67dbf3fc42a0ffede2d61a240523adfe6c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_properties.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_pycapsule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_pycapsule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56be1f27d344a7cf6f024516eb0e58de0637cc3a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_pycapsule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_recursive.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_recursive.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d932ae126c795a1a22096429d06e9c11a65fcc0d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_recursive.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_registered.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_registered.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1170beac0e6cf1e6b3e7d8448b4850ca3e809b62 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_registered.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_restricted.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_restricted.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ff3dd531ef66b30bba1f28d4c7a27b174c0d299 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_restricted.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_selected.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_selected.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd947064829103d6f6a3813d48c03a7ee1ff1a2c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_selected.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_session.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_session.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26db6a6b49c7eec7bba71a472676a0ca62a93c1f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_session.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_source.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_source.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9877a4156a6c0aba99183ec07f49c14e62950bea Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_source.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_sources.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_sources.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c11dab4d273a10d45897a0d233ede00e43f1fade Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_sources.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_temp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_temp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87e85c79a243cb3696fdc1586a4360c30319602 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_temp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_threads.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_threads.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2848884fb8e61668804807d97175bbe1f6772198 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_threads.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_weakref.cpython-311.pyc b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_weakref.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb451e9d029206d5501d3d573798600408da896 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/dill/tests/__pycache__/test_weakref.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/INSTALLER b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/LICENSE b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..aaf210b1d01da693fda15995ec973ed21cd968c3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/LICENSE @@ -0,0 +1,18 @@ +Copyright © 2017 Erez Shinan + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/METADATA b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..6d5f341b993998b289ea5f9093de3ca1488652c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/METADATA @@ -0,0 +1,47 @@ +Metadata-Version: 2.1 +Name: lark +Version: 1.2.2 +Summary: a modern parsing library +Author-email: Erez Shinan +License: MIT +Project-URL: Homepage, https://github.com/lark-parser/lark +Project-URL: Download, https://github.com/lark-parser/lark/tarball/master +Keywords: Earley,LALR,parser,parsing,ast +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python :: 3 +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Classifier: Topic :: Text Processing :: General +Classifier: Topic :: Text Processing :: Linguistic +Classifier: License :: OSI Approved :: MIT License +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +License-File: LICENSE +Provides-Extra: atomic_cache +Requires-Dist: atomicwrites ; extra == 'atomic_cache' +Provides-Extra: interegular +Requires-Dist: interegular <0.4.0,>=0.3.1 ; extra == 'interegular' +Provides-Extra: nearley +Requires-Dist: js2py ; extra == 'nearley' +Provides-Extra: regex +Requires-Dist: regex ; extra == 'regex' + +Lark is a modern general-purpose parsing library for Python. +With Lark, you can parse any context-free grammar, efficiently, with very little code. +Main Features: +- Builds a parse-tree (AST) automagically, based on the structure of the grammar +- Earley parser +- Can parse all context-free grammars +- Full support for ambiguous grammars +- LALR(1) parser +- Fast and light, competitive with PLY +- Can generate a stand-alone parser +- CYK parser, for highly ambiguous grammars +- EBNF grammar +- Unicode fully supported +- Automatic line & column tracking +- Standard library of terminals (strings, numbers, names, etc.) +- Import grammars from Nearley.js +- Extensive test suite +- And much more! +Since version 1.2, only Python versions 3.8 and up are supported. diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/RECORD b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..aa8522672e4c221e72d9a9d7db457d8f63aa0a89 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/RECORD @@ -0,0 +1,82 @@ +lark-1.2.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +lark-1.2.2.dist-info/LICENSE,sha256=Lu5g9S1OETV7-J5ysDTQUOKF5H_aE2HlZi-zIu4n13E,1055 +lark-1.2.2.dist-info/METADATA,sha256=S-69HuNJr0ktlvb7J5XE48ghb_6ahYn8ksdW9HcB-d0,1831 +lark-1.2.2.dist-info/RECORD,, +lark-1.2.2.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91 +lark-1.2.2.dist-info/entry_points.txt,sha256=WXYg_uCUdFlxQDPUhli3HFah37bNNFQfXLdzCqsacGI,61 +lark-1.2.2.dist-info/top_level.txt,sha256=dyS6jg8hCHHkXWvsfcIMO8rjlv_bdzAxiE0lkkzJ5hk,5 +lark/__init__.py,sha256=bc0tK7h7XwHA-Y4vVeJoNIqSMA-MHVTihq8yy795WXo,744 +lark/__pycache__/__init__.cpython-311.pyc,, +lark/__pycache__/ast_utils.cpython-311.pyc,, +lark/__pycache__/common.cpython-311.pyc,, +lark/__pycache__/exceptions.cpython-311.pyc,, +lark/__pycache__/grammar.cpython-311.pyc,, +lark/__pycache__/indenter.cpython-311.pyc,, +lark/__pycache__/lark.cpython-311.pyc,, +lark/__pycache__/lexer.cpython-311.pyc,, +lark/__pycache__/load_grammar.cpython-311.pyc,, +lark/__pycache__/parse_tree_builder.cpython-311.pyc,, +lark/__pycache__/parser_frontends.cpython-311.pyc,, +lark/__pycache__/reconstruct.cpython-311.pyc,, +lark/__pycache__/tree.cpython-311.pyc,, +lark/__pycache__/tree_matcher.cpython-311.pyc,, +lark/__pycache__/tree_templates.cpython-311.pyc,, +lark/__pycache__/utils.cpython-311.pyc,, +lark/__pycache__/visitors.cpython-311.pyc,, +lark/__pyinstaller/__init__.py,sha256=_PpFm44f_mwHlCpvYgv9ZgubLfNDc3PlePVir4sxRfI,182 +lark/__pyinstaller/__pycache__/__init__.cpython-311.pyc,, +lark/__pyinstaller/__pycache__/hook-lark.cpython-311.pyc,, +lark/__pyinstaller/hook-lark.py,sha256=5aFHiZWVHPRdHT8qnb4kW4JSOql5GusHodHR25_q9sU,599 +lark/ast_utils.py,sha256=jwn44ocNQhZGbfcFsEZnwi_gGvPbNgzjQ-0RuEtwDzI,2117 +lark/common.py,sha256=M9-CFAUP3--OkftyyWjke-Kc1-pQMczT1MluHCFwdy4,3008 +lark/exceptions.py,sha256=g76ygMPfSMl6ukKqFAZVpR2EAJTOOdyfJ_ALXc_MCR8,10939 +lark/grammar.py,sha256=DR17QSLSKCRhMOqx2UQh4n-Ywu4CD-wjdQxtuM8OHkY,3665 +lark/grammars/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +lark/grammars/__pycache__/__init__.cpython-311.pyc,, +lark/grammars/common.lark,sha256=FV9xGIPiPqHRM4ULAxP6jApXRTVsSwbOe697I9s7DLs,885 +lark/grammars/lark.lark,sha256=nq1NTZYqm_DPI2mjRIlpd3ZcxPjGhapA4GUzkcfBTQs,1541 +lark/grammars/python.lark,sha256=WMakTkpzCqOd0jUjYONI3LOnSy2KRN9NoL9pFtAZYCI,10641 +lark/grammars/unicode.lark,sha256=d9YCz0XWimdl4F8M5YCptavBcFG9D58Yd4aMwxjYtEI,96 +lark/indenter.py,sha256=L5uNDYUMNrk4ZTWKmW0Tu-H-3GGErLOHygMC32N_twE,4221 +lark/lark.py,sha256=_IHWmTxt43kfd9eYVtwx58zEWWSFAq9_gKH7Oeu5PZs,28184 +lark/lexer.py,sha256=OwgQPCpQ-vUi-2aeZztsydd4DLkEgCbZeucvEPvHFi4,24037 +lark/load_grammar.py,sha256=WYZDxyO6omhA8NKyMjSckfAMwVKuIMF3liiYXE_-kHo,53946 +lark/parse_tree_builder.py,sha256=jT_3gCEkBGZoTXAWSnhMn1kRuJILWB-E7XkUciYNHI4,14412 +lark/parser_frontends.py,sha256=mxMXxux2hkfTfE859wuVp4-Fr1no6YVEUt8toDjEdPQ,10165 +lark/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +lark/parsers/__pycache__/__init__.cpython-311.pyc,, +lark/parsers/__pycache__/cyk.cpython-311.pyc,, +lark/parsers/__pycache__/earley.cpython-311.pyc,, +lark/parsers/__pycache__/earley_common.cpython-311.pyc,, +lark/parsers/__pycache__/earley_forest.cpython-311.pyc,, +lark/parsers/__pycache__/grammar_analysis.cpython-311.pyc,, +lark/parsers/__pycache__/lalr_analysis.cpython-311.pyc,, +lark/parsers/__pycache__/lalr_interactive_parser.cpython-311.pyc,, +lark/parsers/__pycache__/lalr_parser.cpython-311.pyc,, +lark/parsers/__pycache__/lalr_parser_state.cpython-311.pyc,, +lark/parsers/__pycache__/xearley.cpython-311.pyc,, +lark/parsers/cyk.py,sha256=c3GLk3kq23Xwb8MqUOjvivwP488KJY6NUWgxqeR5980,12192 +lark/parsers/earley.py,sha256=03sW9vfBkcH4NR72EBt8HkndDKSVSH3IdRnDulXWy24,15117 +lark/parsers/earley_common.py,sha256=e2e6NrNucw-WMiNV8HqQ_TpGx6P7v_S8f5aEcF0Tkqo,1620 +lark/parsers/earley_forest.py,sha256=w4JTb4tVMewue8dL-gCO96-Uo0wd4BbQUfSfIhr7txY,31332 +lark/parsers/grammar_analysis.py,sha256=rQ4Sn9EP8gjXGTZXEiWLW0KByPPpeKpN5hSIQZgNl3I,7141 +lark/parsers/lalr_analysis.py,sha256=DGHFk2tIluIyeFEVFfsMRU77DVbd598IJnUUOXO04yo,12207 +lark/parsers/lalr_interactive_parser.py,sha256=LsgfT1gdne8pXHTCsN6bl6zD6Pdh2dDqp1rIWOzp7Yw,5757 +lark/parsers/lalr_parser.py,sha256=6U8jP1AlUsuGxgJBWMq15WuGuyaolsLPevcf8HZ_zZk,4586 +lark/parsers/lalr_parser_state.py,sha256=QZ12p4CtvcvFAIKIqkeDBJYgEU3ntQllBJDYXb419ls,3793 +lark/parsers/xearley.py,sha256=DboXMNtuN0G-SXrrDm5zgUDUekz85h0Rih2PRvcf1LM,7825 +lark/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +lark/reconstruct.py,sha256=s7CevBXchUG_fe2otdAITxIaSXCEIiSjy4Sbh5QC0hs,3763 +lark/tools/__init__.py,sha256=FeKYmVUjXSt-vlQm2ktyWkcxaOCTOkZnHD_kOUWjUuA,2469 +lark/tools/__pycache__/__init__.cpython-311.pyc,, +lark/tools/__pycache__/nearley.cpython-311.pyc,, +lark/tools/__pycache__/serialize.cpython-311.pyc,, +lark/tools/__pycache__/standalone.cpython-311.pyc,, +lark/tools/nearley.py,sha256=QaLYdW6mYQdDq8JKMisV3lvPqzF0wPgu8q8BtsSA33g,6265 +lark/tools/serialize.py,sha256=nwt46LNxkDm0T_Uh9k2wS4fcfgvZQ2dy4-YC_aKhTQk,965 +lark/tools/standalone.py,sha256=6eXDqBuzZSpE5BGZm_Fh6X5yRhAPYxNVyl2aUU3ABzA,5627 +lark/tree.py,sha256=aWWHMazid8bbJanhmCjK9XK2jRFJ6N6WmlwXJGTsz28,8522 +lark/tree_matcher.py,sha256=jHdZJggn405SXmPpGf9U9HLrrsfP4eNNZaj267UTB00,6003 +lark/tree_templates.py,sha256=sSnfw1m8txAkJOYhcQrooG7xajVyVplunzTnNsxY720,6139 +lark/utils.py,sha256=3qd1-c0YgHYklvx1hA28qF7N_Ty1Zz6TbtCFMzQanNk,11270 +lark/visitors.py,sha256=VJ3T1m8p78MwXJotpOAvn06mYEqKyuIlhsAF51U-a3w,21422 diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/WHEEL b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..71360e028d9f29e8cf66c9737e4ab9a7a4d352e6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (72.2.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/entry_points.txt b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec317d7da483edbd9ff23577367c3a6fa5de9525 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[pyinstaller40] +hook-dirs = lark.__pyinstaller:get_hook_dirs diff --git a/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/top_level.txt b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc30e96adc75f793da4efdd21e3aff9b501659ed --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lark-1.2.2.dist-info/top_level.txt @@ -0,0 +1 @@ +lark diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__init__.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..410d5142dcde94e27b16d2f7783a8e3bf0fb51f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/__init__.py @@ -0,0 +1,23 @@ +__all__ = ['CharacterLevelParser', + 'CharacterLevelParserConfig', + 'StringParser', + 'RegexParser', + 'UnionParser', + 'SequenceParser', + 'JsonSchemaParser', + 'TokenEnforcer', + 'TokenEnforcerTokenizerData', + 'LMFormatEnforcerException', + 'FormatEnforcerAnalyzer',] + +from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig, StringParser, UnionParser, SequenceParser +from .regexparser import RegexParser +from .jsonschemaparser import JsonSchemaParser +from .tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData +from .exceptions import LMFormatEnforcerException +try: + from .analyzer import FormatEnforcerAnalyzer +except ImportError as e: + import logging + logging.warning(e) + FormatEnforcerAnalyzer = None diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c68d83c312641d3420f1149a7d46f994cf702b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/analyzer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/analyzer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dab669068f00c8fb5e8ee2f7ba016d1a02e388e4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/analyzer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/characterlevelparser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/characterlevelparser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60eb83c0b8ddfa1805b02aba9676b4233b4b5511 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/characterlevelparser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/consts.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/consts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c9b19fc23096633de431c3b96e3e19b6009a805 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/consts.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/exceptions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d7e28d2234a340ddd70094e97630e12c5ed6b93 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/jsonschemaparser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/jsonschemaparser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f25191450f8bad1faa1c43c2b80193c59abf2d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/jsonschemaparser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/regexparser.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/regexparser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69fb1d38c642fcf510b2a80f8ca65d9dba5ab69a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/regexparser.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/tokenenforcer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/tokenenforcer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ff4b68cb5b75c69a6bca984d4f8727cc81757cb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/tokenenforcer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/tokenizerprefixtree.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/tokenizerprefixtree.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f78c9403ee383465a505e2eaf9232e74bccfe82 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/__pycache__/tokenizerprefixtree.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/analyzer.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5fbda0125f6c244a1bab194e43922119b5a32c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/analyzer.py @@ -0,0 +1,77 @@ +from typing import Dict, Hashable, List +try: + import numpy as np + import numpy.typing as npt +except ImportError as e: + class FormatEnforcerAnalyzer: # type: ignore + def __init__(self, *args, **kwargs): + pass + def report_raw_logits(self, *args, **kwargs): + pass + def generate_report_dict(self, *args, **kwargs): + return {} + raise ImportError('FormatEnforcerAnalyzer not available because numpy is not installed. Please install it with "pip install numpy"') from e + +from . import TokenEnforcer + +class FormatEnforcerAnalyzer: + """A helper class to help analyze the format enforcer's behavior.""" + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + self.raw_logits: Dict[Hashable, npt.ArrayLike] = {} + + def report_raw_logits(self, output_tokens: List[int], logits: npt.ArrayLike): + """Report what logits were generated for a specific token sequence. The logits must be before any processing / filtering.""" + self.raw_logits[tuple(output_tokens)] = logits + + def generate_report_dict(self, output_tokens: List[int]) -> dict: + """Generate a report dict containing the analysis results for a specific output token sequence.""" + scores_matrix: List[npt.ArrayLike] = [] + allowed_tokens_matrix: List[List[int]] = [] + for idx in range(len(output_tokens)): + prefix = output_tokens[:idx] + prefix_tuple = tuple(prefix) + if prefix_tuple in self.raw_logits: + scores_matrix.append(self.raw_logits[prefix_tuple]) + allowed_tokens_matrix.append(self.token_enforcer.get_allowed_tokens(prefix)) + + logits = np.array(scores_matrix) # n_tokens * vocab_size + softmax_logits = _softmax(logits) # n_tokens * vocab_size + original_indices = softmax_logits.argmax(axis=1) # n_tokens + original_scores = _select_array(softmax_logits, original_indices) # n_tokens + + single_token_dict: Dict[int, str] = {token_id: token_str for token_id, token_str, _ in self.token_enforcer.regular_tokens} + def single_token_decoder(token_id: int) -> str: + if token_id in single_token_dict: + return single_token_dict[token_id] + return self.token_enforcer.decoder([token_id]) + + original_tokens = [single_token_decoder(idx) for idx in original_indices] + + penalty_matrix = np.full_like(softmax_logits, -np.inf) + for row in range(penalty_matrix.shape[0]): + penalty_matrix[row][allowed_tokens_matrix[row]] = 0 + enfored_softmax_logits = softmax_logits + penalty_matrix + + enforced_indices = enfored_softmax_logits.argmax(axis=1) + enforced_scores = _select_array(enfored_softmax_logits, enforced_indices) + + enforced_tokens = [single_token_decoder(idx) for idx in enforced_indices] + df_dict = {} # In order to minimize the package's dependencies, we don't create a dataframe, but create a dataframe-like dictionary instead. + df_dict['generated_token'] = enforced_tokens + df_dict['generated_token_idx'] = enforced_indices.tolist() + df_dict['generated_score'] = enforced_scores.tolist() + df_dict['leading_token'] = original_tokens + df_dict['leading_token_idx'] = original_indices.tolist() + df_dict['leading_score'] = original_scores.tolist() + + return df_dict + +def _softmax(arr: np.ndarray) -> np.ndarray: + """Compute softmax values for each sets of scores in arr.""" + e_arr = np.exp(arr) + return e_arr / np.sum(e_arr, axis=1, keepdims=True) + +def _select_array(arr: np.ndarray, index_array: np.ndarray) -> np.ndarray: + # https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + return np.take_along_axis(arr, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1) \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/characterlevelparser.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/characterlevelparser.py new file mode 100644 index 0000000000000000000000000000000000000000..186eab9fcde8a5b45cc228846706cfa39b8930e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/characterlevelparser.py @@ -0,0 +1,187 @@ +import abc +import os +from dataclasses import dataclass, field +from typing import Hashable, List, Optional, TypeVar +from .consts import (COMPLETE_ALPHABET, WHITESPACE_CHARACTERS, DEFAULT_MAX_CONSECUTIVE_WHITESPACES, + DEFAULT_FORCE_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, + CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH, + DEFAULT_MAX_JSON_ARRAY_LENGTH) + + +def _parse_bool(s: str) -> bool: + return s and (s.strip().lower() in ['true', '1']) + + +def _env_or_default_field(env_var: str, default_val): + default_val_type = type(default_val) + parser_func = _parse_bool if default_val_type == bool else default_val_type + def factory_func(): + return parser_func(os.environ.get(env_var, str(default_val))) + return field(default_factory=factory_func) + + +@dataclass +class CharacterLevelParserConfig: + alphabet: str = COMPLETE_ALPHABET + max_consecutive_whitespaces: int = _env_or_default_field(CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, + DEFAULT_MAX_CONSECUTIVE_WHITESPACES) + """How many consective whitespaces the JsonSchemaParser will allow""" + force_json_field_order: bool = _env_or_default_field(CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, + DEFAULT_FORCE_JSON_FIELD_ORDER) + """Whether the JsonSchemaParser will force fields to appear in the + order of the 'required' field in the schema""" + max_json_array_length: int = _env_or_default_field(CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH, + DEFAULT_MAX_JSON_ARRAY_LENGTH) + """What is the maximum json array length if not specified by the schema. Helps the LLM + avoid infinite loops.""" + + +class CharacterLevelParser(abc.ABC): + """CharacterLevelParser is an interface for classes that can parse strings one character at a time, and determine which characters are allowed at any specific time""" + + def __init__(self, config: Optional[CharacterLevelParserConfig] = None): + self._config = config or CharacterLevelParserConfig() + + @abc.abstractmethod + def add_character(self, new_character: str) -> 'CharacterLevelParser': + """Add a character to the parser, and return a new parser that represents the state of the parser after the character has been added. This has to be + an immutable operation - the original CharacterLevelParser (self) must not be modified.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_allowed_characters(self) -> str: + """Return a string containing all characters that are allowed at the current point in the parsing process.""" + raise NotImplementedError() + + @abc.abstractmethod + def can_end(self) -> bool: + """Return True if the parser is in a state where it can end (potentially finished parsing the desired structure), and False otherwise.""" + raise NotImplementedError() + + def shortcut_key(self) -> Optional[Hashable]: + """Optional. Return a key that denotes that this state is a repeating state, full tree traversal should be avoided.""" + return None + + def cache_key(self) -> Optional[Hashable]: + """Optional. Return a key that denotes that this state is a repeating state, and if it is visited again, results can be cached.""" + return None + + @property + def config(self) -> CharacterLevelParserConfig: + return self._config + + @config.setter + def config(self, new_config: CharacterLevelParserConfig): + self._config = new_config + return self + + +class StringParser(CharacterLevelParser): + """RegexParser is an example CharacterLevelParser that only allows an exact string. It is a debugging / learning tool + to show how CharacterLevelParser works together with TokenizerPrefixTree to filter the allowed tokens (some of whom may contain multiple characters)""" + def __init__(self, string: str): + self.target_str = string + + def add_character(self, new_character: str) -> CharacterLevelParser: + if self.target_str.startswith(new_character): + return StringParser(self.target_str[len(new_character):]) + else: + raise ValueError(f"Expected '{self.target_str[0]}' but got '{new_character}'") + + def get_allowed_characters(self) -> str: + return self.target_str[0] if self.target_str else "" + + def can_end(self) -> bool: + return not self.target_str + + +class ForceStopParser(CharacterLevelParser): + """A simple parser that forbids any characters except the stop token. Used to force stop LM operation""" + def __init__(self, allow_whitespace: bool = False): + self.allow_whitespace = allow_whitespace + def add_character(self, new_character: str) -> CharacterLevelParser: + return self + def get_allowed_characters(self) -> str: + return WHITESPACE_CHARACTERS if self.allow_whitespace else "" + def can_end(self) -> bool: + return True + + +class UnionParser(CharacterLevelParser): + """A parser that allows a string that would be allowed by any of several different parsers""" + def __init__(self, parsers: List[CharacterLevelParser]): + self.parsers = parsers + + def add_character(self, new_character: str) -> CharacterLevelParser: + # This is a bit of a performance hit, as it means get_allowed_characters() is called twice. + relevant_parsers = [parser for parser in self.parsers if new_character in parser.get_allowed_characters()] + next_parsers = [parser.add_character(new_character) for parser in relevant_parsers] + if len(next_parsers) == 1: + return next_parsers[0] + return UnionParser(next_parsers) + + def get_allowed_characters(self) -> str: + allowed = "".join([parser.get_allowed_characters() for parser in self.parsers]) + return "".join(set(allowed)) + + def can_end(self) -> bool: + return any([parser.can_end() for parser in self.parsers]) + + def shortcut_key(self) -> Optional[Hashable]: + unique_shortcut_keys = set(parser.shortcut_key() for parser in self.parsers) + if len(unique_shortcut_keys) == 1: + return next(iter(unique_shortcut_keys)) + return None + + def cache_key(self) -> Optional[Hashable]: + all_cache_keys = tuple(parser.cache_key() for parser in self.parsers) + if all(key is not None for key in all_cache_keys): + return ('union', all_cache_keys) + return None + + +class SequenceParser(CharacterLevelParser): + """A parser that is a sequence of multiple parsers.""" + def __init__(self, parsers: List[CharacterLevelParser]): + self.parsers = parsers + + def add_character(self, new_character: str) -> CharacterLevelParser: + legal_parsers = [] + # Tricky edge case: if the first parser can both end and accept the character, + # and the second parser can also accept, we don't know which scenario we are dealing + # with, so we need to return a UnionParser. + for idx, parser in enumerate(self.parsers): + if new_character in parser.get_allowed_characters(): + updated_parser = parser.add_character(new_character) + next_parsers = [updated_parser] + self.parsers[idx+1:] + if len(next_parsers) == 1: + legal_parsers.append(next_parsers[0]) + else: + legal_parsers.append(SequenceParser(next_parsers)) + if not parser.can_end(): + break + if len(legal_parsers) == 1: + return legal_parsers[0] + return UnionParser(legal_parsers) + + def get_allowed_characters(self) -> str: + allowed_characters = set() + for parser in self.parsers: + allowed_characters.update(parser.get_allowed_characters()) + if not parser.can_end(): + break + return "".join(allowed_characters) + + def can_end(self) -> bool: + return all([parser.can_end() for parser in self.parsers]) + + def shortcut_key(self) -> Optional[str]: + return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None + + def cache_key(self) -> Optional[Hashable]: + all_cache_keys = tuple(parser.cache_key() for parser in self.parsers) + if all(key is not None for key in all_cache_keys): + return ('sequence', all_cache_keys) + return None + + diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/consts.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/consts.py new file mode 100644 index 0000000000000000000000000000000000000000..620ad737b734a92c6a6a67cbae4a5dacad02fbbf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/consts.py @@ -0,0 +1,20 @@ +COMPLETE_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+-=[]{};:,./<>? `'\"" +DEFAULT_MAX_CONSECUTIVE_WHITESPACES = 12 +DEFAULT_FORCE_JSON_FIELD_ORDER = False +DEFAULT_MAX_JSON_ARRAY_LENGTH = 20 +WHITESPACE_CHARACTERS = " \t\n\r" +BACKSLASH = "\\" +BACKSLASH_ESCAPING_CHARACTERS = '"\\/bfnrt' # Characters allowed after an escaping backslash, except unicode +BACKSLACH_UNICODE_ESCAPE = "u" + +CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES = 'LMFE_MAX_CONSECUTIVE_WHITESPACES' +"""Environment variable for externally controlling how many consective whitespaces the +JsonSchemaParser will allow. Default: 12""" + +CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER = 'LMFE_STRICT_JSON_FIELD_ORDER' +"""Environment variable for externally controlling whether the JsonSchemaParser will force +fields to appear in the order of the 'required' field in the schema. Default: false""" + +CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH = 'LMFE_MAX_JSON_ARRAY_LENGTH' +"""Environment variable for externally controlling what is the maximal JSON array length, +if not specified by the schema. Default: 20""" diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/exceptions.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..c90e28fc10575f15eb3a38121963593595f1cb6e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/exceptions.py @@ -0,0 +1,3 @@ +class LMFormatEnforcerException(Exception): + """Base class for exceptions in this module.""" + pass \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__init__.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..213db9fea78925705a0e14153176f9f3f10b71f1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/jsonschemaobject.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/jsonschemaobject.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d86ca1cf820d739c0ca4f0ae41b5f3515ac87b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/jsonschemaobject.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/jsonschemaobjectutil.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/jsonschemaobjectutil.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6148a3a46f758d0a6aff630a640a7ab9b2be8d58 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/__pycache__/jsonschemaobjectutil.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/external/jsonschemaobject.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/jsonschemaobject.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecfd9bc8f5aca78d8d93c0510b6829f1b89285a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/jsonschemaobject.py @@ -0,0 +1,345 @@ +# https://github.com/koxudaxi/datamodel-code-generator/blob/master/datamodel_code_generator/parser/jsonschema.py +# MIT License + +# Copyright (c) 2019 Koudai Aono + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from __future__ import annotations + +import enum as _enum +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Set, + Union, +) +from warnings import warn + + +from pydantic import ( + Field, +) + +from .jsonschemaobjectutil import ( + PYDANTIC_V2, + BaseModel, + cached_property, + field_validator, + model_validator, + is_url, + Types, + UnionIntFloat +) + +if PYDANTIC_V2: + from pydantic import ConfigDict + + +def get_model_by_path( + schema: Union[Dict[str, Any], List[Any]], keys: Union[List[str], List[int]] +) -> Dict[Any, Any]: + model: Union[Dict[Any, Any], List[Any]] + if not keys: + model = schema + elif len(keys) == 1: + if isinstance(schema, dict): + model = schema.get(keys[0], {}) # type: ignore + else: # pragma: no cover + model = schema[int(keys[0])] + elif isinstance(schema, dict): + model = get_model_by_path(schema[keys[0]], keys[1:]) # type: ignore + else: + model = get_model_by_path(schema[int(keys[0])], keys[1:]) + if isinstance(model, dict): + return model + raise NotImplementedError( # pragma: no cover + f'Does not support json pointer to array. schema={schema}, key={keys}' + ) + + +json_schema_data_formats: Dict[str, Dict[str, Types]] = { + 'integer': { + 'int32': Types.int32, + 'int64': Types.int64, + 'default': Types.integer, + 'date-time': Types.date_time, + 'unix-time': Types.int64, + }, + 'number': { + 'float': Types.float, + 'double': Types.double, + 'decimal': Types.decimal, + 'date-time': Types.date_time, + 'time': Types.time, + 'default': Types.number, + }, + 'string': { + 'default': Types.string, + 'byte': Types.byte, # base64 encoded string + 'binary': Types.binary, + 'date': Types.date, + 'date-time': Types.date_time, + 'time': Types.time, + 'password': Types.password, + 'email': Types.email, + 'idn-email': Types.email, + 'uuid': Types.uuid, + 'uuid1': Types.uuid1, + 'uuid2': Types.uuid2, + 'uuid3': Types.uuid3, + 'uuid4': Types.uuid4, + 'uuid5': Types.uuid5, + 'uri': Types.uri, + 'uri-reference': Types.string, + 'hostname': Types.hostname, + 'ipv4': Types.ipv4, + 'ipv4-network': Types.ipv4_network, + 'ipv6': Types.ipv6, + 'ipv6-network': Types.ipv6_network, + 'decimal': Types.decimal, + 'integer': Types.integer, + }, + 'boolean': {'default': Types.boolean}, + 'object': {'default': Types.object}, + 'null': {'default': Types.null}, + 'array': {'default': Types.array}, +} + + +class JSONReference(_enum.Enum): + LOCAL = 'LOCAL' + REMOTE = 'REMOTE' + URL = 'URL' + + +class Discriminator(BaseModel): + propertyName: str + mapping: Optional[Dict[str, str]] = None + + +class JsonSchemaObject(BaseModel): + if not TYPE_CHECKING: + if PYDANTIC_V2: + + @classmethod + def get_fields(cls) -> Dict[str, Any]: + return cls.model_fields + + else: + + @classmethod + def get_fields(cls) -> Dict[str, Any]: + return cls.__fields__ + + @classmethod + def model_rebuild(cls) -> None: + cls.update_forward_refs() + + __constraint_fields__: Set[str] = { + 'exclusiveMinimum', + 'minimum', + 'exclusiveMaximum', + 'maximum', + 'multipleOf', + 'minItems', + 'maxItems', + 'minLength', + 'maxLength', + 'pattern', + 'uniqueItems', + } + __extra_key__: str = 'extras' + + @model_validator(mode='before') + def validate_exclusive_maximum_and_exclusive_minimum( + cls, values: Dict[str, Any] + ) -> Any: + + # LMFE addition: support "additionalProperties": bool option + if isinstance(values, bool): + return values + + exclusive_maximum: Union[float, bool, None] = values.get('exclusiveMaximum') + exclusive_minimum: Union[float, bool, None] = values.get('exclusiveMinimum') + + if exclusive_maximum is True: + values['exclusiveMaximum'] = values['maximum'] + del values['maximum'] + elif exclusive_maximum is False: + del values['exclusiveMaximum'] + if exclusive_minimum is True: + values['exclusiveMinimum'] = values['minimum'] + del values['minimum'] + elif exclusive_minimum is False: + del values['exclusiveMinimum'] + return values + + @field_validator('ref') + def validate_ref(cls, value: Any) -> Any: + if isinstance(value, str) and '#' in value: + if value.endswith('#/'): + return value[:-1] + elif '#/' in value or value[0] == '#' or value[-1] == '#': + return value + return value.replace('#', '#/') + return value + + items: Union[List[JsonSchemaObject], JsonSchemaObject, bool, None] = None + uniqueItems: Optional[bool] = None + type: Union[str, List[str], None] = None + format: Optional[str] = None + pattern: Optional[str] = None + minLength: Optional[int] = None + maxLength: Optional[int] = None + minimum: Optional[UnionIntFloat] = None + maximum: Optional[UnionIntFloat] = None + minItems: Optional[int] = None + maxItems: Optional[int] = None + multipleOf: Optional[float] = None + exclusiveMaximum: Union[float, bool, None] = None + exclusiveMinimum: Union[float, bool, None] = None + additionalProperties: Union[JsonSchemaObject, bool, None] = None + patternProperties: Optional[Dict[str, JsonSchemaObject]] = None + oneOf: List[JsonSchemaObject] = [] + anyOf: List[JsonSchemaObject] = [] + allOf: List[JsonSchemaObject] = [] + enum: List[Any] = [] + writeOnly: Optional[bool] = None + properties: Optional[Dict[str, Union[JsonSchemaObject, bool]]] = None + required: List[str] = [] + ref: Optional[str] = Field(default=None, alias='$ref') + nullable: Optional[bool] = False + x_enum_varnames: List[str] = Field(default=[], alias='x-enum-varnames') + description: Optional[str] = None + title: Optional[str] = None + example: Any = None + examples: Any = None + default: Any = None + id: Optional[str] = Field(default=None, alias='$id') + custom_type_path: Optional[str] = Field(default=None, alias='customTypePath') + custom_base_path: Optional[str] = Field(default=None, alias='customBasePath') + extras: Dict[str, Any] = Field(alias=__extra_key__, default_factory=dict) + discriminator: Union[Discriminator, str, None] = None + if PYDANTIC_V2: + model_config = ConfigDict( + arbitrary_types_allowed=True, + ignored_types=(cached_property,), + ) + else: + + class Config: + arbitrary_types_allowed = True + keep_untouched = (cached_property,) + smart_casts = True + + if not TYPE_CHECKING: + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + self.extras = {k: v for k, v in data.items() if k not in EXCLUDE_FIELD_KEYS} + + @cached_property + def is_object(self) -> bool: + return ( + self.properties is not None + or self.type == 'object' + and not self.allOf + and not self.oneOf + and not self.anyOf + and not self.ref + ) + + @cached_property + def is_array(self) -> bool: + return self.items is not None or self.type == 'array' + + @cached_property + def ref_object_name(self) -> str: # pragma: no cover + return self.ref.rsplit('/', 1)[-1] # type: ignore + + @field_validator('items', mode='before') + def validate_items(cls, values: Any) -> Any: + # this condition expects empty dict + return values or None + + @cached_property + def has_default(self) -> bool: + return 'default' in self.__fields_set__ or 'default_factory' in self.extras + + @cached_property + def has_constraint(self) -> bool: + return bool(self.__constraint_fields__ & self.__fields_set__) + + @cached_property + def ref_type(self) -> Optional[JSONReference]: + if self.ref: + return get_ref_type(self.ref) + return None # pragma: no cover + + @cached_property + def type_has_null(self) -> bool: + return isinstance(self.type, list) and 'null' in self.type + + +@lru_cache() +def get_ref_type(ref: str) -> JSONReference: + if ref[0] == '#': + return JSONReference.LOCAL + elif is_url(ref): + return JSONReference.URL + return JSONReference.REMOTE + + +def _get_type(type_: str, format__: Optional[str] = None) -> Types: + if type_ not in json_schema_data_formats: + return Types.any + data_formats: Optional[Types] = json_schema_data_formats[type_].get( + 'default' if format__ is None else format__ + ) + if data_formats is not None: + return data_formats + + warn( + 'format of {!r} not understood for {!r} - using default' + ''.format(format__, type_) + ) + return json_schema_data_formats[type_]['default'] + + +JsonSchemaObject.model_rebuild() + +DEFAULT_FIELD_KEYS: Set[str] = { + 'example', + 'examples', + 'description', + 'discriminator', + 'title', + 'const', + 'default_factory', +} + +EXCLUDE_FIELD_KEYS = (set(JsonSchemaObject.get_fields()) - DEFAULT_FIELD_KEYS) | { + '$id', + '$ref', + JsonSchemaObject.__extra_key__, +} diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/external/jsonschemaobjectutil.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/jsonschemaobjectutil.py new file mode 100644 index 0000000000000000000000000000000000000000..f89609c9b6f093ac5e4c5eaccb2efdc453a9173f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/external/jsonschemaobjectutil.py @@ -0,0 +1,231 @@ +# https://github.com/koxudaxi/datamodel-code-generator/blob/master/datamodel_code_generator/util.py +# MIT License + +# Copyright (c) 2019 Koudai Aono + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, TypeVar +from enum import Enum, auto + +import pydantic +from packaging import version +from pydantic import BaseModel as _BaseModel + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + TypeVar, + Union, +) + +PYDANTIC_VERSION = version.parse( + pydantic.VERSION if isinstance(pydantic.VERSION, str) else str(pydantic.VERSION) +) + +PYDANTIC_V2: bool = PYDANTIC_VERSION >= version.parse('2.0b3') + +if PYDANTIC_V2: + from pydantic import GetCoreSchemaHandler + from pydantic_core import core_schema + +if TYPE_CHECKING: + cached_property = property + from yaml import SafeLoader + + Protocol = object + runtime_checkable: Callable[..., Any] + + from typing_extensions import Literal +else: + try: + from typing import Protocol + except ImportError: + from typing_extensions import Protocol # noqa + try: + from typing import runtime_checkable + except ImportError: + from typing_extensions import runtime_checkable # noqa + try: + from yaml import CSafeLoader as SafeLoader + except ImportError: # pragma: no cover + from yaml import SafeLoader + + try: + from functools import cached_property + except ImportError: + _NOT_FOUND = object() + + class cached_property: + def __init__(self, func: Callable) -> None: + self.func: Callable = func + self.__doc__: Any = func.__doc__ + + def __get__(self, instance: Any, owner: Any = None) -> Any: + value = instance.__dict__.get(self.func.__name__, _NOT_FOUND) + if value is _NOT_FOUND: # pragma: no cover + value = instance.__dict__[self.func.__name__] = self.func(instance) + return value + + +SafeLoader.yaml_constructors[ + 'tag:yaml.org,2002:timestamp' +] = SafeLoader.yaml_constructors['tag:yaml.org,2002:str'] + + +Model = TypeVar('Model', bound=_BaseModel) + + +def model_validator( + mode: Literal['before', 'after'] = 'after', +) -> Callable[[Callable[[Model, Any], Any]], Callable[[Model, Any], Any]]: + def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]: + if PYDANTIC_V2: + from pydantic import model_validator as model_validator_v2 + + return model_validator_v2(mode=mode)(method) # type: ignore + else: + from pydantic import root_validator + + return root_validator(method, pre=mode == 'before') # type: ignore + + return inner + + +def field_validator( + field_name: str, + *fields: str, + mode: Literal['before', 'after'] = 'after', +) -> Callable[[Any], Callable[[Model, Any], Any]]: + def inner(method: Callable[[Model, Any], Any]) -> Callable[[Model, Any], Any]: + if PYDANTIC_V2: + from pydantic import field_validator as field_validator_v2 + + return field_validator_v2(field_name, *fields, mode=mode)(method) # type: ignore + else: + from pydantic import validator + + return validator(field_name, *fields, pre=mode == 'before')(method) # type: ignore + + return inner + + +if PYDANTIC_V2: + from pydantic import ConfigDict as ConfigDict +else: + ConfigDict = dict # type: ignore + + +class BaseModel(_BaseModel): + if PYDANTIC_V2: + model_config = ConfigDict(strict=False) + + +def is_url(ref: str) -> bool: + return ref.startswith(('https://', 'http://')) + + +class Types(Enum): + integer = auto() + int32 = auto() + int64 = auto() + number = auto() + float = auto() + double = auto() + decimal = auto() + time = auto() + string = auto() + byte = auto() + binary = auto() + date = auto() + date_time = auto() + password = auto() + email = auto() + uuid = auto() + uuid1 = auto() + uuid2 = auto() + uuid3 = auto() + uuid4 = auto() + uuid5 = auto() + uri = auto() + hostname = auto() + ipv4 = auto() + ipv4_network = auto() + ipv6 = auto() + ipv6_network = auto() + boolean = auto() + object = auto() + null = auto() + array = auto() + any = auto() + +class UnionIntFloat: + def __init__(self, value: Union[int, float]) -> None: + self.value: Union[int, float] = value + + def __int__(self) -> int: + return int(self.value) + + def __float__(self) -> float: + return float(self.value) + + def __str__(self) -> str: + return str(self.value) + + @classmethod + def __get_validators__(cls) -> Iterator[Callable[[Any], Any]]: + yield cls.validate + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + from_int_schema = core_schema.chain_schema( + [ + core_schema.union_schema( + [core_schema.int_schema(), core_schema.float_schema()] + ), + core_schema.no_info_plain_validator_function(cls.validate), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=core_schema.no_info_plain_validator_function(cls.validate), + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(UnionIntFloat), + from_int_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.value + ), + ) + + @classmethod + def validate(cls, v: Any) -> UnionIntFloat: + if isinstance(v, UnionIntFloat): + return v + elif not isinstance(v, (int, float)): # pragma: no cover + raise TypeError(f'{v} is not int or float') + return cls(v) diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__init__.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c832f288d933c271dc41d32d8253c5d0280bb6fe Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/exllamav2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/exllamav2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acced03f3518ec6639482fd36ca797b308629cd6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/exllamav2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/haystackv1.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/haystackv1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d893d859d8b8e87ce32d9c6b09abb49fbb2ea07d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/haystackv1.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/haystackv2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/haystackv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4b474af09e0c063b76dee517ffcb5206c9f0a15 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/haystackv2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/llamacpp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/llamacpp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc08f7e15307b61baa00d1c7fd2ff687a1438b11 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/llamacpp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/transformers.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/transformers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f4cb78218f2868d9b866de9dd2e52112ca7cea1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/transformers.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/trtllm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/trtllm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fef5b1ac25bcbe13c2822b79fb86e241dd24dbed Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/trtllm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/vllm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/vllm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50c1ea83ae37e4558f2b69c73a1a4c4ce82719f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/__pycache__/vllm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/transformers.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdfd2f69a56aa727bb4557b47a03f5e1391e9ca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/transformers.py @@ -0,0 +1,146 @@ +import functools +from typing import Any, Callable, List, Optional, Tuple, Union +try: + from transformers import AutoModelForCausalLM + from transformers.generation.logits_process import LogitsWarper, PrefixConstrainedLogitsProcessor + from transformers.tokenization_utils import PreTrainedTokenizerBase +except ImportError: + raise ImportError('transformers is not installed. Please install it with "pip install transformers[torch]"') + +try: + import torch +except ImportError: + raise ImportError('pytorch is not installed. See https://pytorch.org/get-started/locally/ for installation instructions."') + +from ..characterlevelparser import CharacterLevelParser +from ..tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData +from ..analyzer import FormatEnforcerAnalyzer + +class LogitsSaverWarper(LogitsWarper): + def __init__(self, analyzer: FormatEnforcerAnalyzer) -> None: + self.analyzer = analyzer + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + cpu_inputs = input_ids.tolist() + cpu_scores = scores.tolist() + for single_batch_inputs, single_batch_scores in zip(cpu_inputs, cpu_scores): + self.analyzer.report_raw_logits(single_batch_inputs, single_batch_scores) + return scores + +class LogitsSaverManager: + warper: LogitsSaverWarper + + def __init__(self, model: AutoModelForCausalLM, analyzer: FormatEnforcerAnalyzer): + self.model = model + self.warper = None + self.old_warper = None + self.analyzer = analyzer + + def replace_logits_warper(self, filter_func = None): + self.old_warper = self.model._get_logits_warper + + def new_logits_warper(generation_config): + warpers = self.old_warper(generation_config) + self.warper = LogitsSaverWarper(self.analyzer) + warpers.insert(0, self.warper) + if filter_func is not None: + processor = PrefixConstrainedLogitsProcessor(filter_func, 1) + warpers.insert(1, processor) + return warpers + self.model._get_logits_warper = new_logits_warper + + def unreplace_logits_warper(self): + self.model._get_logits_warper = self.old_warper + + +def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase, vocab_size: int) -> List[Tuple[int, str, bool]]: + token_0 = tokenizer.encode("0")[-1] + regular_tokens = [] + for token_idx in range(vocab_size): + if token_idx in tokenizer.all_special_ids: + continue + # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. + decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:] + decoded_regular = tokenizer.decode([token_idx]) + is_word_start_token = len(decoded_after_0) > len(decoded_regular) + regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + return regular_tokens + + +def _decode_function(tokenizer: PreTrainedTokenizerBase, tokens: List[int]) -> str: + decoded = tokenizer.decode(tokens) + cleaned = decoded.rstrip('�') + return cleaned + + +def build_token_enforcer_tokenizer_data(tokenizer: PreTrainedTokenizerBase, + vocab_size: Optional[int] = None) -> TokenEnforcerTokenizerData: + vocab_size = vocab_size or len(tokenizer) + regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size) + decode_fn = functools.partial(_decode_function, tokenizer) + return TokenEnforcerTokenizerData(regular_tokens, decode_fn, tokenizer.eos_token_id) + + +class TransformersPrefixAllowedTokensFn: + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + + def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]: + token_sequence = sent.tolist() + return self.token_enforcer.get_allowed_tokens(token_sequence) + + +def build_transformers_prefix_allowed_tokens_fn(tokenizer_data: Union[PreTrainedTokenizerBase, TokenEnforcerTokenizerData], + character_level_parser: CharacterLevelParser) -> TransformersPrefixAllowedTokensFn: + """Build the prefix allowed tokens function that transformers will use to filter the tokens generated by the model. The result + can be passed to the prefix_allowed_tokens_fn parameter of the generate() method of transformers models or pipeline configurations.""" + if isinstance(tokenizer_data, PreTrainedTokenizerBase): + tokenizer_data = build_token_enforcer_tokenizer_data(tokenizer_data) + token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser) + return TransformersPrefixAllowedTokensFn(token_enforcer) + + +def generate_enforced(model: AutoModelForCausalLM, + tokenizer: Union[PreTrainedTokenizerBase, TokenEnforcerTokenizerData], + character_level_parser: CharacterLevelParser, + **kwargs: dict) -> Union[str, dict]: + """Generate text from a model while enforcing a given format, generating enforcing diagnostic information. + This can be used instead of calling model.generate(). + If return_dict_in_generate and output_scores parameters are True, diagnostic information will be returned in the result. + If you don't need this, consider using prefix_allowed_tokens_fn + build_transformers_prefix_allowed_tokens_fn() instead""" + + transformers_filter_allowed_tokens = build_transformers_prefix_allowed_tokens_fn(tokenizer, character_level_parser) + + is_multi_inputs = kwargs['input_ids'].shape[0] > 1 + is_multi_beams = kwargs.get('num_beams', 1) > 1 + support_diagnostics = not (is_multi_inputs or is_multi_beams) # TODO: Support diagnostics in these cases as well. + return_dict_in_generate = kwargs.get('return_dict_in_generate', False) + output_scores = kwargs.get('output_scores', None) + + # We do some internals hacking in order to extract the data needed for diagnostics. If we weren't asked for them, + # we are better off simply using prefix_allowed_tokens_fn parameter. + should_run_in_advanced_mode = return_dict_in_generate and output_scores and support_diagnostics + + if should_run_in_advanced_mode: + analyzer = FormatEnforcerAnalyzer(transformers_filter_allowed_tokens.token_enforcer) + logits_saver = LogitsSaverManager(model, analyzer) + logits_saver.replace_logits_warper(transformers_filter_allowed_tokens) + generate_kwargs = kwargs + + try: + output = model.generate(**generate_kwargs) + finally: + logits_saver.unreplace_logits_warper() + + df_dict = analyzer.generate_report_dict(output['sequences'][0].tolist()) + output.enforced_scores = df_dict + else: + output = model.generate(**kwargs, prefix_allowed_tokens_fn=transformers_filter_allowed_tokens) + + return output + +__all__ = [ + 'build_transformers_prefix_allowed_tokens_fn', + 'generate_enforced', + 'build_token_enforcer_tokenizer_data' +] \ No newline at end of file diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/trtllm.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/trtllm.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1fca7050bdda23c90dfdfc348b7754dfba5bf2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/integrations/trtllm.py @@ -0,0 +1,84 @@ +import math +from typing import List, Optional, Tuple, Union +import torch +from transformers import PreTrainedTokenizerBase +from lmformatenforcer import CharacterLevelParser, FormatEnforcerAnalyzer +from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData + + +class TRTLLMLogitsProcessor: + def __init__(self, token_enforcer: TokenEnforcer, eos_token_id, analyze): + self.token_enforcer = token_enforcer + self.analyzer = FormatEnforcerAnalyzer(token_enforcer) if analyze else None + self.mask: Optional[torch.Tensor] = None + self.mask_val = -math.inf + self.eos_token_id = eos_token_id + + def _trim(self, input): + return [x for x in input.tolist() if x not in \ + (self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id])] + + def __call__(self, step: int, batch_input_ids: List[List[int]], logits: torch.Tensor) -> torch.Tensor: + for idx in range(len(batch_input_ids)): + if self.analyzer: + self.analyzer.report_raw_logits(batch_input_ids[idx], logits[idx].tolist()) + + allowed_tokens = self.token_enforcer.get_allowed_tokens(self._trim(batch_input_ids[idx])) + + if self.mask is not None: + self.mask.fill_(self.mask_val) + else: + # We create it here because full_like() also copies the device and dtype + self.mask = torch.full_like(logits[idx], self.mask_val) + self.mask[allowed_tokens] = 0 + logits[idx] = logits[idx] + self.mask + + return logits + + +def _build_regular_tokens_list(tokenizer) -> List[Tuple[int, str, bool]]: + # There are many classes that can be passed here, this logic should work on all of them. + if hasattr(tokenizer, 'get_tokenizer'): + tokenizer = tokenizer.get_tokenizer() + if hasattr(tokenizer, 'tokenizer'): + tokenizer = tokenizer.tokenizer + token_0 = [tokenizer.encode("0")[-1]] + regular_tokens = [] + vocab_size = tokenizer.vocab_size + for token_idx in range(vocab_size): + if token_idx in tokenizer.all_special_ids: + continue + # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. + tensor_after_0 = torch.tensor(token_0 + [token_idx], dtype=torch.long) + decoded_after_0 = tokenizer.decode(tensor_after_0)[1:] + decoded_regular = tokenizer.decode(token_0) + is_word_start_token = len(decoded_after_0) > len(decoded_regular) + regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + return regular_tokens + + +def build_trtlmm_tokenizer_data(tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + """Build the TokenEnforcerTokenizerData from a tokenizer in order to cache it between instances""" + regular_tokens = _build_regular_tokens_list(tokenizer) + + def _decode(tokens: List[int]) -> str: + tensor = torch.tensor(tokens, dtype=torch.long) + return tokenizer.decode(tensor) + + tokenizer_data = TokenEnforcerTokenizerData(regular_tokens, _decode, tokenizer.eos_token_id) + return tokenizer_data + + +def build_trtllm_logits_processor(tokenizer: Union[PreTrainedTokenizerBase, TokenEnforcerTokenizerData], + character_level_parser: CharacterLevelParser, + analyze: bool = False) -> TRTLLMLogitsProcessor: + """ + Build logits processor for feeding it into generate function (use_py_session should be True) + """ + if isinstance(tokenizer, TokenEnforcerTokenizerData): + tokenizer_data = tokenizer + else: + tokenizer_data = build_trtlmm_tokenizer_data(tokenizer) + + token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser) + return TRTLLMLogitsProcessor(token_enforcer, tokenizer.eos_token_id, analyze) diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/jsonschemaparser.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/jsonschemaparser.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d67a947b0ac5a65ce86fcdadb897f3ddbf1707 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/jsonschemaparser.py @@ -0,0 +1,710 @@ +from copy import deepcopy +import enum +import sys +from typing import Dict, Hashable, List, Optional, Union, cast + + +from .external.jsonschemaobject import JsonSchemaObject, json_schema_data_formats +from .exceptions import LMFormatEnforcerException +from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig, ForceStopParser, SequenceParser, StringParser, UnionParser +from .consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, WHITESPACE_CHARACTERS +from .regexparser import RegexParser + +# No need to include the 'integer' option in the anyOf, as it is a subset of 'number' +_ANY_JSON_SCHEMA_DICT = {'anyOf': [{'type': type} for type in json_schema_data_formats.keys() if type != 'integer']} + +class JsonSchemaParser(CharacterLevelParser): + ANY_JSON_OBJECT_SCHEMA: JsonSchemaObject = JsonSchemaObject(**_ANY_JSON_SCHEMA_DICT) + class _Context: + model_class: JsonSchemaObject + # We store the active parser in the context, so that if a node adds to the stack, it knows + # to which parser's stack to add. + active_parser: "JsonSchemaParser" + alphabet_without_quotes: str + regex_parser_cache: Dict[str, RegexParser] = {} + + object_stack: List[CharacterLevelParser] + context: _Context + num_consecutive_whitespaces: int + last_parsed_string: str # Slight hack to allow communicating the parsed key to the object parser + last_non_whitespace_character: str # Slight hack to allow list parser to know if there is an item on top + + def __init__(self, + json_schema: Union[dict, _Context, None], + config: Optional[CharacterLevelParserConfig] = None, + existing_stack: Optional[List[CharacterLevelParser]] = None, + num_consecutive_whitespaces: int = 0): + """Create a CharacterLevelParser for parsing JSON. + :param json_schema: The json schema to parse. Can be a dict of a JSON schema, or None if any json output is allowed.""" + super().__init__(config) + if isinstance(json_schema, JsonSchemaParser._Context): + self.context = json_schema + else: + self.context = JsonSchemaParser._Context() + json_schema = json_schema or _ANY_JSON_SCHEMA_DICT + self.context.model_class = JsonSchemaObject(**json_schema) + self.context.active_parser = self + self.context.alphabet_without_quotes = self.config.alphabet.replace('"', '') + + self.num_consecutive_whitespaces = num_consecutive_whitespaces + if existing_stack is None: + self.object_stack = [get_parser(self, self.context.model_class)] + else: + self.object_stack = existing_stack + self.last_parsed_string = "" + self.last_non_whitespace_character = "" + + def add_character(self, new_character: str) -> CharacterLevelParser: + self.context.active_parser = self + # Assumption: The top-most parser that can accept the character is the one that should accept it. + # This is different from the SequenceParser, in which we need to split (union) into all options. + receiving_idx = len(self.object_stack) - 1 + last_parsed_string = self.last_parsed_string + while receiving_idx >= 0 and new_character not in self.object_stack[receiving_idx].get_allowed_characters(): + finished_receiver = self.object_stack[receiving_idx] + if isinstance(finished_receiver, StringParsingState): + last_parsed_string = finished_receiver.parsed_string + receiving_idx -= 1 + + updated_stack = self.object_stack[:receiving_idx + 1] + updated_parser = JsonSchemaParser(self.context, self.config, updated_stack, self.num_consecutive_whitespaces) + updated_parser.context.active_parser = updated_parser + updated_parser.last_parsed_string = last_parsed_string + if receiving_idx >= 0: + updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character) + if new_character in WHITESPACE_CHARACTERS: + updated_parser.num_consecutive_whitespaces += 1 + updated_parser.last_non_whitespace_character = self.last_non_whitespace_character + else: + updated_parser.num_consecutive_whitespaces = 0 + updated_parser.last_non_whitespace_character = new_character + + if updated_parser.object_stack and isinstance(updated_parser.object_stack[-1], UnionParser) and \ + any(isinstance(parser, (ObjectParsingState, ListParsingState)) for parser in updated_parser.object_stack[-1].parsers): + # If the top parser is a union parser with "advanced" (=parsers that modify the object stack) parsers inside, + # we need to split the top level parser into the different options, + # As each "fork" can live with a different object stack, and we need to make sure they have their own ones. + option_json_schema_parsers = [] + for option_parser in updated_parser.object_stack[-1].parsers: + option_stack = updated_parser.object_stack[:-1] + [option_parser] + option_parser = JsonSchemaParser(self.context, self.config, option_stack, updated_parser.num_consecutive_whitespaces) + option_parser.context.active_parser = option_parser + option_parser.last_parsed_string = last_parsed_string + option_parser.last_non_whitespace_character = updated_parser.last_non_whitespace_character + option_json_schema_parsers.append(option_parser) + return UnionParser(option_json_schema_parsers) + + # For some performance optimizations to work, we want to make sure we don't leave irrelevant + # objects at the top of the stack, which we know will be passed over next timestep + new_object_stack = updated_parser.object_stack + while new_object_stack and new_object_stack[-1].can_end() and new_object_stack[-1].get_allowed_characters() == '': + finished_receiver = new_object_stack[-1] + if isinstance(finished_receiver, StringParsingState): + updated_parser.last_parsed_string = finished_receiver.parsed_string + del new_object_stack[-1] + if new_object_stack: + new_top_parser = new_object_stack[-1] + if isinstance(new_top_parser, ListParsingState): + new_top_parser = new_top_parser._clone() + new_top_parser.num_items_seen += 1 + new_object_stack[-1] = new_top_parser + + + return updated_parser + + def get_allowed_characters(self) -> str: + self.context.active_parser = self + + allowed_character_strs = [] + for parser in reversed(self.object_stack): + # Similar to SequenceParser, if the top object can end, we need to know to accept the next character of parser below, etc. + allowed_character_strs.append(parser.get_allowed_characters()) + if not parser.can_end(): + break + if len(allowed_character_strs) > 0: + allowed_characters = "".join(allowed_character_strs) + else: + # In certain cases, beam search / sample crashes when there are less legal + # continuation tokens than there are beams. Therefore, we allow whitespace + # characters when the object stack is empty (= we are done parsing) + allowed_characters = WHITESPACE_CHARACTERS + + if self.num_consecutive_whitespaces >= self.config.max_consecutive_whitespaces: + # print("Filtering whitespace characters") + allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS) + return allowed_characters + + def can_end(self) -> bool: + return all(parser.can_end() for parser in self.object_stack) + + def shortcut_key(self) -> Optional[Hashable]: + if self.object_stack: + current_parser = self.object_stack[-1] + if isinstance(current_parser, StringParsingState): + if not current_parser.allowed_strings and current_parser.seen_opening_quote and not current_parser.seen_closing_quote and not current_parser.regex_parser: + # Performance optimization: When we are parsing a string that is not from a list of allowed strings, most tokens + # are legal. The exploration can be more costly than the LM itself for large tokenizers (because this is pure python), + # so we signal that we are in a "freetext" mode, and reuse the allowed token list throughout the run. + cur_len = len(current_parser.parsed_string) + min_len = current_parser.min_length or 0 + max_len = current_parser.max_length or sys.maxsize + assert min_len <= max_len, "Invalid schema for str: min length is larger than max length" + if cur_len < max_len: + return ('json_freetext', cur_len, min_len, max_len) + return None + + +class BaseParsingState(CharacterLevelParser): + def __init__(self, root: JsonSchemaParser): + self.root = root + + +def _merge_object_schemas(base_schema: JsonSchemaObject, option_schema: JsonSchemaObject) -> JsonSchemaObject: + base_schema_properties = base_schema.properties or {} + for property_name, property_value in base_schema_properties.items(): + # We assume that if a property exists in both base and option, the option version will be + # more specific, therefore we only take missing entries + if property_name not in option_schema.properties: + option_schema.properties[property_name] = property_value + for required_property in base_schema.required: + if required_property not in option_schema.required: + option_schema.required.append(required_property) + return option_schema + + +def get_parser( + parsing_state: JsonSchemaParser, + value_schema: JsonSchemaObject +) -> CharacterLevelParser: + if value_schema is None: + raise Exception("JsonSchemaParser: Value schema is None") + if value_schema.anyOf: + parsers = [get_parser(parsing_state, schema) for schema in value_schema.anyOf] + return UnionParser(parsers) + if value_schema.allOf: + merged_schema = value_schema.allOf[0] + for schema in value_schema.allOf[1:]: + merged_schema = _merge_object_schemas(merged_schema, schema) + return get_parser(parsing_state, merged_schema) + if value_schema.extras and 'const' in value_schema.extras: + allowed_value = value_schema.extras['const'] + is_string = type(allowed_value) == str + return StringParsingState(parsing_state, + [allowed_value], + require_opening_quote=is_string, + require_closing_quote=is_string) + if value_schema.type == "string": + return StringParsingState( + parsing_state, + value_schema.enum, + require_opening_quote=True, + min_length=value_schema.minLength, + max_length=value_schema.maxLength, + pattern=value_schema.pattern, + ) + if value_schema.oneOf: + # We create a combined object schema for each option that includes the information from the parent + # And then create a UnionParser based on the combined options + merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf] + object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas] + return UnionParser(object_parsing_options) + elif value_schema.type == "object": + return ObjectParsingState(value_schema, parsing_state) + elif value_schema.type == None and value_schema.ref: + value_class_name = value_schema.ref.split('/')[-1] + extras = parsing_state.context.model_class.extras + # Pydantic V1 and V2 have different names for the definitions field + if 'definitions' in extras: + definitions = extras['definitions'] + elif '$defs' in extras: + definitions = extras['$defs'] + else: + raise ValueError("No definitions found in schema") + class_dict = definitions[value_class_name] + value_schema = JsonSchemaObject(**class_dict) + return get_parser(parsing_state, value_schema) + elif value_schema.enum: + is_numeric = all(isinstance(i, (int, float)) for i in value_schema.enum) + is_string = all(isinstance(i, (str)) for i in value_schema.enum) + if is_string: + return StringParsingState( + parsing_state, + value_schema.enum, + require_opening_quote=True, + ) + elif is_numeric: + return StringParsingState( + parsing_state, + [str(i) for i in value_schema.enum], + require_opening_quote=False, + require_closing_quote=False, + ) + else: + raise Exception("Unsupported enum type " + str(value_schema.enum)) + elif value_schema.type == "integer": + return NumberParsingState(parsing_state, False) + elif value_schema.type == "boolean": + return StringParsingState( + parsing_state, + ["true", "false"], + require_opening_quote=False, + require_closing_quote=False, + ) + elif value_schema.type == "null": + return StringParsingState( + parsing_state, + ["null"], + require_opening_quote=False, + require_closing_quote=False, + ) + elif value_schema.type == "number": + return NumberParsingState(parsing_state, True) + elif value_schema.type == "array": + item_schema = value_schema.items or JsonSchemaParser.ANY_JSON_OBJECT_SCHEMA + return ListParsingState(parsing_state, item_schema, value_schema.minItems, value_schema.maxItems) + else: + raise Exception("Unsupported type " + str(value_schema.type)) + + +class ObjectParsingStage(enum.Enum): + START_OBJECT = "StartObject" + PARSING_KEY_OR_END = "ParsingKey" + PARSING_KEY_VALUE_SEPARATOR = "ParsingKeyValueSeparator" + PARSING_VALUE = "ParsingValue" + PARSING_SEPARATOR_OR_END = "ParsingSeparatorOrEnd" + END_OBJECT = "EndObject" + + +class ObjectParsingState(BaseParsingState): + schema_object: JsonSchemaObject + current_stage: ObjectParsingStage + existing_keys: List[str] + current_key: Optional[str] + is_dictionary: bool + + def __init__(self, schema_object: JsonSchemaObject, root: JsonSchemaParser): + super().__init__(root) + self.schema_object = schema_object + self.current_stage = ObjectParsingStage.START_OBJECT + self.root = root + self.existing_keys = [] + self.current_key = None + # Javascript objects represent both classes and dictionaries, so we need to know which one we are parsing + self.is_dictionary = self.schema_object.properties is None + + def clone(self) -> 'ObjectParsingState': + clone = ObjectParsingState(self.schema_object, self.root) + clone.current_stage = self.current_stage + clone.existing_keys = self.existing_keys[:] + clone.current_key = self.current_key + clone.is_dictionary = self.is_dictionary + return clone + + def add_character(self, new_character: str) -> CharacterLevelParser: + if new_character.strip() == "": + # In object scope, whitespaces can be ignored + return self + self = self.clone() # Immutability requirement + if ( + self.current_stage == ObjectParsingStage.START_OBJECT + and new_character == "{" + ): + self.current_stage = ObjectParsingStage.PARSING_KEY_OR_END + elif self.current_stage == ObjectParsingStage.PARSING_KEY_OR_END: + if new_character == "}": + self.current_stage = ObjectParsingStage.END_OBJECT + if new_character == '"': + possible_keys = None + if not self.is_dictionary: + required_keys = self.schema_object.required or [] + next_required_key = next((key for key in required_keys if key not in self.existing_keys), None) + if self.root.config.force_json_field_order and next_required_key: + possible_keys = [next_required_key] + else: + possible_keys = list(self.schema_object.properties.keys()) + possible_keys = list( + set(possible_keys).difference(self.existing_keys) + ) + # We send require_opening_quote=True and then add_character('"') instead of require_opening_quote=False + # Because there is a difference between "don't need a quote" and "received it before creating the parser" + key_parser = StringParsingState( + self.root, possible_keys, require_opening_quote=True, require_closing_quote=True + ) + key_parser = key_parser.add_character('"') + self.root.context.active_parser.object_stack.append(key_parser) + self.current_stage = ObjectParsingStage.PARSING_KEY_VALUE_SEPARATOR + elif self.current_stage == ObjectParsingStage.PARSING_KEY_VALUE_SEPARATOR: + if new_character == ":": + self.current_stage = ObjectParsingStage.PARSING_VALUE + self.current_key = self.root.context.active_parser.last_parsed_string + self.existing_keys.append(self.current_key) + if self.is_dictionary: + if self.schema_object.additionalProperties: + value_schema = self.schema_object.additionalProperties + else: + value_schema = JsonSchemaParser.ANY_JSON_OBJECT_SCHEMA + else: + value_schema = self.schema_object.properties[self.current_key] + self.current_key_parser = get_parser( + self.root, value_schema + ) + self.root.context.active_parser.object_stack.append(self.current_key_parser) + self.current_key_parser = None + elif self.current_stage == ObjectParsingStage.PARSING_VALUE: + # If we recieve a character during parsing value, it means that its the finishing character + # of the value parser + if new_character == '"': + self.current_stage = ObjectParsingStage.PARSING_SEPARATOR_OR_END + elif new_character == ",": + self.current_stage = ObjectParsingStage.PARSING_KEY_OR_END + elif new_character == "}": + self.current_stage = ObjectParsingStage.END_OBJECT + elif self.current_stage == ObjectParsingStage.PARSING_SEPARATOR_OR_END: + if new_character == ",": + self.current_stage = ObjectParsingStage.PARSING_KEY_OR_END + elif new_character == "}": + self.current_stage = ObjectParsingStage.END_OBJECT + return self + + def get_allowed_characters(self) -> str: + possible_keys = ( + list(self.schema_object.properties.keys()) + if not self.is_dictionary + else None + ) + required_keys = self.schema_object.required or [] + can_end = set(self.existing_keys).issuperset(required_keys) + can_parse_key = self.is_dictionary or set(possible_keys).difference( + self.existing_keys + ) + + possible_characters = [c for c in WHITESPACE_CHARACTERS] + if self.current_stage == ObjectParsingStage.START_OBJECT: + possible_characters.append('{') + elif self.current_stage == ObjectParsingStage.PARSING_KEY_OR_END: + if can_end: + possible_characters.append('}') + if can_parse_key: + possible_characters.append('"') + elif self.current_stage == ObjectParsingStage.PARSING_KEY_VALUE_SEPARATOR: + possible_characters.append(':') + elif self.current_stage == ObjectParsingStage.PARSING_VALUE: + # Sometimes the value parser considers finishing, so it needs to know which continuations are possible + if can_end: + possible_characters.append('}') + if can_parse_key: + possible_characters.append(',') + elif self.current_stage == ObjectParsingStage.PARSING_SEPARATOR_OR_END: + if can_end: + possible_characters.append('}') + if can_parse_key: + possible_characters.append(',') + return "".join(possible_characters) + + def can_end(self) -> bool: + return self.current_stage == ObjectParsingStage.END_OBJECT + + +class StringParsingStage: + START_TOKEN = "StartToken" + PARSING_STRING = "ParsingString" + END_TOKEN = "EndToken" + + +class PrimitiveParsingState(BaseParsingState): + def __init__(self, root: JsonSchemaParser): + super().__init__(root) + self.stage = StringParsingStage.START_TOKEN + self.parsed_string = "" + + def _clone(self) -> "PrimitiveParsingState": + raise NotImplementedError() + + def add_character(self, new_character: str) -> "PrimitiveParsingState": + new = self._clone() + new.parsed_string += new_character + return new + + def can_end(self) -> bool: + return True + + +class NumberParsingState(PrimitiveParsingState): + def __init__( + self, + root: JsonSchemaParser, + allow_floating_point: bool, + ): + super().__init__(root) + self.allow_floating_point = allow_floating_point + self.seen_decimal_point = False + self.seen_whitespace_after_digits = False + self.seen_exponent = False + self.seen_digit = False + + def _clone(self) -> "NumberParsingState": + clone = NumberParsingState(self.root, self.allow_floating_point) + clone.parsed_string = self.parsed_string + clone.seen_decimal_point = self.seen_decimal_point + clone.seen_whitespace_after_digits = self.seen_whitespace_after_digits + clone.seen_exponent = self.seen_exponent + clone.seen_digit = self.seen_digit + return clone + + def add_character(self, new_character: str) -> CharacterLevelParser: + if not self.parsed_string and new_character in WHITESPACE_CHARACTERS: + return self + self = cast(NumberParsingState, super().add_character(new_character)) + if new_character in WHITESPACE_CHARACTERS: + if self.parsed_string: + self.seen_whitespace_after_digits = True + return self + if new_character == ".": + if not self.parsed_string or len(self.parsed_string) == 1: + raise LMFormatEnforcerException("Numbers cannot start with a decimal point.") + if self.seen_decimal_point: + raise LMFormatEnforcerException("Numbers cannot contain more than two decimal points.") + self.seen_decimal_point = True + elif new_character in "eE": + if self.seen_exponent or not self.seen_digit: + raise LMFormatEnforcerException("Invalid number format") + self.seen_exponent = True + elif new_character.isdigit(): + self.seen_digit = True + return self + + def get_allowed_characters(self) -> str: + if self.seen_whitespace_after_digits: + return WHITESPACE_CHARACTERS + allowed_characters = "0123456789" + if not self.parsed_string: + allowed_characters += "-" + WHITESPACE_CHARACTERS + if self.parsed_string and len(self.parsed_string) == 1 and self.parsed_string[0] == "0": + allowed_characters = WHITESPACE_CHARACTERS + if self.parsed_string and len(self.parsed_string) == 2 and self.parsed_string == "-0": + allowed_characters = "." + WHITESPACE_CHARACTERS + if self.parsed_string and self.parsed_string[-1] in "eE": + allowed_characters += "-+" + if self.seen_digit and not self.seen_exponent: + allowed_characters += "eE" + if self.allow_floating_point and not self.seen_decimal_point and self.seen_digit and not self.seen_exponent: + allowed_characters += "." + if self.parsed_string and self.parsed_string[-1].isdigit(): + allowed_characters += WHITESPACE_CHARACTERS + return allowed_characters + + def can_end(self) -> bool: + if self.seen_exponent and self.parsed_string[-1] in "eE+-": + return False + return bool(self.parsed_string) and (self.parsed_string[-1].isdigit() or self.seen_whitespace_after_digits) + + +class StringParsingState(PrimitiveParsingState): + allowed_strings: List[str] + parsed_string: str + seen_closing_quote: bool + seen_opening_quote: bool + min_length: Optional[int] + max_length: Optional[int] + pattern: Optional[str] + regex_parser: Optional[RegexParser] + + def __init__( + self, + root: JsonSchemaParser, + allowed_strings: List[str], + require_opening_quote: bool, + require_closing_quote: bool = True, + min_length: Optional[int]=None, + max_length: Optional[int]=None, + pattern: Optional[str]=None, + regex_parser: Optional[RegexParser]=None, + ): + super().__init__(root) + self.allowed_strings = allowed_strings + self.seen_closing_quote = False + self.seen_opening_quote = not require_opening_quote + self.require_closing_quote = require_closing_quote + self.require_opening_quote = require_opening_quote + self.min_length = min_length + self.max_length = max_length + self.pattern = pattern + if self.pattern and (self.min_length or self.max_length): + raise LMFormatEnforcerException("String schema contains both a pattern and a min/max length, which is not currently supported") + self.regex_parser = regex_parser + if self.pattern and not regex_parser: + if self.pattern not in self.root.context.regex_parser_cache: + self.root.context.regex_parser_cache[self.pattern] = RegexParser(self.pattern, self.root.config) + self.regex_parser = self.root.context.regex_parser_cache[self.pattern] + + + def _clone(self) -> "StringParsingState": + clone = StringParsingState( + self.root, + self.allowed_strings, + self.require_opening_quote, + self.require_closing_quote, + self.min_length, + self.max_length, + self.pattern, + self.regex_parser + ) + clone.parsed_string = self.parsed_string + clone.seen_closing_quote = self.seen_closing_quote + clone.seen_opening_quote = self.seen_opening_quote + return clone + + def add_character(self, new_character: str): + if (not self.parsed_string or self.seen_closing_quote) and new_character in WHITESPACE_CHARACTERS: + return self + self = cast(StringParsingState, super().add_character(new_character)) + if new_character == '"': + if not self.seen_opening_quote: + self.seen_opening_quote = True + self.parsed_string = "" + else: + self.seen_closing_quote = True + self.parsed_string = self.parsed_string[:-1] + if self.regex_parser and new_character != '"' and self.seen_opening_quote and not self.seen_closing_quote: + self.regex_parser = self.regex_parser.add_character(new_character) + if new_character == BACKSLASH: + # After a backslack we immediately have the escaping character, and if its 'u', we have 4 hex digits + escaping_character_parsers: List[CharacterLevelParser] = [StringParser(c) for c in BACKSLASH_ESCAPING_CHARACTERS] + hex_digit_parser: CharacterLevelParser = UnionParser([StringParser(c) for c in "0123456789abcdefABCDEF"]) + unicode_components: List[CharacterLevelParser] = list([StringParser("u")] + [hex_digit_parser] * 4) + unicode_escape_parser: CharacterLevelParser = SequenceParser(unicode_components) + json_escaping_parser = UnionParser(escaping_character_parsers + [unicode_escape_parser]) + self.root.context.active_parser.object_stack.append(json_escaping_parser) + return self + + def get_allowed_characters(self) -> str: + if not self.seen_opening_quote: + return '"' + WHITESPACE_CHARACTERS + if self.seen_closing_quote: + return WHITESPACE_CHARACTERS + if self.regex_parser: + regex_chars = self.regex_parser.get_allowed_characters() + # We don't currently support regexes with quotes or escaping backslashes, so we remove them from the allowed characters + regex_chars = regex_chars.replace('"', '').replace(BACKSLASH, '') + if self.regex_parser.can_end(): + regex_chars += '"' + return regex_chars + if self.allowed_strings: + allowed_continuations = [ + s[len(self.parsed_string) :] + for s in self.allowed_strings + if s.startswith(self.parsed_string) + ] + allowed_next_characters = [allowed_continuation[0] for allowed_continuation in allowed_continuations if len(allowed_continuation) > 0] + allowed_next_characters = list(set(allowed_next_characters)) + if self.parsed_string in self.allowed_strings and self.require_closing_quote: + allowed_next_characters.append('"') + if (not self.parsed_string) and (not self.seen_opening_quote or not self.require_opening_quote): + allowed_next_characters.extend(WHITESPACE_CHARACTERS) + return "".join(allowed_next_characters) + else: + if self.min_length is not None and len(self.parsed_string) < self.min_length: + return self.root.context.alphabet_without_quotes + BACKSLASH + if self.max_length is not None and len(self.parsed_string) >= self.max_length: + return '"' + return self.root.config.alphabet + BACKSLASH + + def can_end(self) -> bool: + if self.require_closing_quote: + return self.seen_closing_quote + else: + if self.allowed_strings: + return self.parsed_string in self.allowed_strings + else: + return bool(self.parsed_string) + + +class ListParsingState(PrimitiveParsingState): + list_member_type: JsonSchemaObject + seen_list_opener: bool = False + seen_list_closer: bool = False + num_items_seen: int = 0 + + def __init__( + self, + root: JsonSchemaParser, + list_member_type: JsonSchemaObject, + min_items: Optional[int], + max_items: Optional[int], + ): + super().__init__(root) + self.list_member_type = list_member_type + self.min_items = min_items + self.max_items = max_items + default_max = root.config.max_json_array_length + if self.max_items is None and default_max > 0 and (min_items is None or min_items < default_max): + self.max_items = default_max + + def _clone(self) -> PrimitiveParsingState: + new = ListParsingState(self.root, self.list_member_type, self.min_items, self.max_items) + new.parsed_string = self.parsed_string + new.num_items_seen = self.num_items_seen + new.seen_list_opener = self.seen_list_opener + new.seen_list_closer = self.seen_list_closer + return new + + def add_character(self, new_character: str) -> "ListParsingState": + self = cast(ListParsingState, super().add_character(new_character)) + if new_character == "[": + self.seen_list_opener = True + item_parser = get_parser(self.root, self.list_member_type) + requires_items = self.min_items is not None and self.min_items > 0 + if requires_items: + parser_to_push = item_parser + else: + # If we don't require items, we can also end immediately, the Union + ForceStopParser combination achieves this + empty_list_parser = ForceStopParser(allow_whitespace=True) + if isinstance(item_parser, UnionParser): + item_parser.parsers.append(empty_list_parser) + parser_to_push = item_parser + else: + parser_to_push = UnionParser([item_parser, empty_list_parser]) + self.root.context.active_parser.object_stack.append(parser_to_push) + elif new_character == "]": + self.seen_list_closer = True + elif new_character == ",": + if not self.seen_list_closer: + self.num_items_seen += 1 + + self.root.context.active_parser.object_stack.append( + get_parser( + self.root, + self.list_member_type, + ) + ) + return self + + def get_allowed_characters(self) -> str: + if not self.seen_list_opener: + return "[" + WHITESPACE_CHARACTERS + elif not self.seen_list_closer: + return self.get_allowed_control_characters() + WHITESPACE_CHARACTERS + else: + return "" + + def can_end(self) -> bool: + return self.seen_list_closer + + def get_allowed_control_characters(self): + num_items = self.num_items_seen + top_parser = self.root.context.active_parser.object_stack[-1] + is_on_top = top_parser == self or isinstance(top_parser, UnionParser) and self in top_parser.parsers + if (not is_on_top) and self.root.context.active_parser.last_non_whitespace_character != "[": + # If there is an active parser above us, and the last character is not [, + # there is an active item parser on the stack that we did not count yet. + num_items += 1 + control_characters = "" + has_enough_items = self.min_items is None or num_items >= self.min_items + can_add_another_item = self.max_items is None or num_items < self.max_items + + if num_items > 0 and can_add_another_item: + control_characters += "," + if has_enough_items: + control_characters += "]" + return control_characters + diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/regexparser.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/regexparser.py new file mode 100644 index 0000000000000000000000000000000000000000..ca40fa6eee562b5d033edb30e25781984ab94f91 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/regexparser.py @@ -0,0 +1,85 @@ +from typing import Dict, Hashable, Optional, Union, List +import interegular +from interegular.fsm import anything_else + +from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig + +class RegexParser(CharacterLevelParser): + """RegexParser is an example CharacterLevelParser that only allows strings that match a given regular expression.""" + + UNINITIALIZED_STATE = -1 + INVALID_STATE = -2 + + class _Context: + pattern: interegular.FSM + anything_else_characters: str + state_character_cache: Dict[int, str] + + context: _Context + current_state: int + + def __init__(self, pattern: Union[str, _Context], config: Optional[CharacterLevelParserConfig] = None, current_state: int = UNINITIALIZED_STATE): + super().__init__(config) + if isinstance(pattern, str): + self.context = RegexParser._Context() + self.context.pattern = interegular.parse_pattern(pattern).to_fsm() + self.context.state_character_cache = {} + self._update_alphabet(self.config.alphabet) + else: + self.context = pattern + self.current_state: int = self.context.pattern.initial if current_state == RegexParser.UNINITIALIZED_STATE else current_state + + def add_character(self, new_character: str) -> 'RegexParser': + if self.current_state == RegexParser.INVALID_STATE: + return self + + state = self.current_state + fsm = self.context.pattern + # Mostly taken from FSM.accept() + symbol = new_character + if anything_else in fsm.alphabet and not symbol in fsm.alphabet: + symbol = anything_else + transition = fsm.alphabet[symbol] + + try: + # Prefer try-catch to checking if transition exists to avoid double lookup perf hit in valid case + state = fsm.map[state][transition] # type: ignore + return RegexParser(self.context, self.config, state) + except KeyError: + # Missing transition = transition to dead state + return RegexParser(self.context, self.config, RegexParser.INVALID_STATE) + + def can_end(self) -> bool: + return self.current_state in self.context.pattern.finals or self.current_state == RegexParser.INVALID_STATE + + def get_allowed_characters(self) -> str: + if self.current_state not in self.context.pattern.map: + return '' + if self.current_state not in self.context.state_character_cache: + allowed_characters = [] + state_map = self.context.pattern.map[self.current_state] + for symbol_idx in state_map: + symbols: List[str] = self.context.pattern.alphabet.by_transition[symbol_idx] + for symbol in symbols: + if symbol == anything_else: + allowed_characters.append(self.context.anything_else_characters) + else: + allowed_characters.append(symbol) + self.context.state_character_cache[self.current_state] = "".join(allowed_characters) + return self.context.state_character_cache[self.current_state] + + def cache_key(self) -> Optional[Hashable]: + # If we are in the same regex fsm state, the allowed next tokens are the same ones + return self.current_state + + def _update_alphabet(self, new_alphabet: str): + if self.context: + not_anything_else_characters = set([c for c in self.context.pattern.alphabet.keys() if c != anything_else]) + self.context.anything_else_characters = "".join([c for c in new_alphabet if c not in not_anything_else_characters]) + + @CharacterLevelParser.config.setter + def config(self, new_config: CharacterLevelParserConfig): + CharacterLevelParser.config.fset(self, new_config) # Original set + self._update_alphabet(new_config.alphabet) + + diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/tokenenforcer.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/tokenenforcer.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2534ed494ee517bf14cf2fe1b353dc76f23726 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/tokenenforcer.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass, field +import sys +from typing import Callable, Dict, Hashable, List, Optional, Tuple, Union +import logging + +from .exceptions import LMFormatEnforcerException +from .characterlevelparser import CharacterLevelParser, ForceStopParser, CharacterLevelParserConfig +from .tokenizerprefixtree import TokenizerPrefixTree, TokenizerPrefixTreeNode + + +class TokenEnforcerTokenizerData: + """TokenEnforcerTokenizerData contains all of the preprocessing for preparing the TokenEnforcer to work with a + specific tokenizer. It does some calculations, so it is recommended to reuse it for multiple TokenEnforcers""" + def __init__(self, + regular_tokens: List[Tuple[int, str, bool]], + decoder: Callable[[List[int]], str], + eos_token_id: Union[int, List[int]]): + """ + Create the tokenizer data that the TokenEnforcer needs. This can be reused for multiple TokenEnforcers if they work with the same tokenizer. + :param regular_tokens: A list of tuples (token_id, token_string, is_new_word_token) for all the regular (not special) tokens in the tokenizer vocabulary. + Note that token_string is expected to include leading / trailing whitespaces if relevant. + :param decoder: A function that decodes a list of token ids into a string. + :param eos_token_id: The token id(s) of the end-of-string token(s). + """ + self.regular_tokens = regular_tokens + self.tokenizer_tree = TokenizerPrefixTree(regular_tokens) + self.decoder = decoder + self.eos_token_id = eos_token_id + self.tokenizer_alphabet = "".join(token_str for token_str in self.tokenizer_tree.root.children.keys() if len(token_str) == 1) + + +class TokenEnforcer: + """TokenEnforcer provides a token filtering mechanism, given a CharacterLevelParser and some information about the tokenizer. + It is the main entry point for extending lm-format-enforcer to new inference libraries. See __init__() and get_allowed_tokens()""" + @dataclass + class OutputTensorState: + parser: CharacterLevelParser + allowed_tokens: List[int] = field(default_factory=list) + current_word_tokens: List[int] = field(default_factory=list) + + def __init__(self, tokenizer_data: TokenEnforcerTokenizerData, parser: CharacterLevelParser): + """ + Create a new TokenEnforcer. + :param tokenizer_data: Per tokenizer data that the token enforcer needs in order to operate. + :param parser: A CharacterLevelParser that defines the allowed strings. + """ + self.prefix_states: Dict[Tuple, TokenEnforcer.OutputTensorState] = {} + self.root_parser = parser + self.tokenizer_tree = tokenizer_data.tokenizer_tree + self.decoder = tokenizer_data.decoder + self.eos_token_id = tokenizer_data.eos_token_id + self.regular_tokens = tokenizer_data.regular_tokens + self.allowed_token_cache: Dict[Hashable, List[int]] = {} + + config = CharacterLevelParserConfig(alphabet=tokenizer_data.tokenizer_alphabet) + parser.config = config + + def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]: + """ + Get a list of allowed tokens, given a list of tokens that were already generated. + :param token_sequence: The tokens that were already generated, and the next token will be generated for. + :return: A list of token ids that are allowed to be selected next. + """ + # In order to elegantly support beam search and batching, we don't store per-batch information. + # Instead, we store a hash of all the states (unique token tensors) we encountered so far. + # When we encounter a new unique token tensor, we find the token tensor that led to it, and continue from there. + sent_tuple = tuple(token_sequence) + prev_step_tuple = sent_tuple[:-1] + + if sent_tuple in self.prefix_states: + # We already calculated for this node, return cached list + return self.prefix_states[sent_tuple].allowed_tokens + elif prev_step_tuple not in self.prefix_states: + # We have not encountered the tensor up to the before-last entry. This means that this is the first call - the instruction / prompt tensor. + # Initialize the root node + state = TokenEnforcer.OutputTensorState(parser=self.root_parser) + self.prefix_states[sent_tuple] = state + self._compute_allowed_tokens(sent_tuple, state) + return state.allowed_tokens + else: + # Find the state that led to this node. We explicitly don't use the concept of "timestep" because of beam search + prev_step_state = self.prefix_states[prev_step_tuple] + new_state = self._apply_new_characters(prev_step_state, token_sequence) + self.prefix_states[sent_tuple] = new_state + self._compute_allowed_tokens(sent_tuple, new_state) + return new_state.allowed_tokens + + def _compute_allowed_tokens(self, state_tokens: Tuple, state: 'TokenEnforcer.OutputTensorState'): + try: + allowed_tokens: List[int] = [] + cache_key = state.parser.cache_key() + if cache_key is not None and cache_key in self.allowed_token_cache: + state.allowed_tokens = self.allowed_token_cache[cache_key] + return + shortcut_key = state.parser.shortcut_key() + self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key) + if state.parser.can_end(): + allowed_tokens.extend(self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id]) + if not allowed_tokens: + raise ValueError(f"Parser reached state with no allowed tokens") + # root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser) + # print(f"Allowing {len(allowed_tokens)} tokens after {state.str_so_far[len(root_state.str_so_far):]}") + state.allowed_tokens = allowed_tokens + if cache_key is not None: + self.allowed_token_cache[cache_key] = allowed_tokens + except LMFormatEnforcerException: + # Getting an LMFormatEnforcerException means that we know what the user did wrong, + # and we can give a nice error message for them to fix. + raise + except Exception: + # Other exceptions are potential bugs and should be reported + logging.basicConfig(level=logging.ERROR) # Initialize if no loggers + prefix = self.decoder(list(state_tokens)) + logging.exception(f"Unknown LMFormatEnforcer Problem. Prefix: '{prefix}'\n" + "Terminating the parser. Please open an issue at \n" + "https://github.com/noamgat/lm-format-enforcer/issues with the prefix and " + "CharacterLevelParser parameters") + state.allowed_tokens = self.eos_token_id if isinstance(self.eos_token_id, list) else [self.eos_token_id] + + def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: TokenizerPrefixTreeNode, allowed_tokens: List[int], shortcut_key: Optional[Hashable]): + allowed_tokens.extend(tree_node.tokens) + allowed_characters = parser.get_allowed_characters() + relevant_characters = tree_node.children.keys() + # This next line is the heart of the traversal algorithm. We only explore paths that are shared by both the parser and the tokenizer. + characters_to_explore = set(relevant_characters).intersection(allowed_characters) + + # Performance optimization: If we are in JSON freetext, all of the tokens that don't contain quote, or end with quote, are legal, so we take + # their cached list. If the quote character is allowed, we only need to dynamically explore the cases where the string starts with a quote. + # This breaks the elegance of the API, but otherwise it is a huge performance hit. + if isinstance(shortcut_key, tuple) and shortcut_key[0] == 'json_freetext': + assert len(shortcut_key) == 4 + _, cur_len, min_len, max_len = shortcut_key + cache = self.tokenizer_tree.json_freetext_tokens + + min_remaining = min(cache.max_token_len, max(0, min_len - cur_len)) # no " allowed before this many chars + max_allowed_len = min(cache.max_token_len, max_len - cur_len) # max new characters allowed (before ") + + allowed_tokens.extend(cache.lookup_allowed_tokens(min_remaining, max_allowed_len)) + characters_to_explore = characters_to_explore.intersection(['"']) + + for character in characters_to_explore: + next_parser = parser.add_character(character) + next_tree_node = tree_node.children[character] + self._collect_allowed_tokens(next_parser, next_tree_node, allowed_tokens, None) + + def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_sequence: List[int]): + new_state = TokenEnforcer.OutputTensorState(parser=state.parser) + new_token = token_sequence[-1] + if new_token in self.tokenizer_tree.new_word_tokens: + new_state.current_word_tokens = [new_token] + new_characters = self.tokenizer_tree.tokens_to_strs[new_token] + else: + new_state.current_word_tokens = state.current_word_tokens + [new_token] + prev_decoded = self.decoder(state.current_word_tokens) + new_decoded = self.decoder(new_state.current_word_tokens) + new_characters = new_decoded[len(prev_decoded):] + for character in new_characters: + try: + new_state.parser = new_state.parser.add_character(character) + except Exception as e: + # This can happen in beam / batch scenarios, when some of the batches finished but others are continuing. + logging.debug(f"Received an invalid character '{character}', switching to ForceStopParser (Exception:{e})") + new_state.parser = ForceStopParser() + return new_state + + diff --git a/.venv/lib/python3.11/site-packages/lmformatenforcer/tokenizerprefixtree.py b/.venv/lib/python3.11/site-packages/lmformatenforcer/tokenizerprefixtree.py new file mode 100644 index 0000000000000000000000000000000000000000..aa04f543613efb7233d8d6b29be5e98dcb6d8ebc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/lmformatenforcer/tokenizerprefixtree.py @@ -0,0 +1,132 @@ +from collections import OrderedDict +from typing import Dict, List, Set, Tuple +import json + +class TokenizerPrefixTreeNode: + def __init__(self) -> None: + self.tokens: List[int] = [] + self.children: Dict[str, TokenizerPrefixTreeNode] = {} + + +class JsonFreetextTokenCache: + """ + JSON string can contain almost any unicode character, so creating a list of allowed tokens is very expensive. + The list can be cached, but JSON Schema also allows 'minLength' and 'maxLength' constraint on the string, + that make some tokens illegal depending on how long the generated string is already. This class precalculates + a separate allowlist for all possible constraint states up to maximum token length (16 in Llama, for example). + After deduplication, this results in about ~75 lists for the Llama tokenizer. + """ + class _StringLengthTokenCache: + """This is an internal data structure, that given a list of string+token pairs, + can quickly return all token ids of strings between certain lengths""" + def __init__(self): + self.tokens: List[int] = [] + self.first_index_geq_than_length: List[int] = [0] + + def build(self, token_strs_to_idx: List[Tuple[str, int]]): + # TODO: If this becomes a performance bottleneck, bucket sort instead. + token_strs_to_idx = sorted(token_strs_to_idx, key=lambda p:len(p[0])) + self.tokens = [pair[1] for pair in token_strs_to_idx] + # self.token_strs = [pair[0] for pair in token_strs_to_idx] # For debugging + token_lengths = [len(pair[0]) for pair in token_strs_to_idx] + for idx, token_length in enumerate(token_lengths): + while len(self.first_index_geq_than_length) <= token_length: + self.first_index_geq_than_length.append(idx) + self.first_index_geq_than_length.append(len(token_lengths)) + + def get_indices_between_length(self, min_length=-1, max_length=-1) -> List[int]: + if min_length >= len(self.first_index_geq_than_length): + return [] + start_index = self.first_index_geq_than_length[min_length] if min_length > 0 else 0 + if max_length == 0: + end_index = 0 + elif max_length + 1 < len(self.first_index_geq_than_length): + end_index = self.first_index_geq_than_length[max_length + 1] + else: + end_index = len(self.tokens) + return self.tokens[start_index:end_index] + + def __init__(self, ) -> None: + self.token_num_to_str: Dict[int, str] = {} + self.allowlist_cache: Dict[Tuple[int, int], Tuple[int, ...]] = {} + self.max_token_len = 0 + self.regular_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache() + self.quote_tokens_length_cache = JsonFreetextTokenCache._StringLengthTokenCache() + + def add_token(self, token_str: str, token_int: int): + assert not self.allowlist_cache, "Cannot add more tokens after allowlists were precalculated" + + has_non_trailing_backslash = "\\" in token_str[:-1] + has_quote_before_end = '"' in token_str[0:-1] + has_newline = "\n" in token_str or "\r" in token_str + if has_non_trailing_backslash or has_quote_before_end or has_newline: + try: + json.loads(f'"{token_str}"') + except json.decoder.JSONDecodeError: + return # Illegal inside JSON string, skip this token + + if len(token_str) == 0: + # Tokens that don't decode to anything should be ignored, will not be allowed in json freetext fields. + # TODO: Should we instead ALWAYS allow them? + return + + self.token_num_to_str[token_int] = token_str + + def lookup_allowed_tokens(self, min_remaining: int, max_len: int) -> Tuple[int, ...]: + """ + Get the list of tokens that are allowed within a JSON string, such that: + 1. all candidate tokens are at most `max_len` characters long (excluding the trailing quote), and + 2. if a token ends with a quote, it's at least `min_remaining` chars long (excluding the quote). + """ + cache_key = (min_remaining, max_len) + if cache_key not in self.allowlist_cache: + tokens_with_quote = self.quote_tokens_length_cache.get_indices_between_length(min_remaining + 1, max_len + 1) + tokens_without_quote = self.regular_tokens_length_cache.get_indices_between_length(-1, max_len) + combined = tokens_with_quote + tokens_without_quote + self.allowlist_cache[cache_key] = tuple(combined) + return self.allowlist_cache[cache_key] + + def freeze(self) -> None: + """ + Precalculate token allowlists for all valid combinations of `min_remaining` and `max_len` + based on the tokens that were added with `add_token()`. + """ + all_tokens: List[Tuple[str, int]] = list((s, n) for n,s in self.token_num_to_str.items()) + assert all_tokens, "Cannot precalculate allowlists for an empty token list" + assert not any(pair[0] == '' for pair in all_tokens), "Tokenizer must not contain empty tokens" + + regular_tokens: List[Tuple[str, int]] = [] + quote_tokens: List[Tuple[str, int]] = [] + for pair in all_tokens: + if pair[0].endswith('"'): + quote_tokens.append(pair) + else: + regular_tokens.append(pair) + + self.regular_tokens_length_cache.build(regular_tokens) + self.quote_tokens_length_cache.build(quote_tokens) + self.max_token_len = max(len(self.regular_tokens_length_cache.first_index_geq_than_length), + len(self.quote_tokens_length_cache.first_index_geq_than_length)) + del self.token_num_to_str + + +class TokenizerPrefixTree: + def __init__(self, regular_tokens: List[Tuple[int, str, bool]]): + self.root = TokenizerPrefixTreeNode() + self.json_freetext_tokens = JsonFreetextTokenCache() + self.new_word_tokens: Set[int] = set() + self.tokens_to_strs = {token_idx: token_str for token_idx, token_str, _ in regular_tokens} + for token_idx, decoded, is_new_word in regular_tokens: + self._add_token_to_tree(decoded, token_idx, self.root) + self.json_freetext_tokens.add_token(decoded, token_idx) + if is_new_word: + self.new_word_tokens.add(token_idx) + + self.json_freetext_tokens.freeze() + + def _add_token_to_tree(self, token_str: str, token_idx: int, node: TokenizerPrefixTreeNode): + for character in token_str: + if character not in node.children: + node.children[character] = TokenizerPrefixTreeNode() + node = node.children[character] + node.tokens.append(token_idx)