koichi12 commited on
Commit
4d5d235
·
verified ·
1 Parent(s): 71e6673

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/attr/__init__.py +104 -0
  3. .venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/attr/_compat.py +94 -0
  8. .venv/lib/python3.11/site-packages/attr/_config.py +31 -0
  9. .venv/lib/python3.11/site-packages/attr/_funcs.py +468 -0
  10. .venv/lib/python3.11/site-packages/attr/_version_info.pyi +9 -0
  11. .venv/lib/python3.11/site-packages/attr/exceptions.pyi +17 -0
  12. .venv/lib/python3.11/site-packages/attr/validators.pyi +86 -0
  13. .venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so +3 -0
  14. .venv/lib/python3.11/site-packages/outlines/__init__.py +20 -0
  15. .venv/lib/python3.11/site-packages/outlines/_version.py +16 -0
  16. .venv/lib/python3.11/site-packages/outlines/base.py +299 -0
  17. .venv/lib/python3.11/site-packages/outlines/caching.py +179 -0
  18. .venv/lib/python3.11/site-packages/outlines/fsm/__init__.py +0 -0
  19. .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/outlines/fsm/guide.py +276 -0
  25. .venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py +83 -0
  26. .venv/lib/python3.11/site-packages/outlines/fsm/parsing.py +1127 -0
  27. .venv/lib/python3.11/site-packages/outlines/fsm/types.py +81 -0
  28. .venv/lib/python3.11/site-packages/outlines/function.py +117 -0
  29. .venv/lib/python3.11/site-packages/outlines/generate/__init__.py +8 -0
  30. .venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/outlines/generate/api.py +623 -0
  34. .venv/lib/python3.11/site-packages/outlines/generate/cfg.py +54 -0
  35. .venv/lib/python3.11/site-packages/outlines/generate/choice.py +59 -0
  36. .venv/lib/python3.11/site-packages/outlines/generate/format.py +47 -0
  37. .venv/lib/python3.11/site-packages/outlines/generate/fsm.py +31 -0
  38. .venv/lib/python3.11/site-packages/outlines/generate/generator.py +312 -0
  39. .venv/lib/python3.11/site-packages/outlines/generate/json.py +115 -0
  40. .venv/lib/python3.11/site-packages/outlines/generate/regex.py +59 -0
  41. .venv/lib/python3.11/site-packages/outlines/generate/text.py +50 -0
  42. .venv/lib/python3.11/site-packages/outlines/grammars.py +14 -0
  43. .venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark +18 -0
  44. .venv/lib/python3.11/site-packages/outlines/grammars/common.lark +83 -0
  45. .venv/lib/python3.11/site-packages/outlines/grammars/json.lark +19 -0
  46. .venv/lib/python3.11/site-packages/outlines/processors/__init__.py +7 -0
  47. .venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py +159 -0
.gitattributes CHANGED
@@ -249,3 +249,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
249
  .venv/lib/python3.11/site-packages/pycountry/locales/tr/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
250
  .venv/lib/python3.11/site-packages/pycountry/locales/kn/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
251
  .venv/lib/python3.11/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
249
  .venv/lib/python3.11/site-packages/pycountry/locales/tr/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
250
  .venv/lib/python3.11/site-packages/pycountry/locales/kn/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
251
  .venv/lib/python3.11/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
252
+ .venv/lib/python3.11/site-packages/torchvision.libs/libpng16.7f72a3c5.so.16 filter=lfs diff=lfs merge=lfs -text
253
+ .venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/attr/__init__.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+ """
4
+ Classes Without Boilerplate
5
+ """
6
+
7
+ from functools import partial
8
+ from typing import Callable, Literal, Protocol
9
+
10
+ from . import converters, exceptions, filters, setters, validators
11
+ from ._cmp import cmp_using
12
+ from ._config import get_run_validators, set_run_validators
13
+ from ._funcs import asdict, assoc, astuple, has, resolve_types
14
+ from ._make import (
15
+ NOTHING,
16
+ Attribute,
17
+ Converter,
18
+ Factory,
19
+ _Nothing,
20
+ attrib,
21
+ attrs,
22
+ evolve,
23
+ fields,
24
+ fields_dict,
25
+ make_class,
26
+ validate,
27
+ )
28
+ from ._next_gen import define, field, frozen, mutable
29
+ from ._version_info import VersionInfo
30
+
31
+
32
+ s = attributes = attrs
33
+ ib = attr = attrib
34
+ dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
35
+
36
+
37
+ class AttrsInstance(Protocol):
38
+ pass
39
+
40
+
41
+ NothingType = Literal[_Nothing.NOTHING]
42
+
43
+ __all__ = [
44
+ "NOTHING",
45
+ "Attribute",
46
+ "AttrsInstance",
47
+ "Converter",
48
+ "Factory",
49
+ "NothingType",
50
+ "asdict",
51
+ "assoc",
52
+ "astuple",
53
+ "attr",
54
+ "attrib",
55
+ "attributes",
56
+ "attrs",
57
+ "cmp_using",
58
+ "converters",
59
+ "define",
60
+ "evolve",
61
+ "exceptions",
62
+ "field",
63
+ "fields",
64
+ "fields_dict",
65
+ "filters",
66
+ "frozen",
67
+ "get_run_validators",
68
+ "has",
69
+ "ib",
70
+ "make_class",
71
+ "mutable",
72
+ "resolve_types",
73
+ "s",
74
+ "set_run_validators",
75
+ "setters",
76
+ "validate",
77
+ "validators",
78
+ ]
79
+
80
+
81
+ def _make_getattr(mod_name: str) -> Callable:
82
+ """
83
+ Create a metadata proxy for packaging information that uses *mod_name* in
84
+ its warnings and errors.
85
+ """
86
+
87
+ def __getattr__(name: str) -> str:
88
+ if name not in ("__version__", "__version_info__"):
89
+ msg = f"module {mod_name} has no attribute {name}"
90
+ raise AttributeError(msg)
91
+
92
+ from importlib.metadata import metadata
93
+
94
+ meta = metadata("attrs")
95
+
96
+ if name == "__version_info__":
97
+ return VersionInfo._from_version_string(meta["version"])
98
+
99
+ return meta["version"]
100
+
101
+ return __getattr__
102
+
103
+
104
+ __getattr__ = _make_getattr(__name__)
.venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
.venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc ADDED
Binary file (3.61 kB). View file
 
.venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc ADDED
Binary file (5.03 kB). View file
 
.venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc ADDED
Binary file (2.07 kB). View file
 
.venv/lib/python3.11/site-packages/attr/_compat.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+ import inspect
4
+ import platform
5
+ import sys
6
+ import threading
7
+
8
+ from collections.abc import Mapping, Sequence # noqa: F401
9
+ from typing import _GenericAlias
10
+
11
+
12
+ PYPY = platform.python_implementation() == "PyPy"
13
+ PY_3_9_PLUS = sys.version_info[:2] >= (3, 9)
14
+ PY_3_10_PLUS = sys.version_info[:2] >= (3, 10)
15
+ PY_3_11_PLUS = sys.version_info[:2] >= (3, 11)
16
+ PY_3_12_PLUS = sys.version_info[:2] >= (3, 12)
17
+ PY_3_13_PLUS = sys.version_info[:2] >= (3, 13)
18
+ PY_3_14_PLUS = sys.version_info[:2] >= (3, 14)
19
+
20
+
21
+ if PY_3_14_PLUS: # pragma: no cover
22
+ import annotationlib
23
+
24
+ _get_annotations = annotationlib.get_annotations
25
+
26
+ else:
27
+
28
+ def _get_annotations(cls):
29
+ """
30
+ Get annotations for *cls*.
31
+ """
32
+ return cls.__dict__.get("__annotations__", {})
33
+
34
+
35
+ class _AnnotationExtractor:
36
+ """
37
+ Extract type annotations from a callable, returning None whenever there
38
+ is none.
39
+ """
40
+
41
+ __slots__ = ["sig"]
42
+
43
+ def __init__(self, callable):
44
+ try:
45
+ self.sig = inspect.signature(callable)
46
+ except (ValueError, TypeError): # inspect failed
47
+ self.sig = None
48
+
49
+ def get_first_param_type(self):
50
+ """
51
+ Return the type annotation of the first argument if it's not empty.
52
+ """
53
+ if not self.sig:
54
+ return None
55
+
56
+ params = list(self.sig.parameters.values())
57
+ if params and params[0].annotation is not inspect.Parameter.empty:
58
+ return params[0].annotation
59
+
60
+ return None
61
+
62
+ def get_return_type(self):
63
+ """
64
+ Return the return type if it's not empty.
65
+ """
66
+ if (
67
+ self.sig
68
+ and self.sig.return_annotation is not inspect.Signature.empty
69
+ ):
70
+ return self.sig.return_annotation
71
+
72
+ return None
73
+
74
+
75
+ # Thread-local global to track attrs instances which are already being repr'd.
76
+ # This is needed because there is no other (thread-safe) way to pass info
77
+ # about the instances that are already being repr'd through the call stack
78
+ # in order to ensure we don't perform infinite recursion.
79
+ #
80
+ # For instance, if an instance contains a dict which contains that instance,
81
+ # we need to know that we're already repr'ing the outside instance from within
82
+ # the dict's repr() call.
83
+ #
84
+ # This lives here rather than in _make.py so that the functions in _make.py
85
+ # don't have a direct reference to the thread-local in their globals dict.
86
+ # If they have such a reference, it breaks cloudpickle.
87
+ repr_context = threading.local()
88
+
89
+
90
+ def get_generic_base(cl):
91
+ """If this is a generic class (A[str]), return the generic base for it."""
92
+ if cl.__class__ is _GenericAlias:
93
+ return cl.__origin__
94
+ return None
.venv/lib/python3.11/site-packages/attr/_config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+ __all__ = ["get_run_validators", "set_run_validators"]
4
+
5
+ _run_validators = True
6
+
7
+
8
+ def set_run_validators(run):
9
+ """
10
+ Set whether or not validators are run. By default, they are run.
11
+
12
+ .. deprecated:: 21.3.0 It will not be removed, but it also will not be
13
+ moved to new ``attrs`` namespace. Use `attrs.validators.set_disabled()`
14
+ instead.
15
+ """
16
+ if not isinstance(run, bool):
17
+ msg = "'run' must be bool."
18
+ raise TypeError(msg)
19
+ global _run_validators
20
+ _run_validators = run
21
+
22
+
23
+ def get_run_validators():
24
+ """
25
+ Return whether or not validators are run.
26
+
27
+ .. deprecated:: 21.3.0 It will not be removed, but it also will not be
28
+ moved to new ``attrs`` namespace. Use `attrs.validators.get_disabled()`
29
+ instead.
30
+ """
31
+ return _run_validators
.venv/lib/python3.11/site-packages/attr/_funcs.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+
3
+
4
+ import copy
5
+
6
+ from ._compat import PY_3_9_PLUS, get_generic_base
7
+ from ._make import _OBJ_SETATTR, NOTHING, fields
8
+ from .exceptions import AttrsAttributeNotFoundError
9
+
10
+
11
+ def asdict(
12
+ inst,
13
+ recurse=True,
14
+ filter=None,
15
+ dict_factory=dict,
16
+ retain_collection_types=False,
17
+ value_serializer=None,
18
+ ):
19
+ """
20
+ Return the *attrs* attribute values of *inst* as a dict.
21
+
22
+ Optionally recurse into other *attrs*-decorated classes.
23
+
24
+ Args:
25
+ inst: Instance of an *attrs*-decorated class.
26
+
27
+ recurse (bool): Recurse into classes that are also *attrs*-decorated.
28
+
29
+ filter (~typing.Callable):
30
+ A callable whose return code determines whether an attribute or
31
+ element is included (`True`) or dropped (`False`). Is called with
32
+ the `attrs.Attribute` as the first argument and the value as the
33
+ second argument.
34
+
35
+ dict_factory (~typing.Callable):
36
+ A callable to produce dictionaries from. For example, to produce
37
+ ordered dictionaries instead of normal Python dictionaries, pass in
38
+ ``collections.OrderedDict``.
39
+
40
+ retain_collection_types (bool):
41
+ Do not convert to `list` when encountering an attribute whose type
42
+ is `tuple` or `set`. Only meaningful if *recurse* is `True`.
43
+
44
+ value_serializer (typing.Callable | None):
45
+ A hook that is called for every attribute or dict key/value. It
46
+ receives the current instance, field and value and must return the
47
+ (updated) value. The hook is run *after* the optional *filter* has
48
+ been applied.
49
+
50
+ Returns:
51
+ Return type of *dict_factory*.
52
+
53
+ Raises:
54
+ attrs.exceptions.NotAnAttrsClassError:
55
+ If *cls* is not an *attrs* class.
56
+
57
+ .. versionadded:: 16.0.0 *dict_factory*
58
+ .. versionadded:: 16.1.0 *retain_collection_types*
59
+ .. versionadded:: 20.3.0 *value_serializer*
60
+ .. versionadded:: 21.3.0
61
+ If a dict has a collection for a key, it is serialized as a tuple.
62
+ """
63
+ attrs = fields(inst.__class__)
64
+ rv = dict_factory()
65
+ for a in attrs:
66
+ v = getattr(inst, a.name)
67
+ if filter is not None and not filter(a, v):
68
+ continue
69
+
70
+ if value_serializer is not None:
71
+ v = value_serializer(inst, a, v)
72
+
73
+ if recurse is True:
74
+ if has(v.__class__):
75
+ rv[a.name] = asdict(
76
+ v,
77
+ recurse=True,
78
+ filter=filter,
79
+ dict_factory=dict_factory,
80
+ retain_collection_types=retain_collection_types,
81
+ value_serializer=value_serializer,
82
+ )
83
+ elif isinstance(v, (tuple, list, set, frozenset)):
84
+ cf = v.__class__ if retain_collection_types is True else list
85
+ items = [
86
+ _asdict_anything(
87
+ i,
88
+ is_key=False,
89
+ filter=filter,
90
+ dict_factory=dict_factory,
91
+ retain_collection_types=retain_collection_types,
92
+ value_serializer=value_serializer,
93
+ )
94
+ for i in v
95
+ ]
96
+ try:
97
+ rv[a.name] = cf(items)
98
+ except TypeError:
99
+ if not issubclass(cf, tuple):
100
+ raise
101
+ # Workaround for TypeError: cf.__new__() missing 1 required
102
+ # positional argument (which appears, for a namedturle)
103
+ rv[a.name] = cf(*items)
104
+ elif isinstance(v, dict):
105
+ df = dict_factory
106
+ rv[a.name] = df(
107
+ (
108
+ _asdict_anything(
109
+ kk,
110
+ is_key=True,
111
+ filter=filter,
112
+ dict_factory=df,
113
+ retain_collection_types=retain_collection_types,
114
+ value_serializer=value_serializer,
115
+ ),
116
+ _asdict_anything(
117
+ vv,
118
+ is_key=False,
119
+ filter=filter,
120
+ dict_factory=df,
121
+ retain_collection_types=retain_collection_types,
122
+ value_serializer=value_serializer,
123
+ ),
124
+ )
125
+ for kk, vv in v.items()
126
+ )
127
+ else:
128
+ rv[a.name] = v
129
+ else:
130
+ rv[a.name] = v
131
+ return rv
132
+
133
+
134
+ def _asdict_anything(
135
+ val,
136
+ is_key,
137
+ filter,
138
+ dict_factory,
139
+ retain_collection_types,
140
+ value_serializer,
141
+ ):
142
+ """
143
+ ``asdict`` only works on attrs instances, this works on anything.
144
+ """
145
+ if getattr(val.__class__, "__attrs_attrs__", None) is not None:
146
+ # Attrs class.
147
+ rv = asdict(
148
+ val,
149
+ recurse=True,
150
+ filter=filter,
151
+ dict_factory=dict_factory,
152
+ retain_collection_types=retain_collection_types,
153
+ value_serializer=value_serializer,
154
+ )
155
+ elif isinstance(val, (tuple, list, set, frozenset)):
156
+ if retain_collection_types is True:
157
+ cf = val.__class__
158
+ elif is_key:
159
+ cf = tuple
160
+ else:
161
+ cf = list
162
+
163
+ rv = cf(
164
+ [
165
+ _asdict_anything(
166
+ i,
167
+ is_key=False,
168
+ filter=filter,
169
+ dict_factory=dict_factory,
170
+ retain_collection_types=retain_collection_types,
171
+ value_serializer=value_serializer,
172
+ )
173
+ for i in val
174
+ ]
175
+ )
176
+ elif isinstance(val, dict):
177
+ df = dict_factory
178
+ rv = df(
179
+ (
180
+ _asdict_anything(
181
+ kk,
182
+ is_key=True,
183
+ filter=filter,
184
+ dict_factory=df,
185
+ retain_collection_types=retain_collection_types,
186
+ value_serializer=value_serializer,
187
+ ),
188
+ _asdict_anything(
189
+ vv,
190
+ is_key=False,
191
+ filter=filter,
192
+ dict_factory=df,
193
+ retain_collection_types=retain_collection_types,
194
+ value_serializer=value_serializer,
195
+ ),
196
+ )
197
+ for kk, vv in val.items()
198
+ )
199
+ else:
200
+ rv = val
201
+ if value_serializer is not None:
202
+ rv = value_serializer(None, None, rv)
203
+
204
+ return rv
205
+
206
+
207
+ def astuple(
208
+ inst,
209
+ recurse=True,
210
+ filter=None,
211
+ tuple_factory=tuple,
212
+ retain_collection_types=False,
213
+ ):
214
+ """
215
+ Return the *attrs* attribute values of *inst* as a tuple.
216
+
217
+ Optionally recurse into other *attrs*-decorated classes.
218
+
219
+ Args:
220
+ inst: Instance of an *attrs*-decorated class.
221
+
222
+ recurse (bool):
223
+ Recurse into classes that are also *attrs*-decorated.
224
+
225
+ filter (~typing.Callable):
226
+ A callable whose return code determines whether an attribute or
227
+ element is included (`True`) or dropped (`False`). Is called with
228
+ the `attrs.Attribute` as the first argument and the value as the
229
+ second argument.
230
+
231
+ tuple_factory (~typing.Callable):
232
+ A callable to produce tuples from. For example, to produce lists
233
+ instead of tuples.
234
+
235
+ retain_collection_types (bool):
236
+ Do not convert to `list` or `dict` when encountering an attribute
237
+ which type is `tuple`, `dict` or `set`. Only meaningful if
238
+ *recurse* is `True`.
239
+
240
+ Returns:
241
+ Return type of *tuple_factory*
242
+
243
+ Raises:
244
+ attrs.exceptions.NotAnAttrsClassError:
245
+ If *cls* is not an *attrs* class.
246
+
247
+ .. versionadded:: 16.2.0
248
+ """
249
+ attrs = fields(inst.__class__)
250
+ rv = []
251
+ retain = retain_collection_types # Very long. :/
252
+ for a in attrs:
253
+ v = getattr(inst, a.name)
254
+ if filter is not None and not filter(a, v):
255
+ continue
256
+ if recurse is True:
257
+ if has(v.__class__):
258
+ rv.append(
259
+ astuple(
260
+ v,
261
+ recurse=True,
262
+ filter=filter,
263
+ tuple_factory=tuple_factory,
264
+ retain_collection_types=retain,
265
+ )
266
+ )
267
+ elif isinstance(v, (tuple, list, set, frozenset)):
268
+ cf = v.__class__ if retain is True else list
269
+ items = [
270
+ (
271
+ astuple(
272
+ j,
273
+ recurse=True,
274
+ filter=filter,
275
+ tuple_factory=tuple_factory,
276
+ retain_collection_types=retain,
277
+ )
278
+ if has(j.__class__)
279
+ else j
280
+ )
281
+ for j in v
282
+ ]
283
+ try:
284
+ rv.append(cf(items))
285
+ except TypeError:
286
+ if not issubclass(cf, tuple):
287
+ raise
288
+ # Workaround for TypeError: cf.__new__() missing 1 required
289
+ # positional argument (which appears, for a namedturle)
290
+ rv.append(cf(*items))
291
+ elif isinstance(v, dict):
292
+ df = v.__class__ if retain is True else dict
293
+ rv.append(
294
+ df(
295
+ (
296
+ (
297
+ astuple(
298
+ kk,
299
+ tuple_factory=tuple_factory,
300
+ retain_collection_types=retain,
301
+ )
302
+ if has(kk.__class__)
303
+ else kk
304
+ ),
305
+ (
306
+ astuple(
307
+ vv,
308
+ tuple_factory=tuple_factory,
309
+ retain_collection_types=retain,
310
+ )
311
+ if has(vv.__class__)
312
+ else vv
313
+ ),
314
+ )
315
+ for kk, vv in v.items()
316
+ )
317
+ )
318
+ else:
319
+ rv.append(v)
320
+ else:
321
+ rv.append(v)
322
+
323
+ return rv if tuple_factory is list else tuple_factory(rv)
324
+
325
+
326
+ def has(cls):
327
+ """
328
+ Check whether *cls* is a class with *attrs* attributes.
329
+
330
+ Args:
331
+ cls (type): Class to introspect.
332
+
333
+ Raises:
334
+ TypeError: If *cls* is not a class.
335
+
336
+ Returns:
337
+ bool:
338
+ """
339
+ attrs = getattr(cls, "__attrs_attrs__", None)
340
+ if attrs is not None:
341
+ return True
342
+
343
+ # No attrs, maybe it's a specialized generic (A[str])?
344
+ generic_base = get_generic_base(cls)
345
+ if generic_base is not None:
346
+ generic_attrs = getattr(generic_base, "__attrs_attrs__", None)
347
+ if generic_attrs is not None:
348
+ # Stick it on here for speed next time.
349
+ cls.__attrs_attrs__ = generic_attrs
350
+ return generic_attrs is not None
351
+ return False
352
+
353
+
354
+ def assoc(inst, **changes):
355
+ """
356
+ Copy *inst* and apply *changes*.
357
+
358
+ This is different from `evolve` that applies the changes to the arguments
359
+ that create the new instance.
360
+
361
+ `evolve`'s behavior is preferable, but there are `edge cases`_ where it
362
+ doesn't work. Therefore `assoc` is deprecated, but will not be removed.
363
+
364
+ .. _`edge cases`: https://github.com/python-attrs/attrs/issues/251
365
+
366
+ Args:
367
+ inst: Instance of a class with *attrs* attributes.
368
+
369
+ changes: Keyword changes in the new copy.
370
+
371
+ Returns:
372
+ A copy of inst with *changes* incorporated.
373
+
374
+ Raises:
375
+ attrs.exceptions.AttrsAttributeNotFoundError:
376
+ If *attr_name* couldn't be found on *cls*.
377
+
378
+ attrs.exceptions.NotAnAttrsClassError:
379
+ If *cls* is not an *attrs* class.
380
+
381
+ .. deprecated:: 17.1.0
382
+ Use `attrs.evolve` instead if you can. This function will not be
383
+ removed du to the slightly different approach compared to
384
+ `attrs.evolve`, though.
385
+ """
386
+ new = copy.copy(inst)
387
+ attrs = fields(inst.__class__)
388
+ for k, v in changes.items():
389
+ a = getattr(attrs, k, NOTHING)
390
+ if a is NOTHING:
391
+ msg = f"{k} is not an attrs attribute on {new.__class__}."
392
+ raise AttrsAttributeNotFoundError(msg)
393
+ _OBJ_SETATTR(new, k, v)
394
+ return new
395
+
396
+
397
+ def resolve_types(
398
+ cls, globalns=None, localns=None, attribs=None, include_extras=True
399
+ ):
400
+ """
401
+ Resolve any strings and forward annotations in type annotations.
402
+
403
+ This is only required if you need concrete types in :class:`Attribute`'s
404
+ *type* field. In other words, you don't need to resolve your types if you
405
+ only use them for static type checking.
406
+
407
+ With no arguments, names will be looked up in the module in which the class
408
+ was created. If this is not what you want, for example, if the name only
409
+ exists inside a method, you may pass *globalns* or *localns* to specify
410
+ other dictionaries in which to look up these names. See the docs of
411
+ `typing.get_type_hints` for more details.
412
+
413
+ Args:
414
+ cls (type): Class to resolve.
415
+
416
+ globalns (dict | None): Dictionary containing global variables.
417
+
418
+ localns (dict | None): Dictionary containing local variables.
419
+
420
+ attribs (list | None):
421
+ List of attribs for the given class. This is necessary when calling
422
+ from inside a ``field_transformer`` since *cls* is not an *attrs*
423
+ class yet.
424
+
425
+ include_extras (bool):
426
+ Resolve more accurately, if possible. Pass ``include_extras`` to
427
+ ``typing.get_hints``, if supported by the typing module. On
428
+ supported Python versions (3.9+), this resolves the types more
429
+ accurately.
430
+
431
+ Raises:
432
+ TypeError: If *cls* is not a class.
433
+
434
+ attrs.exceptions.NotAnAttrsClassError:
435
+ If *cls* is not an *attrs* class and you didn't pass any attribs.
436
+
437
+ NameError: If types cannot be resolved because of missing variables.
438
+
439
+ Returns:
440
+ *cls* so you can use this function also as a class decorator. Please
441
+ note that you have to apply it **after** `attrs.define`. That means the
442
+ decorator has to come in the line **before** `attrs.define`.
443
+
444
+ .. versionadded:: 20.1.0
445
+ .. versionadded:: 21.1.0 *attribs*
446
+ .. versionadded:: 23.1.0 *include_extras*
447
+ """
448
+ # Since calling get_type_hints is expensive we cache whether we've
449
+ # done it already.
450
+ if getattr(cls, "__attrs_types_resolved__", None) != cls:
451
+ import typing
452
+
453
+ kwargs = {"globalns": globalns, "localns": localns}
454
+
455
+ if PY_3_9_PLUS:
456
+ kwargs["include_extras"] = include_extras
457
+
458
+ hints = typing.get_type_hints(cls, **kwargs)
459
+ for field in fields(cls) if attribs is None else attribs:
460
+ if field.name in hints:
461
+ # Since fields have been frozen we must work around it.
462
+ _OBJ_SETATTR(field, "type", hints[field.name])
463
+ # We store the class we resolved so that subclasses know they haven't
464
+ # been resolved.
465
+ cls.__attrs_types_resolved__ = cls
466
+
467
+ # Return the class so you can use it as a decorator too.
468
+ return cls
.venv/lib/python3.11/site-packages/attr/_version_info.pyi ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ class VersionInfo:
2
+ @property
3
+ def year(self) -> int: ...
4
+ @property
5
+ def minor(self) -> int: ...
6
+ @property
7
+ def micro(self) -> int: ...
8
+ @property
9
+ def releaselevel(self) -> str: ...
.venv/lib/python3.11/site-packages/attr/exceptions.pyi ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ class FrozenError(AttributeError):
4
+ msg: str = ...
5
+
6
+ class FrozenInstanceError(FrozenError): ...
7
+ class FrozenAttributeError(FrozenError): ...
8
+ class AttrsAttributeNotFoundError(ValueError): ...
9
+ class NotAnAttrsClassError(ValueError): ...
10
+ class DefaultAlreadySetError(RuntimeError): ...
11
+ class UnannotatedAttributeError(RuntimeError): ...
12
+ class PythonTooOldError(RuntimeError): ...
13
+
14
+ class NotCallableError(TypeError):
15
+ msg: str = ...
16
+ value: Any = ...
17
+ def __init__(self, msg: str, value: Any) -> None: ...
.venv/lib/python3.11/site-packages/attr/validators.pyi ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import UnionType
2
+ from typing import (
3
+ Any,
4
+ AnyStr,
5
+ Callable,
6
+ Container,
7
+ ContextManager,
8
+ Iterable,
9
+ Mapping,
10
+ Match,
11
+ Pattern,
12
+ TypeVar,
13
+ overload,
14
+ )
15
+
16
+ from attrs import _ValidatorType
17
+ from attrs import _ValidatorArgType
18
+
19
+ _T = TypeVar("_T")
20
+ _T1 = TypeVar("_T1")
21
+ _T2 = TypeVar("_T2")
22
+ _T3 = TypeVar("_T3")
23
+ _I = TypeVar("_I", bound=Iterable)
24
+ _K = TypeVar("_K")
25
+ _V = TypeVar("_V")
26
+ _M = TypeVar("_M", bound=Mapping)
27
+
28
+ def set_disabled(run: bool) -> None: ...
29
+ def get_disabled() -> bool: ...
30
+ def disabled() -> ContextManager[None]: ...
31
+
32
+ # To be more precise on instance_of use some overloads.
33
+ # If there are more than 3 items in the tuple then we fall back to Any
34
+ @overload
35
+ def instance_of(type: type[_T]) -> _ValidatorType[_T]: ...
36
+ @overload
37
+ def instance_of(type: tuple[type[_T]]) -> _ValidatorType[_T]: ...
38
+ @overload
39
+ def instance_of(
40
+ type: tuple[type[_T1], type[_T2]],
41
+ ) -> _ValidatorType[_T1 | _T2]: ...
42
+ @overload
43
+ def instance_of(
44
+ type: tuple[type[_T1], type[_T2], type[_T3]],
45
+ ) -> _ValidatorType[_T1 | _T2 | _T3]: ...
46
+ @overload
47
+ def instance_of(type: tuple[type, ...]) -> _ValidatorType[Any]: ...
48
+ @overload
49
+ def instance_of(type: UnionType) -> _ValidatorType[Any]: ...
50
+ def optional(
51
+ validator: (
52
+ _ValidatorType[_T]
53
+ | list[_ValidatorType[_T]]
54
+ | tuple[_ValidatorType[_T]]
55
+ ),
56
+ ) -> _ValidatorType[_T | None]: ...
57
+ def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
58
+ def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
59
+ def matches_re(
60
+ regex: Pattern[AnyStr] | AnyStr,
61
+ flags: int = ...,
62
+ func: Callable[[AnyStr, AnyStr, int], Match[AnyStr] | None] | None = ...,
63
+ ) -> _ValidatorType[AnyStr]: ...
64
+ def deep_iterable(
65
+ member_validator: _ValidatorArgType[_T],
66
+ iterable_validator: _ValidatorType[_I] | None = ...,
67
+ ) -> _ValidatorType[_I]: ...
68
+ def deep_mapping(
69
+ key_validator: _ValidatorType[_K],
70
+ value_validator: _ValidatorType[_V],
71
+ mapping_validator: _ValidatorType[_M] | None = ...,
72
+ ) -> _ValidatorType[_M]: ...
73
+ def is_callable() -> _ValidatorType[_T]: ...
74
+ def lt(val: _T) -> _ValidatorType[_T]: ...
75
+ def le(val: _T) -> _ValidatorType[_T]: ...
76
+ def ge(val: _T) -> _ValidatorType[_T]: ...
77
+ def gt(val: _T) -> _ValidatorType[_T]: ...
78
+ def max_len(length: int) -> _ValidatorType[_T]: ...
79
+ def min_len(length: int) -> _ValidatorType[_T]: ...
80
+ def not_(
81
+ validator: _ValidatorType[_T],
82
+ *,
83
+ msg: str | None = None,
84
+ exc_types: type[Exception] | Iterable[type[Exception]] = ...,
85
+ ) -> _ValidatorType[_T]: ...
86
+ def or_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a6211b1e1e47f505c8b79cb8b191ba1169b99f917cd874de883cffd11aa9883
3
+ size 406024
.venv/lib/python3.11/site-packages/outlines/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Outlines is a Generative Model Programming Framework."""
2
+ import outlines.generate
3
+ import outlines.grammars
4
+ import outlines.models
5
+ import outlines.processors
6
+ import outlines.types
7
+ from outlines.base import vectorize
8
+ from outlines.caching import clear_cache, disable_cache, get_cache
9
+ from outlines.function import Function
10
+ from outlines.prompts import prompt
11
+
12
+ __all__ = [
13
+ "clear_cache",
14
+ "disable_cache",
15
+ "get_cache",
16
+ "Function",
17
+ "prompt",
18
+ "vectorize",
19
+ "grammars",
20
+ ]
.venv/lib/python3.11/site-packages/outlines/_version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.1.11'
16
+ __version_tuple__ = version_tuple = (0, 1, 11)
.venv/lib/python3.11/site-packages/outlines/base.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import builtins
3
+ import functools
4
+ import inspect
5
+ from typing import Callable, Optional
6
+
7
+ import numpy as np
8
+
9
+ # Import required functions based on NumPy version
10
+ np_major_version = int(np.__version__.split(".")[0])
11
+ if np_major_version >= 2:
12
+ from numpy.lib._function_base_impl import (
13
+ _calculate_shapes,
14
+ _parse_gufunc_signature,
15
+ _parse_input_dimensions,
16
+ _update_dim_sizes,
17
+ )
18
+ else:
19
+ from numpy.lib.function_base import (
20
+ _calculate_shapes,
21
+ _parse_gufunc_signature,
22
+ _parse_input_dimensions,
23
+ _update_dim_sizes,
24
+ )
25
+
26
+ # Allow nested loops for running in notebook. We don't enable it globally as it
27
+ # may interfere with other libraries that use asyncio.
28
+ if hasattr(builtins, "__IPYTHON__"):
29
+ try:
30
+ import nest_asyncio
31
+
32
+ nest_asyncio.apply()
33
+ except ImportError:
34
+ print(
35
+ "Couldn't patch nest_asyncio because it's not installed. Running in the notebook might be have issues"
36
+ )
37
+
38
+
39
+ class vectorize:
40
+ """Returns an object that acts like a function but takes arrays as an input.
41
+
42
+ The vectorized function evaluates `func` over successive tuples of the input
43
+ chararrays and returns a single NumPy chararrays or a tuple of NumPy chararrays.
44
+
45
+ Its behavior is similar to NumPy's `vectorize` for Python functions: the function
46
+ being vectorized is executed in a `for` loop. Coroutines, however, are executed
47
+ concurrently.
48
+
49
+ Part of the code was adapted from `numpy.lib.function_base`.
50
+
51
+ """
52
+
53
+ def __init__(self, func: Callable, signature: Optional[str] = None):
54
+ self.func = func
55
+ self.signature = signature
56
+ self.is_coroutine_fn = inspect.iscoroutinefunction(func)
57
+
58
+ functools.update_wrapper(self, func)
59
+
60
+ if signature is not None:
61
+ # Parse the signature string into a Python data structure.
62
+ # For instance "(m),(s)->(s,m)" becomes `([(m,),(s,)],[(s,m)])`.
63
+ self._in_and_out_core_dimensions = _parse_gufunc_signature(signature)
64
+ else:
65
+ self._in_and_out_core_dimensions = None
66
+
67
+ def __call__(self, *args, **kwargs):
68
+ """Call the vectorized function."""
69
+ if not args and not kwargs:
70
+ return self.call_thunk()
71
+ elif self.signature is not None:
72
+ return self.call_with_signature(*args, **kwargs)
73
+ else:
74
+ return self.call_no_signature(*args, **kwargs)
75
+
76
+ def call_thunk(self):
77
+ """Call a vectorized thunk.
78
+
79
+ Thunks have no arguments and can thus be called directly.
80
+
81
+ """
82
+ if self.is_coroutine_fn:
83
+ loop = asyncio.new_event_loop()
84
+ try:
85
+ outputs = loop.run_until_complete(self.func())
86
+ finally:
87
+ loop.close()
88
+ else:
89
+ outputs = self.func()
90
+
91
+ return outputs
92
+
93
+ def call_no_signature(self, *args, **kwargs):
94
+ """Call functions and coroutines when no signature is specified.
95
+
96
+ When no signature is specified we assume that all of the function's
97
+ inputs and outputs are scalars (core dimension of zero). We first
98
+ broadcast the input arrays, then iteratively apply the function over the
99
+ elements of the broadcasted arrays and finally reshape the results to
100
+ match the input shape.
101
+
102
+ Functions are executed in a for loop, coroutines are executed
103
+ concurrently.
104
+
105
+ """
106
+ # Convert args and kwargs to arrays
107
+ args = [np.array(arg) for arg in args]
108
+ kwargs = {key: np.array(value) for key, value in kwargs.items()}
109
+
110
+ # Broadcast args and kwargs
111
+ broadcast_shape = np.broadcast(*args, *list(kwargs.values())).shape
112
+ args = [np.broadcast_to(arg, broadcast_shape) for arg in args]
113
+ kwargs = {
114
+ key: np.broadcast_to(value, broadcast_shape)
115
+ for key, value in kwargs.items()
116
+ }
117
+
118
+ # Execute functions in a loop, and coroutines concurrently
119
+ if self.is_coroutine_fn:
120
+ outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs)
121
+ else:
122
+ outputs = self.vectorize_call(broadcast_shape, args, kwargs)
123
+
124
+ # `outputs` is a flat array or a tuple of flat arrays. We reshape the arrays
125
+ # to match the input shape.
126
+ outputs = [
127
+ results if isinstance(results, tuple) else (results,) for results in outputs
128
+ ]
129
+ outputs = tuple(
130
+ [np.asarray(x).reshape(broadcast_shape).squeeze() for x in zip(*outputs)]
131
+ )
132
+ outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs])
133
+
134
+ n_results = len(list(outputs))
135
+
136
+ return outputs[0] if n_results == 1 else outputs
137
+
138
+ def call_with_signature(self, *args, **kwargs):
139
+ """Call functions and coroutines when a signature is specified."""
140
+ input_core_dims, output_core_dims = self._in_and_out_core_dimensions
141
+
142
+ # Make sure that the numbers of arguments passed is compatible with
143
+ # the signature.
144
+ num_args = len(args) + len(kwargs)
145
+ if num_args != len(input_core_dims):
146
+ raise TypeError(
147
+ "wrong number of positional arguments: "
148
+ "expected %r, got %r" % (len(input_core_dims), len(args))
149
+ )
150
+
151
+ # Convert args and kwargs to arrays
152
+ args = [np.asarray(arg) for arg in args]
153
+ kwargs = {key: np.array(value) for key, value in kwargs.items()}
154
+
155
+ # Find the arguments' broadcast shape, and map placeholder
156
+ # variables in the signature to the number of dimensions
157
+ # they correspond to given the arguments.
158
+ broadcast_shape, dim_sizes = _parse_input_dimensions(
159
+ args + list(kwargs.values()), input_core_dims
160
+ )
161
+
162
+ # Calculate the shape to which each of the arguments should be broadcasted
163
+ # and reshape them accordingly.
164
+ input_shapes = _calculate_shapes(broadcast_shape, dim_sizes, input_core_dims)
165
+ args = [
166
+ np.broadcast_to(arg, shape, subok=True)
167
+ for arg, shape in zip(args, input_shapes)
168
+ ]
169
+ kwargs = {
170
+ key: np.broadcast_to(value, broadcast_shape)
171
+ for key, value in kwargs.items()
172
+ }
173
+
174
+ n_out = len(output_core_dims)
175
+
176
+ if self.is_coroutine_fn:
177
+ outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs)
178
+ else:
179
+ outputs = self.vectorize_call(broadcast_shape, args, kwargs)
180
+
181
+ outputs = [
182
+ results if isinstance(results, tuple) else (results,) for results in outputs
183
+ ]
184
+
185
+ flat_outputs = list(zip(*outputs))
186
+ n_results = len(flat_outputs)
187
+
188
+ if n_out != n_results:
189
+ raise ValueError(
190
+ f"wrong number of outputs from the function, expected {n_out}, got {n_results}"
191
+ )
192
+
193
+ # The number of dimensions of the outputs are not necessarily known in
194
+ # advance. The following iterates over the results and updates the
195
+ # number of dimensions of the outputs accordingly.
196
+ for results, core_dims in zip(flat_outputs, output_core_dims):
197
+ for result in results:
198
+ _update_dim_sizes(dim_sizes, result, core_dims)
199
+
200
+ # Calculate the shape to which each of the outputs should be broadcasted
201
+ # and reshape them.
202
+ shapes = _calculate_shapes(broadcast_shape, dim_sizes, output_core_dims)
203
+ outputs = tuple(
204
+ [
205
+ np.hstack(results).reshape(shape).squeeze()
206
+ for shape, results in zip(shapes, zip(*outputs))
207
+ ]
208
+ )
209
+ outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs])
210
+
211
+ return outputs[0] if n_results == 1 else outputs
212
+
213
+ def vectorize_call(self, broadcast_shape, args, kwargs):
214
+ """Run the function in a for loop.
215
+
216
+ A possible extension would be to parallelize the calls.
217
+
218
+ Parameters
219
+ ----------
220
+ broadcast_shape
221
+ The brodcast shape of the input arrays.
222
+ args
223
+ The function's broadcasted arguments.
224
+ kwargs
225
+ The function's broadcasted keyword arguments.
226
+
227
+ """
228
+ outputs = []
229
+ for index in np.ndindex(*broadcast_shape):
230
+ current_args = tuple(arg[index] for arg in args)
231
+ current_kwargs = {key: value[index] for key, value in kwargs.items()}
232
+ outputs.append(self.func(*current_args, **current_kwargs))
233
+
234
+ return outputs
235
+
236
+ def vectorize_call_coroutine(self, broadcast_shape, args, kwargs):
237
+ """Run coroutines concurrently.
238
+
239
+ Creates as many tasks as needed and executes them in a new event
240
+ loop.
241
+
242
+ Parameters
243
+ ----------
244
+ broadcast_shape
245
+ The brodcast shape of the input arrays.
246
+ args
247
+ The function's broadcasted arguments.
248
+ kwargs
249
+ The function's broadcasted keyword arguments.
250
+
251
+ """
252
+
253
+ async def create_and_gather_tasks():
254
+ tasks = []
255
+ for index in np.ndindex(*broadcast_shape):
256
+ current_args = tuple(arg[index] for arg in args)
257
+ current_kwargs = {key: value[index] for key, value in kwargs.items()}
258
+ tasks.append(self.func(*current_args, **current_kwargs))
259
+
260
+ outputs = await asyncio.gather(*tasks)
261
+
262
+ return outputs
263
+
264
+ loop = asyncio.new_event_loop()
265
+ try:
266
+ outputs = loop.run_until_complete(create_and_gather_tasks())
267
+ finally:
268
+ loop.close()
269
+
270
+ return outputs
271
+
272
+
273
+ def _update_arrays_type(arrays, results):
274
+ """Update the dtype of arrays.
275
+
276
+ String arrays contain strings of fixed length. Here they are initialized with
277
+ the type of the first results, so that if the next results contain longer
278
+ strings they will be truncated when added to the output arrays. Here we
279
+ update the type if the current results contain longer strings than in the
280
+ current output array.
281
+
282
+ Parameters
283
+ ----------
284
+ arrays
285
+ Arrays that contain the vectorized function's results.
286
+ results
287
+ The current output of the function being vectorized.
288
+
289
+ """
290
+
291
+ updated_arrays = []
292
+ for array, result in zip(arrays, results):
293
+ if array.dtype.type == np.str_:
294
+ if array.dtype < np.array(result).dtype:
295
+ array = array.astype(np.array(result).dtype)
296
+
297
+ updated_arrays.append(array)
298
+
299
+ return tuple(updated_arrays)
.venv/lib/python3.11/site-packages/outlines/caching.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import contextlib
3
+ import functools
4
+ import os
5
+ from typing import Callable, Optional
6
+
7
+ import cloudpickle
8
+ from diskcache import Cache, Disk
9
+ from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name
10
+
11
+ _caching_enabled = True
12
+
13
+
14
+ class CloudpickleDisk(Disk):
15
+ def __init__(self, directory, compress_level=1, **kwargs):
16
+ self.compress_level = compress_level
17
+ super().__init__(directory, **kwargs)
18
+
19
+ def put(self, key):
20
+ data = cloudpickle.dumps(key)
21
+ return super().put(data)
22
+
23
+ def get(self, key, raw):
24
+ data = super().get(key, raw)
25
+ return cloudpickle.loads(data)
26
+
27
+ def store(self, value, read, key=UNKNOWN):
28
+ if not read:
29
+ value = cloudpickle.dumps(value)
30
+ return super().store(value, read, key=key)
31
+
32
+ def fetch(self, mode, filename, value, read):
33
+ data = super().fetch(mode, filename, value, read)
34
+ if not read:
35
+ data = cloudpickle.loads(data)
36
+ return data
37
+
38
+
39
+ @functools.lru_cache(1)
40
+ def get_cache():
41
+ """Get the context object that contains previously-computed return values.
42
+
43
+ The cache is used to avoid unnecessary computations and API calls, which can
44
+ be long and expensive for large models.
45
+
46
+ The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice
47
+ can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR`
48
+ environment variable.
49
+
50
+ """
51
+ from outlines._version import __version__ as outlines_version # type: ignore
52
+
53
+ home_dir = os.path.expanduser("~")
54
+ cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
55
+ memory = Cache(
56
+ cache_dir,
57
+ eviction_policy="none",
58
+ cull_limit=0,
59
+ disk=CloudpickleDisk,
60
+ )
61
+
62
+ # ensure if version upgrade occurs, old cache is pruned
63
+ if outlines_version != memory.get("__version__"):
64
+ memory.clear()
65
+ memory["__version__"] = outlines_version
66
+
67
+ return memory
68
+
69
+
70
+ def cache(expire: Optional[float] = None, typed=False, ignore=()):
71
+ """Caching decorator for memoizing function calls.
72
+
73
+ The cache key is created based on the values returned by the key_function callable
74
+ if provided or based on the arguments of the decorated function directly otherwise
75
+
76
+ This is based on `diskcache`'s `memoize`.
77
+
78
+ Parameters
79
+ ----------
80
+ expire
81
+ Seconds until arguments expire.
82
+ typed
83
+ Cache different types separately.
84
+ ignore
85
+ Positional or keyword arguments to ignore.
86
+
87
+ Returns
88
+ -------
89
+ A decorator function that can be applied to other functions.
90
+ """
91
+
92
+ def decorator(cached_function: Callable):
93
+ memory = get_cache()
94
+
95
+ base = (full_name(cached_function),)
96
+
97
+ if asyncio.iscoroutinefunction(cached_function):
98
+
99
+ async def wrapper(*args, **kwargs):
100
+ if not _caching_enabled:
101
+ return await cached_function(*args, **kwargs)
102
+
103
+ cache_key = wrapper.__cache_key__(*args, **kwargs)
104
+ result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)
105
+
106
+ if result is ENOVAL:
107
+ result = await cached_function(*args, **kwargs)
108
+ wrapper.__memory__.set(cache_key, result, expire, retry=True)
109
+
110
+ return result
111
+
112
+ else:
113
+
114
+ def wrapper(*args, **kwargs):
115
+ if not _caching_enabled:
116
+ return cached_function(*args, **kwargs)
117
+
118
+ cache_key = wrapper.__cache_key__(*args, **kwargs)
119
+ result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)
120
+
121
+ if result is ENOVAL:
122
+ result = cached_function(*args, **kwargs)
123
+ wrapper.__memory__.set(cache_key, result, expire, retry=True)
124
+
125
+ return result
126
+
127
+ def __cache_key__(*args, **kwargs):
128
+ """Make key for cache given function arguments."""
129
+ return args_to_key(base, args, kwargs, typed, ignore)
130
+
131
+ wrapper.__cache_key__ = __cache_key__ # type: ignore
132
+ wrapper.__memory__ = memory # type: ignore
133
+ wrapper.__wrapped__ = cached_function # type: ignore
134
+
135
+ return wrapper
136
+
137
+ return decorator
138
+
139
+
140
+ def disable_cache():
141
+ """Disable the cache for this session.
142
+
143
+ Generative models output different results each time they are called when
144
+ sampling. This can be a desirable property for some workflows, in which case
145
+ one can call `outlines.call.disable` to disable the cache for the session.
146
+
147
+ This function does not delete the cache, call `outlines.cache.clear`
148
+ instead. It also does not overwrite the cache with the values returned
149
+ during the session.
150
+
151
+ Example
152
+ -------
153
+
154
+ `outlines.cache.disable` should be called right after importing outlines:
155
+
156
+ >>> import outlines.caching as cache
157
+ >>> cache.disable_cache()
158
+
159
+ """
160
+ global _caching_enabled
161
+ _caching_enabled = False
162
+
163
+
164
+ def clear_cache():
165
+ """Erase the cache completely."""
166
+ memory = get_cache()
167
+ memory.clear()
168
+
169
+
170
+ @contextlib.contextmanager
171
+ def cache_disabled():
172
+ # outlines.caching._caching_enabled
173
+ global _caching_enabled
174
+ original_state = _caching_enabled
175
+ _caching_enabled = False
176
+ try:
177
+ yield
178
+ finally:
179
+ _caching_enabled = original_state
.venv/lib/python3.11/site-packages/outlines/fsm/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (185 Bytes). View file
 
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc ADDED
Binary file (52.9 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc ADDED
Binary file (5.1 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/fsm/guide.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import copy
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Any, Generator, Union
5
+
6
+ import torch
7
+ from lark.indenter import DedentError
8
+ from lark.lexer import UnexpectedCharacters, UnexpectedToken
9
+ from outlines_core.fsm.guide import Generate
10
+ from outlines_core.fsm.guide import Guide as CoreGuide
11
+ from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide
12
+ from outlines_core.fsm.guide import Write
13
+ from outlines_core.fsm.guide import (
14
+ create_states_mapping as uncached_create_states_mapping,
15
+ )
16
+
17
+ from outlines import grammars
18
+ from outlines.fsm.parsing import PartialLark, PartialParserState
19
+
20
+ if TYPE_CHECKING:
21
+ from outlines.models.tokenizer import Tokenizer
22
+
23
+
24
+ Instruction = Union[Write, Generate]
25
+
26
+
27
+ class Guide(CoreGuide):
28
+ """Base definition of a generation guide.
29
+
30
+ A generation guide defines the behavior of a finite-state machine that guides
31
+ a text generation procedure. Unlike the DFAs built from regular expressions
32
+ guides can also emit a `Write` instructions which tells the model that it can
33
+ append a sequence of tokens (or token word) instead of generating it.
34
+
35
+ """
36
+
37
+ initial_state: Any
38
+
39
+
40
+ class StopAtEOSGuide(Guide):
41
+ """Guide to generate tokens until the EOS token has been generated."""
42
+
43
+ final_state = 1
44
+ start_state = 0 # TODO: remove start_state, use only initial_state
45
+ initial_state = 0
46
+
47
+ def __init__(self, tokenizer: "Tokenizer"):
48
+ """Initialize the generation guide.
49
+
50
+ model
51
+ The logit generator used to generate the next token.
52
+
53
+ """
54
+ self.eos_token_id = tokenizer.eos_token_id
55
+ self.vocabulary = tokenizer.vocabulary.values()
56
+
57
+ def get_next_instruction(self, state: int) -> Instruction:
58
+ if self.is_final_state(state):
59
+ return Write([self.eos_token_id])
60
+ return Generate(None)
61
+
62
+ def get_next_state(self, state: int, token_id: int) -> int:
63
+ if token_id == self.eos_token_id or state == self.final_state:
64
+ return self.final_state
65
+
66
+ return self.initial_state
67
+
68
+ def is_final_state(self, state: int):
69
+ return state == self.final_state
70
+
71
+ def copy(self):
72
+ return self
73
+
74
+
75
+ def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
76
+ return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
77
+
78
+
79
+ class RegexGuide(CoreRegexGuide):
80
+ """
81
+ Guide to generate text in the language of a regular expression.
82
+ CoreRegexGuide with outlines cache
83
+ """
84
+
85
+ @classmethod
86
+ def from_regex(
87
+ cls,
88
+ regex_string: str,
89
+ tokenizer,
90
+ **kwargs,
91
+ ):
92
+ return super().from_regex(
93
+ regex_string,
94
+ tokenizer,
95
+ _create_states_mapping=cached_create_states_mapping,
96
+ **kwargs,
97
+ )
98
+
99
+
100
+ CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])
101
+
102
+
103
+ class CFGGuide(Guide):
104
+ """Guide to generate text that is in the language of a context-free Lark grammar."""
105
+
106
+ def __init__(self, cfg_string: str, tokenizer):
107
+ """
108
+ Construct the PartialLark parser and set the empty initial_state (PartialParserState)
109
+ """
110
+ warnings.warn(
111
+ "Outlines' public *community-contributed* CFG structured generation is experimental. "
112
+ "Please review https://dottxt-ai.github.io/outlines/latest/reference/generation/cfg#disclaimer"
113
+ )
114
+
115
+ self.cfg_string = cfg_string
116
+ self.tokenizer = tokenizer
117
+ self.eos_token_id = self.tokenizer.eos_token_id
118
+ self.parser = PartialLark(
119
+ cfg_string,
120
+ parser="lalr",
121
+ import_paths=[grammars.GRAMMAR_PATH],
122
+ )
123
+ self.initial_state = CFGState(
124
+ parser_state=self.parser.parse(""), prev_token=None
125
+ )
126
+
127
+ def get_next_instruction(self, state: CFGState) -> Instruction:
128
+ """Return the next instruction for guided generation.
129
+
130
+ Current lazy approach:
131
+ - For each token in the vocabulary
132
+ - create a copy of the parsers state
133
+ - add the tokens to the parsers input text
134
+ - if valid, add token to returned tokens
135
+
136
+ Further refinements are necessary for performant text processing.
137
+
138
+ Parameters
139
+ ----------
140
+ state
141
+ The guides current PartialParserState, or None if complete
142
+
143
+ Returns
144
+ -------
145
+ A `Generate` instance that contains the model and the allowed token ids.
146
+
147
+ """
148
+
149
+ if state.parser_state is None:
150
+ return Write(torch.tensor([self.eos_token_id]))
151
+
152
+ valid_tokens = list(
153
+ self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
154
+ )
155
+ if len(valid_tokens) == 1:
156
+ return Write(torch.tensor(valid_tokens))
157
+ return Generate(torch.tensor(valid_tokens))
158
+
159
+ def iter_valid_token_ids(
160
+ self, state: CFGState, candidate_token_ids: list
161
+ ) -> Generator[int, None, None]:
162
+ """
163
+ Iterate over the given token_ids and yield those that are valid for the current parser state.
164
+
165
+ Parameters
166
+ ----------
167
+ parser_state
168
+ The current state of the parser, or None if complete.
169
+ token_ids
170
+ The list of token ids to check for validity.
171
+
172
+ Yields
173
+ ------
174
+ int
175
+ Valid token ids.
176
+ """
177
+ if state.parser_state is None:
178
+ yield self.eos_token_id
179
+ return
180
+
181
+ for token_id in candidate_token_ids:
182
+ if token_id == self.eos_token_id:
183
+ if self.can_terminate_state(state):
184
+ yield token_id
185
+ else:
186
+ try:
187
+ self._get_parser_state_token_applied(state, int(token_id))
188
+ yield token_id
189
+ except (
190
+ ValueError,
191
+ EOFError,
192
+ UnexpectedToken,
193
+ UnexpectedCharacters,
194
+ DedentError,
195
+ ):
196
+ pass
197
+
198
+ def get_next_state(self, state: CFGState, token_id: int) -> CFGState:
199
+ """
200
+ Update the state of the guide.
201
+ Decode the token_id, and calculate the new parser_state with the token applied.
202
+
203
+ Parameters
204
+ ----------
205
+ state
206
+ The guides current PartialParserState, or None if complete
207
+ token_id
208
+ The id of the token that was just generated.
209
+
210
+ Returns
211
+ -------
212
+ The guides new PartialParserState
213
+
214
+ """
215
+ if state.parser_state is None or token_id == self.eos_token_id:
216
+ parser_state = None
217
+ else:
218
+ parser_state = self._get_parser_state_token_applied(state, int(token_id))
219
+ return CFGState(parser_state=parser_state, prev_token=token_id)
220
+
221
+ def _get_parser_state_token_applied(
222
+ self, state: CFGState, token_id: int
223
+ ) -> PartialParserState:
224
+ """
225
+ Don't mutate `parser_state`, copy to protect
226
+
227
+ Get the token string
228
+ - if first token in generation: tokenizer.decode (no leading whitespace)
229
+ - else: normalized (with possibly leading whitespace)
230
+
231
+ Don't allow empty ("") tokens, raise ValueError
232
+ """
233
+ parser_state = copy.copy(state.parser_state) # prevent side effects
234
+
235
+ # normalize
236
+ if state.prev_token is None:
237
+ new_token_str = self.tokenizer.decode([token_id])[0]
238
+ else:
239
+ prev_token_str = self.tokenizer.decode([[state.prev_token]])[0]
240
+ combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[
241
+ 0
242
+ ]
243
+ new_token_str = combined_token_str[len(prev_token_str) :]
244
+
245
+ if new_token_str == "":
246
+ raise ValueError("empty next token")
247
+
248
+ # update parser with new token
249
+ parser_state.lexer.state.text += new_token_str
250
+ self.parser.parse_from_state(parser_state, is_end=False)
251
+
252
+ return parser_state
253
+
254
+ def is_final_state(self, state: CFGState) -> bool:
255
+ # TODO: remove this method, use can_terminate_state and must_terminate_state
256
+ # here and in RegexGuide per https://github.com/dottxt-ai/outlines/issues/885
257
+ return self.can_terminate_state(state)
258
+
259
+ def can_terminate_state(self, state: CFGState) -> bool:
260
+ """Generation is allowed to terminate"""
261
+ if state.parser_state is not None:
262
+ try:
263
+ copy.copy(state.parser_state).feed_eof()
264
+ except UnexpectedToken:
265
+ return False
266
+ return True
267
+
268
+ def must_terminate_state(self, state: CFGState) -> bool:
269
+ """Generation must terminate, no legal continuations"""
270
+ return state.parser_state is None or set(state.parser_state.accepts()).issubset(
271
+ {"$END"}
272
+ )
273
+
274
+ def copy(self) -> "CFGGuide":
275
+ """Create a copy of the Guide."""
276
+ return CFGGuide(self.cfg_string, self.tokenizer)
.venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import json
3
+ import warnings
4
+ from enum import Enum
5
+ from typing import Callable, Type, Union
6
+
7
+ from pydantic import BaseModel, create_model
8
+
9
+
10
+ def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
11
+ """Convert a JSON schema to a string.
12
+
13
+ Parameters
14
+ ----------
15
+ json_schema
16
+ The JSON schema.
17
+
18
+ Returns
19
+ -------
20
+ str
21
+ The JSON schema converted to a string.
22
+
23
+ Raises
24
+ ------
25
+ ValueError
26
+ If the schema is not a dictionary, a string or a Pydantic class.
27
+ """
28
+ if isinstance(json_schema, dict):
29
+ schema_str = json.dumps(json_schema)
30
+ elif isinstance(json_schema, str):
31
+ schema_str = json_schema
32
+ elif issubclass(json_schema, BaseModel):
33
+ schema_str = json.dumps(json_schema.model_json_schema())
34
+ else:
35
+ raise ValueError(
36
+ f"Cannot parse schema {json_schema}. The schema must be either "
37
+ + "a Pydantic class, a dictionary or a string that contains the JSON "
38
+ + "schema specification"
39
+ )
40
+ return schema_str
41
+
42
+
43
+ def get_schema_from_signature(fn: Callable) -> dict:
44
+ """Turn a function signature into a JSON schema.
45
+
46
+ Every JSON object valid to the output JSON Schema can be passed
47
+ to `fn` using the ** unpacking syntax.
48
+
49
+ """
50
+ signature = inspect.signature(fn)
51
+ arguments = {}
52
+ for name, arg in signature.parameters.items():
53
+ if arg.annotation == inspect._empty:
54
+ raise ValueError("Each argument must have a type annotation")
55
+ else:
56
+ arguments[name] = (arg.annotation, ...)
57
+
58
+ try:
59
+ fn_name = fn.__name__
60
+ except Exception as e:
61
+ fn_name = "Arguments"
62
+ warnings.warn(
63
+ f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}",
64
+ category=UserWarning,
65
+ )
66
+ model = create_model(fn_name, **arguments)
67
+
68
+ return model.model_json_schema()
69
+
70
+
71
+ def get_schema_from_enum(myenum: type[Enum]) -> dict:
72
+ if len(myenum) == 0:
73
+ raise ValueError(
74
+ f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
75
+ )
76
+ choices = [
77
+ get_schema_from_signature(elt.value.func)
78
+ if callable(elt.value)
79
+ else {"const": elt.value}
80
+ for elt in myenum
81
+ ]
82
+ schema = {"title": myenum.__name__, "oneOf": choices}
83
+ return schema
.venv/lib/python3.11/site-packages/outlines/fsm/parsing.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy, deepcopy
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ FrozenSet,
8
+ Generator,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ Sequence,
13
+ Set,
14
+ Tuple,
15
+ Union,
16
+ )
17
+
18
+ import interegular
19
+ from interegular.fsm import FSM, Alphabet, OblivionError
20
+ from interegular.patterns import Unsupported
21
+ from lark import Lark, Token
22
+ from lark.common import LexerConf, ParserConf
23
+ from lark.exceptions import LexError, UnexpectedInput
24
+ from lark.indenter import Indenter
25
+ from lark.lexer import (
26
+ BasicLexer,
27
+ ContextualLexer,
28
+ LexerState,
29
+ LexerThread,
30
+ Scanner,
31
+ UnexpectedCharacters,
32
+ UnexpectedToken,
33
+ _create_unless,
34
+ )
35
+ from lark.parser_frontends import (
36
+ ParsingFrontend,
37
+ PostLexConnector,
38
+ _validate_frontend_args,
39
+ )
40
+ from lark.parsers.lalr_analysis import (
41
+ Action,
42
+ IntParseTable,
43
+ LALR_Analyzer,
44
+ ParseTable,
45
+ Shift,
46
+ )
47
+ from lark.parsers.lalr_interactive_parser import InteractiveParser
48
+ from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser
49
+ from outlines_core.fsm.regex import (
50
+ BetterFSM,
51
+ get_token_transition_keys,
52
+ make_deterministic_fsm,
53
+ )
54
+
55
+ PartialParseState = Tuple[str, int]
56
+ ParseStateType = Union[int, FrozenSet]
57
+
58
+
59
+ @dataclass
60
+ class PartialTerminalInfo:
61
+ priority: int
62
+ terminal_name: str
63
+ can_transition: bool
64
+ is_final: bool
65
+
66
+
67
+ @dataclass
68
+ class PartialTokensInfo:
69
+ fsm_state_seq: Tuple[int, ...]
70
+ is_not_finished: bool
71
+ terminals_and_info: Tuple[PartialTerminalInfo, ...]
72
+ final_terminals_and_info: Tuple[PartialTerminalInfo, ...]
73
+
74
+
75
+ class PartialParserConf(ParserConf):
76
+ __serialize_fields__ = (
77
+ "rules",
78
+ "start",
79
+ "parser_type",
80
+ "deterministic",
81
+ "use_value_stack",
82
+ )
83
+
84
+ def __init__(self, rules, callbacks, start, deterministic, use_value_stack):
85
+ super().__init__(rules, callbacks, start)
86
+ self.deterministic = deterministic
87
+ self.use_value_stack = use_value_stack
88
+
89
+
90
+ class PartialLark(Lark):
91
+ __serialize_fields__ = (
92
+ "parser",
93
+ "rules",
94
+ "options",
95
+ "deterministic",
96
+ "use_value_stack",
97
+ )
98
+
99
+ def __init__(self, grammar, **options):
100
+ # TODO: Could've extended `LarkOptions`, but all these extensions are
101
+ # already way too much (and brittle). This library really needs a
102
+ # complete refactoring.
103
+ self.deterministic = options.pop("deterministic", False)
104
+ self.use_value_stack = options.pop("use_value_stack", False)
105
+ options["regex"] = True
106
+ super().__init__(grammar, **options)
107
+ assert self.options.parser == "lalr"
108
+
109
+ def _build_lexer(self, dont_ignore: bool = False) -> "PartialBasicLexer":
110
+ lexer_conf = self.lexer_conf
111
+ if dont_ignore:
112
+ from copy import copy
113
+
114
+ lexer_conf = copy(lexer_conf)
115
+ lexer_conf.ignore = ()
116
+
117
+ return PartialBasicLexer(lexer_conf)
118
+
119
+ def _build_parser(self) -> "PartialParsingFrontend":
120
+ self._prepare_callbacks()
121
+ _validate_frontend_args(self.options.parser, self.options.lexer)
122
+ parser_conf = PartialParserConf(
123
+ self.rules,
124
+ self._callbacks,
125
+ self.options.start,
126
+ self.deterministic,
127
+ self.use_value_stack,
128
+ )
129
+
130
+ # This is `_construct_parsing_frontend` expanded/inlined
131
+ parser_type = self.options.parser
132
+ lexer_type = self.options.lexer
133
+ lexer_conf = self.lexer_conf
134
+
135
+ assert isinstance(lexer_conf, LexerConf)
136
+ assert isinstance(parser_conf, ParserConf)
137
+ parser_conf.parser_type = parser_type
138
+ self.lexer_conf.lexer_type = lexer_type
139
+ return PartialParsingFrontend(lexer_conf, parser_conf, self.options)
140
+
141
+ def __repr__(self):
142
+ return "{}(open({!r}), parser={!r}, lexer={!r}, ...)".format(
143
+ type(self).__name__,
144
+ self.source_path,
145
+ self.options.parser,
146
+ self.options.lexer,
147
+ )
148
+
149
+ def parse_from_state(self, parse_state: "PartialParseState", is_end=False):
150
+ return self.parser.parser.parser.parse_from_state(parse_state, is_end=is_end)
151
+
152
+
153
+ class PartialLexerThread(LexerThread):
154
+ def __copy__(self):
155
+ return type(self)(copy(self.lexer), copy(self.state))
156
+
157
+ def __repr__(self):
158
+ return f"{type(self).__name__}(lexer={self.lexer!r}, state={self.state!r})"
159
+
160
+
161
+ class PartialPostLexConnector(PostLexConnector):
162
+ def __copy__(self):
163
+ return type(self)(self.lexer, copy(self.postlexer))
164
+
165
+ def __repr__(self):
166
+ return (
167
+ f"{type(self).__name__}(lexer={self.lexer!r}, postlexer={self.postlexer!r})"
168
+ )
169
+
170
+
171
+ class PartialParsingFrontend(ParsingFrontend):
172
+ def __init__(self, lexer_conf, parser_conf, options, parser=None):
173
+ assert parser_conf.parser_type == "lalr"
174
+
175
+ options._plugins["LALR_Parser"] = PartialLALRParser
176
+ options._plugins["BasicLexer"] = PartialBasicLexer
177
+ options._plugins["ContextualLexer"] = PartialContextualLexer
178
+ options._plugins["LexerThread"] = PartialLexerThread
179
+
180
+ super().__init__(lexer_conf, parser_conf, options, parser=parser)
181
+
182
+ if lexer_conf.postlex:
183
+ self.lexer = PartialPostLexConnector(self.lexer.lexer, lexer_conf.postlex)
184
+
185
+ self._termset_fsm_info = None
186
+ self._symbols_to_states: Optional[
187
+ Dict[str, Set[Tuple[ParseStateType, Action]]]
188
+ ] = None
189
+ self._reverse_shifts: Optional[
190
+ Dict[ParseStateType, Dict[str, Set[ParseStateType]]]
191
+ ] = None
192
+ # self._state_transition_map: Optional[
193
+ # Dict[Tuple[ParseStateType, str], Set[ParseStateType]]
194
+ # ] = None
195
+
196
+ def _compute_maps(
197
+ self,
198
+ ):
199
+ """Compute state transition and symbols-to-states maps."""
200
+ self._reverse_shifts = {}
201
+ self._symbols_to_states = {}
202
+
203
+ parse_table = self.parser.parser.parse_table
204
+
205
+ for from_state, symbols_to_ops in parse_table.states.items():
206
+ for symbol, op in symbols_to_ops.items():
207
+ if op[0] == Shift:
208
+ symbols_to_from_states = self._reverse_shifts.setdefault(op[1], {})
209
+ symbols_to_from_states.setdefault(symbol, set()).add(from_state)
210
+ self._symbols_to_states.setdefault(symbol, set()).add((from_state, op))
211
+
212
+ # # TODO: This approach is very wasteful.
213
+ # context_lexer = get_contextual_lexer(self)
214
+ # self._state_transition_map = {}
215
+ #
216
+ # for from_state, transitions in parse_table.states.items():
217
+ # for symbol, action in transitions.items():
218
+ # # TODO: Filter non-terminals
219
+ # if symbol not in context_lexer.root_lexer.terminals_by_name:
220
+ # continue
221
+ #
222
+ # if action[0] is Shift:
223
+ # self._state_transition_map.setdefault(
224
+ # (from_state, symbol), set()
225
+ # ).add(action[1])
226
+ # continue
227
+ #
228
+ # antecedent_state_seqs = parse_to_terminal(self, [(from_state,)], symbol)
229
+ #
230
+ # for antecedent_state_seq in antecedent_state_seqs:
231
+ # antecedent_state = antecedent_state_seq[-1]
232
+ # self._state_transition_map.setdefault(
233
+ # (from_state, symbol), set()
234
+ # ).add(antecedent_state)
235
+
236
+ def _compute_termset_fsm_info(self):
237
+ """Collect and return information about terminal symbol sets and their FSMs.
238
+
239
+ Terminal symbol sets (or "termsets") are ordered sequences of terminal
240
+ symbols that are used by each parser state. Associated with each is a
241
+ collection of FSMs for each terminal and a single parse state FSM that is
242
+ the union of each terminal's FSM.
243
+
244
+ This constructs a list of tuples containing the termset, the set of
245
+ parse states that use the termsets, parse state FSMs, and information
246
+ mapping the components of the parse state FSMs to their terminal symbol
247
+ FSMs.
248
+
249
+ """
250
+ context_lexer = get_contextual_lexer(self)
251
+ termsets_to_fsms = {}
252
+ termsets_to_parse_states: Dict[Tuple[str, ...], Set[ParseStateType]] = {}
253
+ for parse_state, lexer in context_lexer.lexers.items():
254
+ scanner = lexer.scanner
255
+ key = tuple(term.name for term in scanner.terminals)
256
+ termsets_to_fsms[key] = (scanner.fsm, scanner.fsms_to_trans_finals)
257
+ termsets_to_parse_states.setdefault(key, set()).add(parse_state)
258
+
259
+ self._termset_fsm_info = [
260
+ (
261
+ termset,
262
+ frozenset(termsets_to_parse_states[termset]),
263
+ fsm,
264
+ fsms_to_trans_finals,
265
+ )
266
+ for termset, (fsm, fsms_to_trans_finals) in termsets_to_fsms.items()
267
+ ]
268
+
269
+ @property
270
+ def termset_fsm_info(self):
271
+ if self._termset_fsm_info is None:
272
+ self._compute_termset_fsm_info()
273
+ return self._termset_fsm_info
274
+
275
+ @property
276
+ def symbols_to_states(self):
277
+ if self._symbols_to_states is None:
278
+ self._compute_maps()
279
+ return self._symbols_to_states
280
+
281
+ @property
282
+ def reverse_shifts(self):
283
+ if self._reverse_shifts is None:
284
+ self._compute_maps()
285
+ return self._reverse_shifts
286
+
287
+ # @property
288
+ # def state_transition_map(self):
289
+ # if self._state_transition_map is None:
290
+ # self._compute_maps()
291
+ # return self._state_transition_map
292
+
293
+
294
+ class PartialLALRParser(LALR_Parser):
295
+ def __init__(self, parser_conf, debug=False, strict=False):
296
+ analysis = LALR_Analyzer(
297
+ parser_conf, debug=debug if not parser_conf.deterministic else True
298
+ )
299
+ analysis.compute_lalr()
300
+ callbacks = parser_conf.callbacks
301
+
302
+ self.parser_conf = parser_conf
303
+ self._parse_table = analysis.parse_table
304
+
305
+ if parser_conf.deterministic:
306
+ old_to_new = {}
307
+
308
+ def to_tuple(v):
309
+ new = old_to_new.get(v)
310
+ if new is None:
311
+ new = tuple(sorted(v, key=lambda y: str(y)))
312
+ old_to_new[v] = new
313
+ return new
314
+
315
+ enum = sorted(
316
+ self._parse_table.states.keys(),
317
+ key=lambda x: str(sorted(x, key=lambda y: str(y))),
318
+ )
319
+
320
+ new_states = {}
321
+ for s in enum:
322
+ transitions = {
323
+ term: op if op[0] is not Shift else (op[0], to_tuple(op[1]))
324
+ for term, op in self._parse_table.states[s].items()
325
+ }
326
+ new_states[to_tuple(s)] = transitions
327
+
328
+ self._parse_table = type(self._parse_table)(
329
+ new_states,
330
+ {k: to_tuple(v) for k, v in self._parse_table.start_states.items()},
331
+ {k: to_tuple(v) for k, v in self._parse_table.end_states.items()},
332
+ )
333
+
334
+ if not debug:
335
+ self._parse_table = IntParseTable.from_ParseTable(self._parse_table)
336
+ self.states_to_rulesets = dict(
337
+ zip(self._parse_table.states.keys(), new_states.keys())
338
+ )
339
+
340
+ self.parser = PartialParser(
341
+ self._parse_table,
342
+ callbacks,
343
+ debug,
344
+ use_value_stack=parser_conf.use_value_stack,
345
+ )
346
+
347
+ @classmethod
348
+ def deserialize(cls, data, memo, callbacks, debug=False):
349
+ inst = cls.__new__(cls)
350
+ inst._parse_table = ParseTable.deserialize(data, memo)
351
+ inst.parser = PartialParser(inst._parse_table, callbacks, debug)
352
+ return inst
353
+
354
+
355
+ class PartialParserState(ParserState):
356
+ __slots__ = "use_value_stack"
357
+
358
+ def __init__(
359
+ self,
360
+ parse_conf,
361
+ lexer,
362
+ state_stack=None,
363
+ value_stack=None,
364
+ use_value_stack=False,
365
+ ):
366
+ super().__init__(
367
+ parse_conf, lexer, state_stack=state_stack, value_stack=value_stack
368
+ )
369
+ self.use_value_stack = use_value_stack
370
+
371
+ def feed_token(self, token, is_end=False):
372
+ if token.type == "partial":
373
+ # If none of the potential terminals can transition, we need to know now
374
+ current_state = self.state_stack[-1]
375
+ current_lexer = get_contextual_lexer(self.lexer).lexers[current_state]
376
+
377
+ # We have to feed the token and determine whether or not at least
378
+ # one terminal is consistent with the stack; otherwise, we'll miss
379
+ # invalid REDUCE cases.
380
+ # TODO: We should track separate parses conditional on possible
381
+ # token/symbol types, then we can coherently reuse the following
382
+ # results instead of recomputing it later.
383
+ can_transition = False
384
+ for terminal_info in token.value.terminals_and_info:
385
+ if terminal_info.terminal_name not in current_lexer.ignore_types:
386
+ test_token = Token.new_borrow_pos(
387
+ terminal_info.terminal_name, "", token
388
+ )
389
+
390
+ stack = copy(self.state_stack)
391
+ try:
392
+ self.feed_token_no_stack(test_token, is_end=is_end)
393
+ can_transition = True
394
+ break
395
+ except UnexpectedToken:
396
+ continue
397
+ finally:
398
+ self.state_stack = stack
399
+ else:
400
+ can_transition = True
401
+
402
+ if not can_transition:
403
+ expected = {
404
+ s
405
+ for s in self.parse_conf.states[current_state].keys()
406
+ if s.isupper()
407
+ }
408
+ raise UnexpectedToken(
409
+ token, expected, state=self, interactive_parser=None
410
+ )
411
+
412
+ elif self.use_value_stack:
413
+ super().feed_token(token, is_end=is_end)
414
+ else:
415
+ self.feed_token_no_stack(token, is_end=is_end)
416
+
417
+ def feed_token_no_stack(self, token, is_end=False):
418
+ """
419
+ This is a copy of `ParserState.feed_token` with all the value stack
420
+ steps removed. Since we're not exactly parsing in order to obtain a
421
+ CST or anything similar, we can avoid the growing expense of tracking
422
+ the parse tree.
423
+ """
424
+ state_stack = self.state_stack
425
+ states = self.parse_conf.states
426
+ end_state = self.parse_conf.end_state
427
+
428
+ while True:
429
+ state = state_stack[-1]
430
+ try:
431
+ action, arg = states[state][token.type]
432
+ except KeyError:
433
+ expected = {s for s in states[state].keys() if s.isupper()}
434
+ raise UnexpectedToken(
435
+ token, expected, state=self, interactive_parser=None
436
+ )
437
+
438
+ assert arg != end_state
439
+
440
+ if action is Shift:
441
+ # shift once and return
442
+ assert not is_end
443
+ state_stack.append(arg)
444
+ return
445
+ else:
446
+ # reduce+shift as many times as necessary
447
+ rule = arg
448
+ size = len(rule.expansion)
449
+ if size:
450
+ del state_stack[-size:]
451
+
452
+ _action, new_state = states[state_stack[-1]][rule.origin.name]
453
+ assert _action is Shift
454
+ state_stack.append(new_state)
455
+
456
+ if is_end and state_stack[-1] == end_state:
457
+ return
458
+
459
+ def feed_eof(self):
460
+ last_token = self.lexer.state.last_token
461
+
462
+ if last_token is None:
463
+ eof_token = self.lexer._Token("$END", "", 0, 1, 1)
464
+ else:
465
+ eof_token = Token.new_borrow_pos("$END", "", last_token)
466
+
467
+ new_token_is_legal = (
468
+ last_token is None
469
+ or last_token.type != "partial"
470
+ or any(ti.is_final for ti in last_token.value.terminals_and_info)
471
+ )
472
+ if new_token_is_legal:
473
+ self.feed_token(eof_token, is_end=True)
474
+ else:
475
+ raise UnexpectedToken(eof_token, [], state=self, interactive_parser=None)
476
+
477
+ def choices(self):
478
+ return self.parse_conf.parse_table.states[self.position]
479
+
480
+ def accepts(self):
481
+ """
482
+ Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95
483
+ Returns the set of possible tokens that will advance the parser into a new valid state.
484
+ """
485
+ accepts = set()
486
+ conf_no_callbacks = copy(self.parse_conf)
487
+ # We don't want to call callbacks here since those might have arbitrary side effects
488
+ # and are unnecessarily slow.
489
+ conf_no_callbacks.callbacks = {}
490
+ for t in self.choices():
491
+ if t.isupper(): # is terminal?
492
+ new_state = copy(self)
493
+ new_state.parse_conf = conf_no_callbacks
494
+ try:
495
+ new_state.feed_token(new_state.lexer._Token(t, ""))
496
+ except UnexpectedToken:
497
+ pass
498
+ else:
499
+ accepts.add(t)
500
+ return accepts
501
+
502
+ def __copy__(self):
503
+ return type(self)(
504
+ self.parse_conf,
505
+ copy(self.lexer),
506
+ copy(self.state_stack),
507
+ deepcopy(self.value_stack),
508
+ use_value_stack=self.use_value_stack,
509
+ )
510
+
511
+ def __repr__(self):
512
+ return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})"
513
+
514
+
515
+ class PartialParser(_Parser):
516
+ def __init__(self, parse_table, callbacks, debug=False, use_value_stack=False):
517
+ super().__init__(parse_table, callbacks, debug=debug)
518
+ self.use_value_stack = use_value_stack
519
+
520
+ def parse(
521
+ self, lexer, start, value_stack=None, state_stack=None, start_interactive=False
522
+ ):
523
+ parse_conf = ParseConf(self.parse_table, self.callbacks, start)
524
+ parser_state = PartialParserState(
525
+ parse_conf, copy(lexer), state_stack, value_stack, self.use_value_stack
526
+ )
527
+ if start_interactive:
528
+ return InteractiveParser(self, parser_state, parser_state.lexer)
529
+ return self.parse_from_state(parser_state)
530
+
531
+ def parse_from_state(self, state, last_token=None, is_end=False):
532
+ try:
533
+ token = last_token
534
+ for token in state.lexer.lex(state):
535
+ state.feed_token(token)
536
+
537
+ if is_end and (not token or token.type != "partial"):
538
+ state.feed_eof()
539
+
540
+ return state
541
+ except UnexpectedInput as e:
542
+ try:
543
+ e.interactive_parser = InteractiveParser(self, state, state.lexer)
544
+ except NameError:
545
+ pass
546
+ raise e
547
+ except Exception:
548
+ if self.debug:
549
+ print("")
550
+ print("STATE STACK DUMP")
551
+ print("----------------")
552
+ for i, s in enumerate(state.state_stack):
553
+ print("%d)" % i, s)
554
+ print("")
555
+
556
+ raise
557
+
558
+
559
+ class PartialScanner(Scanner):
560
+ @classmethod
561
+ @lru_cache
562
+ def construct_terminal_fsm(cls, terminal):
563
+ # TODO: This should really be done at the lexer/parser level so that
564
+ # the lifetime of these objects is tied to the parser itself.
565
+ regex_str = terminal.pattern.to_regexp()
566
+ pattern = interegular.parse_pattern(regex_str)
567
+ fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
568
+ return fsm, pattern.prefix_postfix
569
+
570
+ def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False):
571
+ self.terminals = terminals
572
+ self.g_regex_flags = g_regex_flags
573
+ self.use_bytes = use_bytes
574
+ self.match_whole = match_whole
575
+ self.allowed_types = {t.name for t in self.terminals}
576
+ self._mres = None
577
+
578
+ fsms = []
579
+ for t in self.terminals:
580
+ fsm, prefix_postfix = self.construct_terminal_fsm(t)
581
+
582
+ # TODO FIXME: We don't support this right now.
583
+ assert prefix_postfix == (0, 0)
584
+
585
+ fsms.append(fsm)
586
+
587
+ self.fsm, self.fsms_to_trans_finals = fsm_union(fsms)
588
+
589
+ def get_terminals_info(
590
+ self, fsm_state_seq
591
+ ) -> Tuple[Tuple[PartialTerminalInfo, ...], Tuple[PartialTerminalInfo, ...]]:
592
+ """Get the possible terminal symbols for an FSM state sequence."""
593
+ terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
594
+ final_terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
595
+ for i, (fsm_id, fsm_reads_more, in_final) in enumerate(
596
+ get_sub_fsms_from_seq(fsm_state_seq, self.fsms_to_trans_finals)
597
+ ):
598
+ terminal_name = self.terminals[fsm_id].name
599
+ info = PartialTerminalInfo(i, terminal_name, fsm_reads_more, in_final)
600
+ terminals_and_info += (info,)
601
+ if in_final:
602
+ final_terminals_and_info += (info,)
603
+
604
+ return terminals_and_info, final_terminals_and_info
605
+
606
+ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None):
607
+ """Determine an FSM match over `text` starting at `pos` and continuing `last_fsm_state_seq`."""
608
+
609
+ start_pos = pos
610
+
611
+ if last_fsm_state_seq:
612
+ assert len(last_fsm_state_seq) > 1
613
+ start_pos += len(last_fsm_state_seq) - 1
614
+ start_state = last_fsm_state_seq[-1]
615
+ else:
616
+ start_state = self.fsm.initial
617
+
618
+ text_part = text[start_pos:]
619
+
620
+ text_transitions = get_token_transition_keys(
621
+ self.fsm.fsm_info.alphabet_symbol_mapping,
622
+ self.fsm.fsm_info.alphabet_anything_value,
623
+ text_part,
624
+ )
625
+
626
+ state_seq = walk_fsm(
627
+ self.fsm,
628
+ text_transitions,
629
+ start_state,
630
+ full_match=self.match_whole,
631
+ )
632
+
633
+ if not state_seq:
634
+ return None
635
+
636
+ if last_fsm_state_seq:
637
+ res = last_fsm_state_seq + tuple(state_seq)
638
+ else:
639
+ res = (start_state,) + tuple(state_seq)
640
+
641
+ return res
642
+
643
+
644
+ class PartialContextualLexer(ContextualLexer):
645
+ def __init__(self, conf: "LexerConf", states, always_accept=()):
646
+ terminals = list(conf.terminals)
647
+ terminals_by_name = conf.terminals_by_name
648
+
649
+ trad_conf = copy(conf)
650
+ trad_conf.terminals = terminals
651
+
652
+ lexer_by_symbols: Dict = {}
653
+ self.lexers = {}
654
+ for state, accepts in states.items():
655
+ key = frozenset(accepts)
656
+ try:
657
+ lexer = lexer_by_symbols[key]
658
+ except KeyError:
659
+ accepts = set(accepts) | set(conf.ignore) | set(always_accept)
660
+ lexer_conf = copy(trad_conf)
661
+ lexer_conf.terminals = [
662
+ terminals_by_name[n] for n in accepts if n in terminals_by_name
663
+ ]
664
+ if not lexer_conf.terminals:
665
+ continue
666
+ lexer = PartialBasicLexer(lexer_conf)
667
+ lexer_by_symbols[key] = lexer
668
+
669
+ self.lexers[state] = lexer
670
+
671
+ assert trad_conf.terminals is terminals
672
+ self.root_lexer = PartialBasicLexer(trad_conf)
673
+
674
+ def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]:
675
+ try:
676
+ while True:
677
+ lexer = self.lexers[parser_state.position]
678
+ next_tok = lexer.next_token(lexer_state, parser_state)
679
+ yield next_tok
680
+ except EOFError:
681
+ pass
682
+ except KeyError:
683
+ if len(lexer_state.text) > lexer_state.line_ctr.char_pos:
684
+ raise UnexpectedCharacters(
685
+ lexer_state.text,
686
+ lexer_state.line_ctr.char_pos,
687
+ lexer_state.line_ctr.line,
688
+ lexer_state.line_ctr.column,
689
+ allowed=False,
690
+ token_history=lexer_state.last_token and [lexer_state.last_token],
691
+ state=parser_state,
692
+ terminals_by_name=self.root_lexer.terminals,
693
+ )
694
+
695
+
696
+ class PartialBasicLexer(BasicLexer):
697
+ def __init__(self, conf: "LexerConf"):
698
+ super().__init__(conf)
699
+ # Eagerly construct the scanner
700
+ self._build_scanner()
701
+
702
+ def _build_scanner(self):
703
+ # This seems incredibly convoluted: `lark` creates callback-triggered
704
+ # nested scanners for regex-defined terminals that overlap with
705
+ # string-defined terminals when both types of terminals have the same
706
+ # priority. Unless I'm missing something important, why not simply
707
+ # reorder the terminals so that the string-defined ones come before the
708
+ # regex-defined ones?
709
+ terminals, self.callback = _create_unless(
710
+ self.terminals, self.g_regex_flags, self.re, self.use_bytes
711
+ )
712
+
713
+ # We can't let people arbitrarily mess with the scanning process.
714
+ assert not self.user_callbacks
715
+ # for type_, f in self.user_callbacks.items():
716
+ # if type_ in self.callback:
717
+ # # Already a callback there, probably UnlessCallback
718
+ # self.callback[type_] = CallChain(
719
+ # self.callback[type_], f, lambda t: t.type == type_
720
+ # )
721
+ # else:
722
+ # self.callback[type_] = f
723
+
724
+ # We used the "callback" results to reorder the terminals (see the
725
+ # comments above).
726
+ for terminal_name, callback in self.callback.items():
727
+ terminal = self.terminals_by_name[terminal_name]
728
+ for sub_terminal in callback.scanner.terminals:
729
+ self.terminals.remove(sub_terminal)
730
+ idx = self.terminals.index(terminal)
731
+ self.terminals.insert(idx, sub_terminal)
732
+
733
+ self._scanner = PartialScanner(
734
+ self.terminals, self.g_regex_flags, self.re, self.use_bytes
735
+ )
736
+
737
+ def match(self, text, pos, last_fsm_state_seq=None):
738
+ return self.scanner.match(text, pos, last_fsm_state_seq)
739
+
740
+ def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:
741
+ last_token = lex_state.last_token
742
+
743
+ last_fsm_state_seq = None
744
+ if last_token and last_token.type == "partial":
745
+ # Continue from last partial lexer state
746
+ last_fsm_state_seq = last_token.value.fsm_state_seq
747
+
748
+ line_ctr = lex_state.line_ctr
749
+ end_pos = line_ctr.char_pos + (
750
+ len(last_fsm_state_seq) - 1 if last_fsm_state_seq else 0
751
+ )
752
+ while end_pos < len(lex_state.text):
753
+ res = self.match(lex_state.text, line_ctr.char_pos, last_fsm_state_seq)
754
+
755
+ if not res:
756
+ if (
757
+ not last_fsm_state_seq
758
+ or last_fsm_state_seq[-1] not in self.scanner.fsm.finals
759
+ ):
760
+ allowed = self.scanner.allowed_types - self.ignore_types
761
+ if not allowed:
762
+ allowed = {"<END-OF-FILE>"}
763
+ raise UnexpectedCharacters(
764
+ lex_state.text,
765
+ line_ctr.char_pos,
766
+ line_ctr.line,
767
+ line_ctr.column,
768
+ allowed=allowed,
769
+ token_history=lex_state.last_token and [lex_state.last_token],
770
+ state=parser_state,
771
+ terminals_by_name=self.terminals_by_name,
772
+ )
773
+
774
+ # The partial match might be complete now
775
+ fsm_state_seq = last_token.value.fsm_state_seq
776
+ terminals_and_info = last_token.value.terminals_and_info
777
+ final_terminals_and_info = last_token.value.final_terminals_and_info
778
+ else:
779
+ fsm_state_seq = res
780
+ (
781
+ terminals_and_info,
782
+ final_terminals_and_info,
783
+ ) = self.scanner.get_terminals_info(fsm_state_seq)
784
+
785
+ priority_terminal_info = (
786
+ final_terminals_and_info[0]
787
+ if final_terminals_and_info
788
+ else terminals_and_info[0]
789
+ )
790
+
791
+ is_not_finished = (
792
+ not priority_terminal_info.is_final
793
+ or priority_terminal_info.can_transition
794
+ or len(terminals_and_info) > 1
795
+ )
796
+
797
+ start_pos = line_ctr.char_pos
798
+ end_pos = start_pos + len(fsm_state_seq) - 1
799
+
800
+ if end_pos >= len(lex_state.text) and is_not_finished:
801
+ type_name = "partial"
802
+ token_value = PartialTokensInfo(
803
+ fsm_state_seq,
804
+ is_not_finished,
805
+ terminals_and_info,
806
+ final_terminals_and_info,
807
+ )
808
+ # Don't update the line counter states until we've finished
809
+ value = ""
810
+ else:
811
+ type_name = priority_terminal_info.terminal_name
812
+ # The token value should contain all partial scan parts in this
813
+ # case
814
+ value = token_value = lex_state.text[start_pos:end_pos]
815
+
816
+ assert isinstance(self.callback, Dict)
817
+
818
+ if type_name not in self.ignore_types:
819
+ t = Token(
820
+ type_name,
821
+ token_value,
822
+ line_ctr.char_pos,
823
+ line_ctr.line,
824
+ line_ctr.column,
825
+ )
826
+
827
+ line_ctr.feed(value, type_name in self.newline_types)
828
+
829
+ t.end_line = line_ctr.line
830
+ t.end_column = line_ctr.column
831
+ t.end_pos = line_ctr.char_pos
832
+ if t.type in self.callback:
833
+ t = self.callback[t.type](t)
834
+ if not isinstance(t, Token):
835
+ raise LexError(
836
+ "Callbacks must return a token (returned %r)" % t
837
+ )
838
+ lex_state.last_token = t
839
+ return t
840
+
841
+ if type_name in self.callback:
842
+ t2 = Token(
843
+ type_name, value, line_ctr.char_pos, line_ctr.line, line_ctr.column
844
+ )
845
+ self.callback[type_name](t2)
846
+
847
+ line_ctr.feed(value, type_name in self.newline_types)
848
+
849
+ last_fsm_state_seq = None
850
+
851
+ raise EOFError(self)
852
+
853
+
854
+ class PartialIndenter(Indenter):
855
+ """An `Indenter` that doesn't reset its state every time `process` is called."""
856
+
857
+ def process(self, stream):
858
+ return self._process(stream)
859
+
860
+ def _process(self, stream):
861
+ for token in stream:
862
+ # These were previously *after* the `yield`, but that makes the
863
+ # state tracking unnecessarily convoluted.
864
+ if token.type in self.OPEN_PAREN_types:
865
+ self.paren_level += 1
866
+ elif token.type in self.CLOSE_PAREN_types:
867
+ self.paren_level -= 1
868
+ if self.paren_level < 0:
869
+ raise UnexpectedToken(token, [])
870
+
871
+ if token.type == self.NL_type:
872
+ yield from self.handle_NL(token)
873
+ else:
874
+ yield token
875
+
876
+ # TODO: What do we want to do here?
877
+ # while len(self.indent_level) > 1:
878
+ # self.indent_level.pop()
879
+ # yield Token(self.DEDENT_type, "")
880
+
881
+ def accepts_token_type(self, token_type):
882
+ if token_type in self.CLOSE_PAREN_types and self.paren_level - 1 < 0:
883
+ return False
884
+
885
+ # TODO:
886
+ # if token_type == self.NL_type and self.paren_level == 0:
887
+ # ...
888
+ # return False
889
+
890
+ return True
891
+
892
+ def __copy__(self):
893
+ res = type(self)()
894
+ res.paren_level = self.paren_level
895
+ res.indent_level = copy(self.indent_level)
896
+ return res
897
+
898
+ def __repr__(self):
899
+ return f"{type(self).__name__}(paren_level={self.paren_level!r}, indent_level={self.indent_level!r})"
900
+
901
+
902
+ class PartialPythonIndenter(PartialIndenter):
903
+ NL_type = "_NEWLINE"
904
+ OPEN_PAREN_types = ["LPAR", "LSQB", "LBRACE"]
905
+ CLOSE_PAREN_types = ["RPAR", "RSQB", "RBRACE"]
906
+ INDENT_type = "_INDENT"
907
+ DEDENT_type = "_DEDENT"
908
+ tab_len = 8
909
+
910
+
911
+ def get_contextual_lexer(x: Union[PartialLexerThread, PartialParsingFrontend]):
912
+ if isinstance(x.lexer, ContextualLexer):
913
+ return x.lexer
914
+ else:
915
+ return x.lexer.lexer
916
+
917
+
918
+ def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]:
919
+ """Construct a ``dict`` mapping terminal symbol names to their finite state machines."""
920
+
921
+ symbol_names_and_fsms = {}
922
+ for terminal in lp.terminals:
923
+ pattern = interegular.parse_pattern(terminal.pattern.to_regexp())
924
+ # TODO: Use `pyparser.terminals[0].pattern.flags`?
925
+ try:
926
+ fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
927
+ except Unsupported:
928
+ fsm = None
929
+
930
+ symbol_names_and_fsms[terminal.name] = fsm
931
+
932
+ return symbol_names_and_fsms
933
+
934
+
935
+ def fsm_union(
936
+ fsms: Sequence[FSM],
937
+ ) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
938
+ """Construct an FSM representing the union of the FSMs in `fsms`.
939
+
940
+ This is an updated version of `interegular.fsm.FSM.union` made to return an
941
+ extra map of component FSMs to the sets of state transitions that
942
+ correspond to them in the new FSM.
943
+
944
+ """
945
+
946
+ alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
947
+
948
+ indexed_fsms = tuple(enumerate(fsms))
949
+
950
+ initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
951
+
952
+ # Dedicated function accepting a "superset" and returning the next
953
+ # "superset" obtained by following this transition in the new FSM
954
+ def follow(current_state, new_transition: int):
955
+ next = {}
956
+ for i, f in indexed_fsms:
957
+ old_transition = new_to_old[i][new_transition]
958
+ if (
959
+ i in current_state
960
+ and current_state[i] in f.map
961
+ and old_transition in f.map[current_state[i]]
962
+ ):
963
+ next[i] = f.map[current_state[i]][old_transition]
964
+ if not next:
965
+ raise OblivionError
966
+ return next
967
+
968
+ states = [initial]
969
+ finals: Set[int] = set()
970
+ map: Dict[int, Dict[int, int]] = {}
971
+
972
+ # Map component FSMs to their new state-to-state transitions, finals, and a
973
+ # map translating component FSM states to aggregate FSM states
974
+ fsms_to_trans_finals: Dict[
975
+ int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
976
+ ] = {}
977
+
978
+ i = 0
979
+ while i < len(states):
980
+ state = states[i]
981
+
982
+ # Add to the finals of the aggregate FSM whenever we hit a final in a
983
+ # component FSM
984
+ if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
985
+ finals.add(i)
986
+
987
+ # Compute the map for this state
988
+ map[i] = {}
989
+ for transition in alphabet.by_transition:
990
+ try:
991
+ next = follow(state, transition)
992
+ except OblivionError:
993
+ # Reached an oblivion state; don't list it
994
+ continue
995
+ else:
996
+ try:
997
+ # TODO: Seems like this could--and should--be avoided
998
+ j = states.index(next)
999
+ except ValueError:
1000
+ j = len(states)
1001
+ states.append(next)
1002
+
1003
+ map[i][transition] = j
1004
+
1005
+ for fsm_id, fsm_state in next.items():
1006
+ (
1007
+ fsm_transitions,
1008
+ fsm_finals,
1009
+ fsm_old_to_new,
1010
+ ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
1011
+ old_from = state[fsm_id]
1012
+ old_to = fsm_state
1013
+ fsm_old_to_new.setdefault(old_from, set()).add(i)
1014
+ fsm_old_to_new.setdefault(old_to, set()).add(j)
1015
+ fsm_transitions.add((i, j))
1016
+ if fsm_state in fsms[fsm_id].finals:
1017
+ fsm_finals.add(j)
1018
+
1019
+ i += 1
1020
+
1021
+ fsm = FSM(
1022
+ alphabet=alphabet,
1023
+ states=range(len(states)),
1024
+ initial=0,
1025
+ finals=finals,
1026
+ map=map,
1027
+ __no_validation__=True,
1028
+ )
1029
+
1030
+ fsm, old_to_new_states = make_deterministic_fsm(fsm)
1031
+ _fsms_to_trans_finals = {
1032
+ fsm_id: (
1033
+ {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
1034
+ {old_to_new_states[s] for s in finals},
1035
+ {
1036
+ old_state: {old_to_new_states[new_state] for new_state in new_states}
1037
+ for old_state, new_states in old_to_new.items()
1038
+ },
1039
+ )
1040
+ for fsm_id, (transitions, finals, old_to_new) in sorted(
1041
+ fsms_to_trans_finals.items(), key=lambda x: x[0]
1042
+ )
1043
+ }
1044
+
1045
+ return (
1046
+ fsm,
1047
+ _fsms_to_trans_finals,
1048
+ )
1049
+
1050
+
1051
+ def get_sub_fsms_from_seq(
1052
+ state_seq: Sequence[int],
1053
+ fsms_to_trans_finals: Dict[
1054
+ int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
1055
+ ],
1056
+ ) -> Generator[Tuple[int, bool, bool], None, None]:
1057
+ """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
1058
+
1059
+ Parameters
1060
+ ----------
1061
+ state_seq
1062
+ A state sequence.
1063
+ fsms_to_trans_finals
1064
+ A map from FSM indices to tuples containing sets of their state transitions
1065
+ and sets of the final/accept states.
1066
+
1067
+ Returns
1068
+ -------
1069
+ A generator returning tuples containing each sub-FSM index (in the order
1070
+ they were union-ed to construct `fsm`) and booleans indicating whether or
1071
+ not there is another valid transition from the last state in the sequence
1072
+ for the associated sub-FSM (i.e. if the FSM can continue
1073
+ accepting/matching) and whether or not the sequence ends in a final state
1074
+ of the sub-FSM.
1075
+ """
1076
+ state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
1077
+ last_fsm_state = state_seq[-1]
1078
+ yield from (
1079
+ (
1080
+ # The sub-FMS index
1081
+ fsm_idx,
1082
+ # Is there another possible transition in this sub-FSM?
1083
+ any(last_fsm_state == from_s for (from_s, to_s) in transitions),
1084
+ # Is this sub-FSM in a final state?
1085
+ state_seq[-1] in finals,
1086
+ )
1087
+ for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
1088
+ if state_seq_transitions.issubset(transitions)
1089
+ )
1090
+
1091
+
1092
+ def walk_fsm(
1093
+ fsm: BetterFSM,
1094
+ token_transition_keys: Sequence[int],
1095
+ start_state: int,
1096
+ full_match: bool = True,
1097
+ ) -> List[int]:
1098
+ fsm_finals = fsm.finals
1099
+
1100
+ state = start_state
1101
+ accepted_states: List[int] = []
1102
+ last_final_idx: int = 0
1103
+
1104
+ fsm_transitions = fsm.flat_transition_map
1105
+
1106
+ # Iterate over token transition key sequence. The transition key
1107
+ # sequence represents the FSM traversal rules of the tokens symbols.
1108
+ for i, trans_key in enumerate(token_transition_keys):
1109
+ new_state = fsm_transitions.get((state, trans_key))
1110
+
1111
+ if new_state is None:
1112
+ if not full_match and last_final_idx > 0:
1113
+ return accepted_states[:last_final_idx]
1114
+
1115
+ return []
1116
+
1117
+ state = new_state
1118
+
1119
+ if state in fsm_finals:
1120
+ last_final_idx = i + 1
1121
+
1122
+ accepted_states.append(state)
1123
+
1124
+ if full_match and last_final_idx - 1 != i:
1125
+ return []
1126
+
1127
+ return accepted_states
.venv/lib/python3.11/site-packages/outlines/fsm/types.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from enum import EnumMeta
3
+ from typing import Any, Protocol, Tuple, Type
4
+
5
+ from typing_extensions import _AnnotatedAlias, get_args
6
+
7
+ INTEGER = r"[+-]?(0|[1-9][0-9]*)"
8
+ BOOLEAN = "(True|False)"
9
+ FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?"
10
+ DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])"
11
+ TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])"
12
+ DATETIME = rf"({DATE})(\s)({TIME})"
13
+
14
+
15
+ class FormatFunction(Protocol):
16
+ def __call__(self, sequence: str) -> Any:
17
+ ...
18
+
19
+
20
+ def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
21
+ # If it is a custom type
22
+ if isinstance(python_type, _AnnotatedAlias):
23
+ json_schema = get_args(python_type)[1].json_schema
24
+ type_class = get_args(python_type)[0]
25
+
26
+ custom_regex_str = json_schema["pattern"]
27
+
28
+ def custom_format_fn(sequence: str) -> Any:
29
+ return type_class(sequence)
30
+
31
+ return custom_regex_str, custom_format_fn
32
+
33
+ if isinstance(python_type, EnumMeta):
34
+ values = python_type.__members__.keys()
35
+ enum_regex_str: str = "(" + "|".join(values) + ")"
36
+
37
+ def enum_format_fn(sequence: str) -> str:
38
+ return str(sequence)
39
+
40
+ return enum_regex_str, enum_format_fn
41
+
42
+ if python_type == float:
43
+
44
+ def float_format_fn(sequence: str) -> float:
45
+ return float(sequence)
46
+
47
+ return FLOAT, float_format_fn
48
+ elif python_type == int:
49
+
50
+ def int_format_fn(sequence: str) -> int:
51
+ return int(sequence)
52
+
53
+ return INTEGER, int_format_fn
54
+ elif python_type == bool:
55
+
56
+ def bool_format_fn(sequence: str) -> bool:
57
+ return bool(sequence)
58
+
59
+ return BOOLEAN, bool_format_fn
60
+ elif python_type == datetime.date:
61
+
62
+ def date_format_fn(sequence: str) -> datetime.date:
63
+ return datetime.datetime.strptime(sequence, "%Y-%m-%d").date()
64
+
65
+ return DATE, date_format_fn
66
+ elif python_type == datetime.time:
67
+
68
+ def time_format_fn(sequence: str) -> datetime.time:
69
+ return datetime.datetime.strptime(sequence, "%H:%M:%S").time()
70
+
71
+ return TIME, time_format_fn
72
+ elif python_type == datetime.datetime:
73
+
74
+ def datetime_format_fn(sequence: str) -> datetime.datetime:
75
+ return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")
76
+
77
+ return DATETIME, datetime_format_fn
78
+ else:
79
+ raise NotImplementedError(
80
+ f"The Python type {python_type} is not supported. Please open an issue."
81
+ )
.venv/lib/python3.11/site-packages/outlines/function.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
4
+
5
+ import requests
6
+
7
+ from outlines import generate, models
8
+
9
+ if TYPE_CHECKING:
10
+ from outlines.generate.api import SequenceGenerator
11
+ from outlines.prompts import Prompt
12
+
13
+
14
+ @dataclass
15
+ class Function:
16
+ """Represents an Outlines function.
17
+
18
+ Functions are a convenient way to encapsulate a prompt template, a language
19
+ model and a Pydantic model that define the output structure. Once defined,
20
+ the function can be called with arguments that will be used to render the
21
+ prompt template.
22
+
23
+ """
24
+
25
+ prompt_template: "Prompt"
26
+ schema: Union[str, Callable, object]
27
+ model_name: str
28
+ generator: Optional["SequenceGenerator"] = None
29
+
30
+ @classmethod
31
+ def from_github(cls, program_path: str, function_name: str = "fn"):
32
+ """Load a function stored on GitHub"""
33
+ program_content = download_from_github(program_path)
34
+ function = extract_function_from_file(program_content, function_name)
35
+
36
+ return function
37
+
38
+ def init_generator(self):
39
+ """Load the model and initialize the generator."""
40
+ model = models.transformers(self.model_name)
41
+ self.generator = generate.json(model, self.schema)
42
+
43
+ def __call__(self, *args, **kwargs):
44
+ """Call the function.
45
+
46
+ .. warning::
47
+
48
+ This currently does not support batching.
49
+
50
+ Parameters
51
+ ----------
52
+ args
53
+ Values to pass to the prompt template as positional arguments.
54
+ kwargs
55
+ Values to pass to the prompt template as keyword arguments.
56
+
57
+ """
58
+ if self.generator is None:
59
+ self.init_generator()
60
+
61
+ prompt = self.prompt_template(*args, **kwargs)
62
+ return self.generator(prompt)
63
+
64
+
65
+ def download_from_github(short_path: str):
66
+ """Download the file in which the function is stored on GitHub."""
67
+ GITHUB_BASE_URL = "https://raw.githubusercontent.com"
68
+ BRANCH = "main"
69
+
70
+ path = short_path.split("/")
71
+ if len(path) < 3:
72
+ raise ValueError(
73
+ "Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}."
74
+ )
75
+ elif short_path[-3:] == ".py":
76
+ raise ValueError("Do not append the `.py` extension to the program name.")
77
+
78
+ username = path[0]
79
+ repo = path[1]
80
+ path_to_file = path[2:]
81
+
82
+ url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py"
83
+ result = requests.get(url)
84
+
85
+ if result.status_code == 200:
86
+ return result.text
87
+ elif result.status_code == 404:
88
+ raise ValueError(
89
+ f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly."
90
+ )
91
+ else:
92
+ result.raise_for_status()
93
+
94
+
95
+ def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]:
96
+ """Extract a function object from a downloaded file."""
97
+
98
+ spec = importlib.util.spec_from_loader(
99
+ "outlines_function", loader=None, origin="github"
100
+ )
101
+ if spec is not None:
102
+ module = importlib.util.module_from_spec(spec)
103
+ exec(content, module.__dict__)
104
+
105
+ try:
106
+ fn = getattr(module, function_name)
107
+ except AttributeError:
108
+ raise AttributeError(
109
+ "Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct."
110
+ )
111
+
112
+ if not isinstance(fn, module.outlines.Function):
113
+ raise TypeError(
114
+ f"The `{function_name}` variable in the program must be an instance of `outlines.Function`"
115
+ )
116
+
117
+ return fn
.venv/lib/python3.11/site-packages/outlines/generate/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .api import SequenceGenerator
2
+ from .cfg import cfg
3
+ from .choice import choice
4
+ from .format import format
5
+ from .fsm import fsm
6
+ from .json import json
7
+ from .regex import regex
8
+ from .text import text
.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc ADDED
Binary file (3.4 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc ADDED
Binary file (6.05 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/generate/api.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from copy import copy
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
5
+
6
+ from outlines.generate.generator import sequence_generator
7
+ from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler
8
+
9
+ if TYPE_CHECKING:
10
+ import torch
11
+
12
+ FormattedOutput = Union[
13
+ str, int, float, bool, datetime.date, datetime.time, datetime.datetime
14
+ ]
15
+
16
+
17
+ class SequenceGenerator:
18
+ def __init__(
19
+ self,
20
+ fsm,
21
+ model,
22
+ sampler,
23
+ device,
24
+ ):
25
+ self.fsm = fsm
26
+ self.model = model
27
+ self.sampler = sampler
28
+ self.tokenizer = model.tokenizer
29
+ self.device = device
30
+ self.num_samples = sampler.samples
31
+
32
+ def get_generated_token_ids(
33
+ self,
34
+ prompt_token_ids: "torch.Tensor",
35
+ token_ids: "torch.Tensor",
36
+ ) -> List["torch.Tensor"]:
37
+ """Get the tokens generated so far.
38
+
39
+ Parameters
40
+ ----------
41
+ prompt_token_ids
42
+ Tensor that contains the token ids of the sequences' prompts.
43
+ token_ids
44
+ The generated token ids.
45
+
46
+ Returns
47
+ -------
48
+ A tensor that contains the token ids that have been generated so far.
49
+
50
+ """
51
+ prompt_lengths = [len(prompt) for prompt in prompt_token_ids]
52
+ token_ids = [
53
+ cur_token_ids[length:]
54
+ for cur_token_ids, length in zip(token_ids, prompt_lengths)
55
+ ]
56
+
57
+ return token_ids
58
+
59
+ def is_stop_sequence_found(
60
+ self, generated_sequences: List[str], stop_sequences: List[str]
61
+ ) -> bool:
62
+ """Determine whether one of the stop sequences has been generated.
63
+
64
+ Parameters
65
+ ----------
66
+ generated_sequences
67
+ The list of sequences generated so far.
68
+ stop_sequences
69
+ The list that contains the sequence which stop the generation when
70
+ found.
71
+
72
+ Returns
73
+ -------
74
+ True if at least one of the stop sequences has been found in each generated
75
+ sequence.
76
+
77
+ """
78
+ return all(
79
+ [
80
+ any([seq in generated for seq in stop_sequences])
81
+ for generated in generated_sequences
82
+ ]
83
+ )
84
+
85
+ def strip_stop_sequences(
86
+ self, sequence: str, stop_sequences: Optional[List[str]]
87
+ ) -> str:
88
+ """Remove the stop sequences from the generated sequences.
89
+
90
+ Parameters
91
+ ----------
92
+ sequence
93
+ One of the generated sequences.
94
+ stop_sequences
95
+ The list that contains the sequence which stop the generation when
96
+ found.
97
+
98
+ """
99
+ if stop_sequences:
100
+ match_indexes = [sequence.find(seq) for seq in stop_sequences]
101
+ if any([index != -1 for index in match_indexes]):
102
+ # select the stop_sequence that is found first in the sequence
103
+ min_match_index_value = min([i for i in match_indexes if i != -1])
104
+ min_match_index_pos = match_indexes.index(min_match_index_value)
105
+ sequence = sequence[
106
+ : match_indexes[min_match_index_pos]
107
+ + len(stop_sequences[min_match_index_pos])
108
+ ]
109
+
110
+ return sequence
111
+
112
+ def format_sequence(self, sequence: str) -> FormattedOutput:
113
+ """Translate the generated sequence to another type.
114
+
115
+ This method is for instance overridden when generating JSON to either
116
+ return a dictionnary or a Pydantic model.
117
+
118
+ Parameters
119
+ ----------
120
+ sequence
121
+ A generated sequences.
122
+
123
+ Returns
124
+ -------
125
+ The formatted sequence.
126
+
127
+ """
128
+ return sequence
129
+
130
+ def __call__(
131
+ self,
132
+ prompts: Union[str, List[str]],
133
+ max_tokens: Optional[int] = None,
134
+ stop_at: Optional[Union[str, List[str]]] = None,
135
+ rng: Optional["torch.Generator"] = None,
136
+ ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
137
+ """Generate the full text sequence.
138
+
139
+ Since `SequenceGenerator.stream` calls the tokenizer at every step this
140
+ method loops over the generator returned by `sequence_generator` itself
141
+ so the tokenizer is called only once after all token ids have been
142
+ generated.
143
+
144
+ Parameters
145
+ ----------
146
+ prompts
147
+ A string or list of strings that are passed to the model before
148
+ generating the first token.
149
+ max_tokens
150
+ An integer representing maximum number of tokens that will be generated
151
+ (per prompt)
152
+ stop_at
153
+ A string or list of strings at which the text generated will stop
154
+ rng
155
+ The random number generator. Defaults to a non-seeded `torch.Generator`
156
+ instance.
157
+
158
+ Returns
159
+ -------
160
+ The generation(s), potentially cast to another type.
161
+ """
162
+ import torch
163
+
164
+ if isinstance(prompts, str):
165
+ prompts = [prompts]
166
+
167
+ if isinstance(stop_at, str):
168
+ stop_at = [stop_at]
169
+
170
+ stop_sequences = stop_at
171
+ num_samples = self.num_samples
172
+
173
+ if rng is None:
174
+ rng = torch.Generator(device=self.device)
175
+ rng.seed()
176
+
177
+ prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
178
+ prompt_token_ids = prompt_token_ids.to(self.device)
179
+ attention_masks = attention_masks.to(self.device)
180
+
181
+ # To draw multiple samples we repeat the prompt as many times
182
+ # as there are samples. We copy the FSMs and initialize the
183
+ # FSM states.
184
+ num_samples = self.num_samples
185
+ batch_size = len(prompts)
186
+
187
+ prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
188
+ attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
189
+ fsm_states = [0 for _ in range(batch_size * num_samples)]
190
+ fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
191
+ weights = torch.zeros(
192
+ (batch_size * num_samples), dtype=torch.float, device=self.device
193
+ )
194
+
195
+ states = sequence_generator(
196
+ self.model,
197
+ self.sampler,
198
+ fsms,
199
+ prompt_token_ids,
200
+ weights,
201
+ attention_masks,
202
+ fsm_states,
203
+ rng=rng,
204
+ )
205
+
206
+ while True:
207
+ try:
208
+ last_state = next(states)
209
+ if max_tokens or stop_sequences:
210
+ token_ids = last_state.token_ids
211
+ generated_token_ids = self.get_generated_token_ids(
212
+ prompt_token_ids, token_ids
213
+ )
214
+ if max_tokens and len(generated_token_ids[0]) >= max_tokens:
215
+ break
216
+ if stop_sequences and self.is_stop_sequence_found(
217
+ self.tokenizer.decode(generated_token_ids), stop_sequences
218
+ ):
219
+ break
220
+ except StopIteration:
221
+ break
222
+
223
+ token_ids = last_state.token_ids
224
+ generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)
225
+
226
+ generated = self.tokenizer.decode(generated_token_ids)
227
+ stripped = [
228
+ self.strip_stop_sequences(sequence, stop_sequences)
229
+ for sequence in generated
230
+ ]
231
+ formatted = [self.format_sequence(sequence) for sequence in stripped]
232
+
233
+ # We reshape the output to (batch_size, sample_size)
234
+ output: List[List[FormattedOutput]] = list()
235
+ for i in range(0, batch_size * num_samples, num_samples):
236
+ output.append(formatted[i : i + num_samples])
237
+
238
+ # We remove leading dimensions for the output
239
+ if batch_size == 1 and num_samples == 1:
240
+ return output[0][0]
241
+ elif batch_size == 1:
242
+ return output[0]
243
+ elif num_samples == 1:
244
+ return [samples[0] for samples in output]
245
+ else:
246
+ return output
247
+
248
+ def stream(
249
+ self,
250
+ prompts: Union[str, List[str]],
251
+ max_tokens: Optional[int] = None,
252
+ stop_at: Optional[Union[str, List[str]]] = None,
253
+ rng: Optional["torch.Generator"] = None,
254
+ ) -> Iterator[Union[List[str], str, List[List[str]]]]:
255
+ """Generate the text sequence one token at a time.
256
+
257
+ Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
258
+ choice but to decode the generated token ids at each step and compare the
259
+ current decoded strings to the previously decoded strings.
260
+
261
+ Parameters
262
+ ----------
263
+ prompts
264
+ A string or list of strings that are passed to the model before
265
+ generating the first token.
266
+ max_tokens
267
+ An integer representing maximum number of tokens that will be generated
268
+ (per prompt)
269
+ stop_at
270
+ A string or list of strings at which the text generated will stop
271
+ rng
272
+ The random number generator. Defaults to a non-seeded `torch.Generator`
273
+ instance.
274
+
275
+ Returns
276
+ -------
277
+ A string or list of strings that contain the generated text.
278
+
279
+ """
280
+ import torch
281
+
282
+ if isinstance(prompts, str):
283
+ prompts = [prompts]
284
+
285
+ if isinstance(stop_at, str):
286
+ stop_at = [stop_at]
287
+
288
+ stop_sequences = stop_at
289
+ num_samples = self.num_samples
290
+
291
+ prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
292
+ prompt_token_ids = prompt_token_ids.to(self.device)
293
+ attention_masks = attention_masks.to(prompt_token_ids.device)
294
+
295
+ # To draw multiple samples we repeat the prompt as many times
296
+ # as there are samples. We copy the FSMs and initialize the
297
+ # FSM states.
298
+ num_samples = self.num_samples
299
+ batch_size = len(prompts)
300
+
301
+ prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
302
+ attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
303
+ fsm_states = [0 for _ in range(batch_size * num_samples)]
304
+ fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
305
+ weights = torch.zeros(
306
+ (batch_size * num_samples),
307
+ dtype=torch.float,
308
+ device=prompt_token_ids.device,
309
+ )
310
+
311
+ if rng is None:
312
+ rng = torch.Generator(device=prompt_token_ids.device)
313
+ rng.seed()
314
+
315
+ states = sequence_generator(
316
+ self.model,
317
+ self.sampler,
318
+ fsms,
319
+ prompt_token_ids,
320
+ weights,
321
+ attention_masks,
322
+ fsm_states,
323
+ rng=rng,
324
+ )
325
+
326
+ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
327
+ previously_generated_sequences = [
328
+ "" for _ in range(batch_size)
329
+ ] * num_samples
330
+ num_generated = 0
331
+ is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
332
+ while True:
333
+ if (max_tokens and num_generated >= max_tokens) or all(
334
+ is_stop_at_reached
335
+ ):
336
+ return
337
+ try:
338
+ sequence = next(states)
339
+ num_generated += 1
340
+ except StopIteration:
341
+ return
342
+ generated_token_ids = sequence.token_ids[:, -num_generated:]
343
+ generated_sequences = self.tokenizer.decode(generated_token_ids)
344
+ if stop_sequences:
345
+ is_stop_at_reached = [
346
+ stop
347
+ or self.is_stop_sequence_found(
348
+ [generated_sequence], stop_sequences
349
+ )
350
+ for generated_sequence, stop in zip(
351
+ generated_sequences, is_stop_at_reached
352
+ )
353
+ ]
354
+
355
+ generated_sequences = [
356
+ self.format_sequence(
357
+ self.strip_stop_sequences(sequence, stop_sequences)
358
+ )
359
+ if stop
360
+ else sequence
361
+ for sequence, stop in zip(
362
+ generated_sequences, is_stop_at_reached
363
+ )
364
+ ]
365
+ next_tokens = [
366
+ token[len(sequence) :]
367
+ for token, sequence, stop in zip(
368
+ generated_sequences,
369
+ previously_generated_sequences,
370
+ is_stop_at_reached,
371
+ )
372
+ ]
373
+ previously_generated_sequences = generated_sequences
374
+ # We reshape the output to (batch_size, sample_size)
375
+ output: List[List[str]] = list()
376
+ for i in range(0, batch_size * num_samples, num_samples):
377
+ output.append(next_tokens[i : i + num_samples])
378
+
379
+ # We remove leading dimensions for the output
380
+ if batch_size == 1 and num_samples == 1:
381
+ yield output[0][0]
382
+ elif batch_size == 1:
383
+ yield output[0]
384
+ elif num_samples == 1:
385
+ yield [samples[0] for samples in output]
386
+ else:
387
+ yield output
388
+
389
+ return token_generator()
390
+
391
+
392
+ @dataclass(frozen=True)
393
+ class GenerationParameters:
394
+ """Generation parameters used in Outlines' public API."""
395
+
396
+ max_tokens: Optional[int]
397
+ stop_at: Optional[Union[str, List[str]]]
398
+ seed: Optional[int]
399
+
400
+
401
+ @dataclass(frozen=True)
402
+ class SamplingParameters:
403
+ """Sampling parameters available in Outlines."""
404
+
405
+ sampler: str
406
+ num_samples: int = 1
407
+ top_p: Optional[float] = None
408
+ top_k: Optional[int] = None
409
+ temperature: Optional[float] = None
410
+
411
+
412
+ class SequenceGeneratorAdapter:
413
+ """Class used to unify the interface to the model providers'
414
+ generation functions.
415
+
416
+ Attributes
417
+ ----------
418
+ model
419
+ The wrapped model.
420
+ logits_processor
421
+ The logits processor to use to generate text.
422
+ sampler
423
+ The sampler to use to generate text.
424
+
425
+ """
426
+
427
+ def __init__(self, model, logits_processor, sampler):
428
+ self.model = model
429
+ self.logits_processor = logits_processor
430
+
431
+ if isinstance(sampler, MultinomialSampler):
432
+ self.sampling_params = SamplingParameters(
433
+ "multinomial",
434
+ sampler.samples,
435
+ sampler.top_p,
436
+ sampler.top_k,
437
+ sampler.temperature,
438
+ )
439
+ elif isinstance(sampler, GreedySampler):
440
+ self.sampling_params = SamplingParameters(
441
+ "greedy", sampler.samples, None, None, 0.0
442
+ )
443
+ elif isinstance(sampler, BeamSearchSampler):
444
+ self.sampling_params = SamplingParameters(
445
+ "beam_search", sampler.samples, None, None, 1.0
446
+ )
447
+
448
+ def prepare_generation_parameters(
449
+ self,
450
+ max_tokens: Optional[int],
451
+ stop_at: Optional[Union[str, List[str]]],
452
+ seed: Optional[int],
453
+ ):
454
+ if isinstance(stop_at, str):
455
+ stop_at = [stop_at]
456
+
457
+ generation_params = GenerationParameters(
458
+ max_tokens,
459
+ stop_at,
460
+ seed,
461
+ )
462
+
463
+ return generation_params
464
+
465
+ def format_sequence(self, sequence: str) -> FormattedOutput:
466
+ """Translate the generated sequence to another type.
467
+
468
+ This method is for instance overridden when generating JSON to either
469
+ return a dictionnary or a Pydantic model.
470
+
471
+ Parameters
472
+ ----------
473
+ sequence
474
+ A generated sequences.
475
+
476
+ Returns
477
+ -------
478
+ The formatted sequence.
479
+
480
+ """
481
+ return sequence
482
+
483
+ def _format(self, sequences):
484
+ """Apply formatting to every string in a completion."""
485
+ if isinstance(sequences, list):
486
+ return [self._format(sequence) for sequence in sequences]
487
+ else:
488
+ return self.format_sequence(sequences)
489
+
490
+ def __call__(
491
+ self,
492
+ prompts: Union[str, List[str]],
493
+ max_tokens: Optional[int] = None,
494
+ stop_at: Optional[Union[str, List[str]]] = None,
495
+ seed: Optional[int] = None,
496
+ **model_specific_params,
497
+ ):
498
+ """Generate text from a prompt of list of prompts."""
499
+
500
+ generation_params = self.prepare_generation_parameters(
501
+ max_tokens, stop_at, seed
502
+ )
503
+
504
+ completions = self.model.generate(
505
+ prompts,
506
+ generation_params,
507
+ copy(self.logits_processor),
508
+ self.sampling_params,
509
+ **model_specific_params,
510
+ )
511
+
512
+ return self._format(completions)
513
+
514
+ def stream(
515
+ self,
516
+ prompts: Union[str, List[str]],
517
+ max_tokens: Optional[int] = None,
518
+ stop_at: Optional[Union[str, List[str]]] = None,
519
+ seed: Optional[int] = None,
520
+ **model_specific_params,
521
+ ):
522
+ """Return a text generator from a prompt or a list of prompts."""
523
+ generation_params = self.prepare_generation_parameters(
524
+ max_tokens, stop_at, seed
525
+ )
526
+ return self.model.stream(
527
+ prompts,
528
+ generation_params,
529
+ copy(self.logits_processor),
530
+ self.sampling_params,
531
+ **model_specific_params,
532
+ )
533
+
534
+
535
+ class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter):
536
+ def __call__( # type: ignore
537
+ self,
538
+ prompts: Union[str, List[str]],
539
+ media: Union[str, Any],
540
+ max_tokens: Optional[int] = None,
541
+ stop_at: Optional[Union[str, List[str]]] = None,
542
+ seed: Optional[int] = None,
543
+ **model_specific_params,
544
+ ):
545
+ """
546
+ Generate text from a prompt of list of prompts.
547
+
548
+ Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
549
+ """
550
+ prompts, media = self._validate_prompt_media_types(prompts, media)
551
+
552
+ generation_params = self.prepare_generation_parameters(
553
+ max_tokens, stop_at, seed
554
+ )
555
+
556
+ completions = self.model.generate(
557
+ prompts,
558
+ media,
559
+ generation_params,
560
+ copy(self.logits_processor),
561
+ self.sampling_params,
562
+ **model_specific_params,
563
+ )
564
+
565
+ return self._format(completions)
566
+
567
+ def stream( # type: ignore
568
+ self,
569
+ prompts: Union[str, List[str]],
570
+ media: List[Union[str, Any, List[Union[str, Any]]]],
571
+ max_tokens: Optional[int] = None,
572
+ stop_at: Optional[Union[str, List[str]]] = None,
573
+ seed: Optional[int] = None,
574
+ **model_specific_params,
575
+ ):
576
+ """Return a text generator from a prompt or a list of prompts."""
577
+ prompts, media = self._validate_prompt_media_types(prompts, media)
578
+ generation_params = self.prepare_generation_parameters(
579
+ max_tokens, stop_at, seed
580
+ )
581
+ return self.model.stream(
582
+ prompts,
583
+ media,
584
+ generation_params,
585
+ copy(self.logits_processor),
586
+ self.sampling_params,
587
+ **model_specific_params,
588
+ )
589
+
590
+ @classmethod
591
+ def _validate_prompt_media_types(
592
+ cls,
593
+ prompts: Union[str, List[str]],
594
+ media: Union[str, Any, List[Union[str, Any]]],
595
+ ) -> Union[Any, List[Any]]:
596
+ """
597
+ Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image]
598
+ """
599
+
600
+ def valid_types(prompts, media):
601
+ from PIL import Image # type: ignore
602
+
603
+ if isinstance(prompts, list):
604
+ if not isinstance(media, list) or len(prompts) != len(media):
605
+ return False
606
+ for subprompt, submedia in zip(prompts, media):
607
+ if not isinstance(subprompt, str) or not all(
608
+ isinstance(m, Image.Image) for m in submedia
609
+ ):
610
+ return False
611
+ elif isinstance(prompts, str):
612
+ if not all(isinstance(m, Image.Image) for m in media):
613
+ return False
614
+ return True
615
+
616
+ if not valid_types(prompts, media):
617
+ raise TypeError(
618
+ "Expected (prompts, media) to be of type "
619
+ "(str, List[Image])), or (List[str], List[List[Image]]) "
620
+ f"instead got prompts={prompts}, media={media}"
621
+ )
622
+
623
+ return prompts, media
.venv/lib/python3.11/site-packages/outlines/generate/cfg.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import singledispatch
2
+
3
+ from outlines.generate.api import (
4
+ SequenceGeneratorAdapter,
5
+ VisionSequenceGeneratorAdapter,
6
+ )
7
+ from outlines.models import LlamaCpp, OpenAI, TransformersVision
8
+ from outlines.samplers import Sampler, multinomial
9
+
10
+
11
+ @singledispatch
12
+ def cfg(
13
+ model, cfg_str: str, sampler: Sampler = multinomial()
14
+ ) -> SequenceGeneratorAdapter:
15
+ """Generate text in the language of a Context-Free Grammar
16
+
17
+ Arguments
18
+ ---------
19
+ model:
20
+ An `outlines.model` instance.
21
+ sampler:
22
+ The sampling algorithm to use to generate token ids from the logits
23
+ distribution.
24
+
25
+ Returns
26
+ -------
27
+ A `SequenceGeneratorAdapter` instance that generates text.
28
+
29
+ """
30
+ from outlines.processors import CFGLogitsProcessor
31
+
32
+ logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
33
+ return SequenceGeneratorAdapter(model, logits_processor, sampler)
34
+
35
+
36
+ @cfg.register(TransformersVision)
37
+ def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
38
+ from outlines.processors import CFGLogitsProcessor
39
+
40
+ logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
41
+ return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
42
+
43
+
44
+ @cfg.register(LlamaCpp)
45
+ def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()):
46
+ raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer")
47
+
48
+
49
+ @cfg.register(OpenAI)
50
+ def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
51
+ raise NotImplementedError(
52
+ "Cannot use grammar-structured generation with an OpenAI model"
53
+ + "due to the limitations of the OpenAI API."
54
+ )
.venv/lib/python3.11/site-packages/outlines/generate/choice.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json as pyjson
2
+ import re
3
+ from enum import Enum
4
+ from functools import singledispatch
5
+ from typing import Callable, List, Union
6
+
7
+ from outlines_core.fsm.json_schema import build_regex_from_schema
8
+
9
+ from outlines.fsm.json_schema import get_schema_from_enum
10
+ from outlines.generate.api import SequenceGeneratorAdapter
11
+ from outlines.models import OpenAI
12
+ from outlines.samplers import Sampler, multinomial
13
+
14
+ from .json import json
15
+ from .regex import regex
16
+
17
+
18
+ @singledispatch
19
+ def choice(
20
+ model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial()
21
+ ) -> SequenceGeneratorAdapter:
22
+ if isinstance(choices, type(Enum)):
23
+ regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices)))
24
+ else:
25
+ choices = [re.escape(choice) for choice in choices] # type: ignore
26
+ regex_str = r"(" + r"|".join(choices) + r")"
27
+
28
+ generator = regex(model, regex_str, sampler)
29
+ if isinstance(choices, type(Enum)):
30
+ generator.format_sequence = lambda x: pyjson.loads(x)
31
+ else:
32
+ generator.format_sequence = lambda x: x
33
+
34
+ return generator
35
+
36
+
37
+ @choice.register(OpenAI)
38
+ def choice_openai(
39
+ model: OpenAI, choices: List[str], sampler: Sampler = multinomial()
40
+ ) -> Callable:
41
+ """
42
+ Call OpenAI API with response_format of a dict:
43
+ {"result": <one of choices>}
44
+ """
45
+
46
+ choices_schema = pyjson.dumps(
47
+ {
48
+ "type": "object",
49
+ "properties": {"result": {"type": "string", "enum": choices}},
50
+ "additionalProperties": False,
51
+ "required": ["result"],
52
+ }
53
+ )
54
+ generator = json(model, choices_schema, sampler)
55
+
56
+ def generate_choice(*args, **kwargs):
57
+ return generator(*args, **kwargs)["result"]
58
+
59
+ return generate_choice
.venv/lib/python3.11/site-packages/outlines/generate/format.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import singledispatch
2
+
3
+ from outlines.fsm.types import python_types_to_regex
4
+ from outlines.generate.api import SequenceGeneratorAdapter
5
+ from outlines.models import OpenAI
6
+ from outlines.samplers import Sampler, multinomial
7
+
8
+ from .regex import regex
9
+
10
+
11
+ @singledispatch
12
+ def format(
13
+ model, python_type, sampler: Sampler = multinomial()
14
+ ) -> SequenceGeneratorAdapter:
15
+ """Generate structured data that can be parsed as a Python type.
16
+
17
+ Parameters
18
+ ----------
19
+ model:
20
+ An instance of `Transformer` that represents a model from the
21
+ `transformers` library.
22
+ python_type:
23
+ A Python type. The output of the generator must be parseable into
24
+ this type.
25
+ sampler:
26
+ The sampling algorithm to use to generate token ids from the logits
27
+ distribution.
28
+
29
+ Returns
30
+ -------
31
+ A `SequenceGenerator` instance that generates text constrained by the Python type
32
+ and translates this text into the corresponding type.
33
+
34
+ """
35
+ regex_str, format_fn = python_types_to_regex(python_type)
36
+ generator = regex(model, regex_str, sampler)
37
+ generator.format_sequence = format_fn
38
+
39
+ return generator
40
+
41
+
42
+ @format.register(OpenAI)
43
+ def format_openai(model, python_type, sampler: Sampler = multinomial()):
44
+ raise NotImplementedError(
45
+ "Cannot use Python type-structured generation with an OpenAI model"
46
+ + " due to the limitations of the OpenAI API."
47
+ )
.venv/lib/python3.11/site-packages/outlines/generate/fsm.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import singledispatch
2
+
3
+ import interegular
4
+
5
+ from outlines.fsm.guide import RegexGuide
6
+ from outlines.generate.api import (
7
+ SequenceGeneratorAdapter,
8
+ VisionSequenceGeneratorAdapter,
9
+ )
10
+ from outlines.models import TransformersVision
11
+ from outlines.samplers import Sampler, multinomial
12
+
13
+
14
+ @singledispatch
15
+ def fsm(
16
+ model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
17
+ ) -> SequenceGeneratorAdapter:
18
+ from outlines.processors import GuideLogitsProcessor
19
+
20
+ guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
21
+ logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
22
+ return SequenceGeneratorAdapter(model, logits_processor, sampler)
23
+
24
+
25
+ @fsm.register(TransformersVision)
26
+ def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
27
+ from outlines.processors import GuideLogitsProcessor
28
+
29
+ guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
30
+ logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
31
+ return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
.venv/lib/python3.11/site-packages/outlines/generate/generator.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import math
3
+ from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple
4
+
5
+ if TYPE_CHECKING:
6
+ import torch
7
+
8
+ from outlines.fsm.guide import Guide
9
+
10
+
11
+ class ContextLengthExceededError(Exception):
12
+ pass
13
+
14
+
15
+ @dataclasses.dataclass(frozen=True)
16
+ class GenerationState:
17
+ token_ids: "torch.Tensor"
18
+ kv_cache: "torch.Tensor"
19
+ logits: "torch.Tensor"
20
+ weights: "torch.Tensor"
21
+ fsm_states: List[int]
22
+
23
+
24
+ def sequence_generator(
25
+ model: Callable,
26
+ sampler: Callable,
27
+ fsms: List["Guide"],
28
+ token_ids: "torch.Tensor",
29
+ sequence_weights: "torch.Tensor",
30
+ attention_masks: "torch.Tensor",
31
+ fsm_states: List[int],
32
+ rng: "torch.Generator",
33
+ ) -> Iterator[GenerationState]:
34
+ """Generates sequences of tokens.
35
+
36
+ Parameters
37
+ ----------
38
+ model
39
+ A callable that generates a probability distribution over the
40
+ vocabulary when passed a tensor of token ids.
41
+ sampler
42
+ A callable that returns the next token ids, their ancestor sequence and
43
+ the updated sequence weights when passed a distribution over the
44
+ vocabulary.
45
+ token_ids
46
+ A tensor of token ids on which the sequence distribution is conditioned, of
47
+ shape ``(n_seqs, n_prompt_tokens)``
48
+ sequence_weights
49
+ A tensor that contains the initial weights of the sequences, of shape
50
+ ``(n_seqs,)``
51
+ attention_masks
52
+ A tensor of tensors that represent the tokens considered at the attention
53
+ layer, of shape ``(n_seqs, n_prompt_tokens)``.
54
+ fsms
55
+ List of finite-state machines that drive the text generation,
56
+ one for each sequence in the batch.
57
+ fsm_states
58
+ The initial states of the finite-state machine for each sequence in the batch.
59
+
60
+ Yields
61
+ ------
62
+ A new sequence.
63
+
64
+ """
65
+ import torch
66
+
67
+ if rng is None:
68
+ rng = torch.Generator()
69
+
70
+ kv_cache = None
71
+
72
+ while True:
73
+ try:
74
+ logits, kv_cache = model(token_ids, attention_masks, kv_cache)
75
+ except IndexError: # Exceeding the context length
76
+ raise ContextLengthExceededError(
77
+ "The input length exceeds the context length of the model."
78
+ )
79
+
80
+ allowed_tokens = get_allowed_tokens(fsms, fsm_states)
81
+ biased_logits = bias_logits(logits, allowed_tokens)
82
+ next_token_ids, ancestors, sequence_weights = sampler(
83
+ biased_logits, sequence_weights, rng
84
+ )
85
+
86
+ token_ids = update_token_ids(token_ids, next_token_ids, ancestors)
87
+ attention_masks = update_attention_masks(attention_masks, ancestors)
88
+ kv_cache = reorder_kv_cache(kv_cache, ancestors)
89
+ if len(ancestors) > 1:
90
+ fsms = reorder_fsms(fsms, ancestors)
91
+ fsm_states = reorder_fsm_states(fsm_states, ancestors)
92
+
93
+ fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids)
94
+ is_finished = is_generation_finished(fsms, fsm_states)
95
+
96
+ if is_finished:
97
+ yield GenerationState(
98
+ token_ids,
99
+ kv_cache,
100
+ logits,
101
+ sequence_weights,
102
+ fsm_states,
103
+ )
104
+ return
105
+
106
+ yield GenerationState(
107
+ token_ids,
108
+ kv_cache,
109
+ logits,
110
+ sequence_weights,
111
+ fsm_states,
112
+ )
113
+
114
+
115
+ def get_next_fsm_states(
116
+ fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
117
+ ) -> List[int]:
118
+ """
119
+
120
+ Parameters
121
+ ----------
122
+ fsm
123
+ The finite-state machine used to monitor this batch.
124
+ next_token_ids
125
+ The tokens that were just generated.
126
+
127
+ Returns
128
+ -------
129
+ A `torch.Tensor` object that represents the next logit mask.
130
+
131
+ """
132
+ return [
133
+ fsm.get_next_state(fsm_state, int(token_id[0]))
134
+ for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids)
135
+ ]
136
+
137
+
138
+ def get_allowed_tokens(
139
+ fsms: List["Guide"], fsm_states: List[int]
140
+ ) -> List[Optional[Iterable[int]]]:
141
+ """Get the new instructions for each sequence from the finite-state machine.
142
+
143
+ Parameters
144
+ ----------
145
+ fsm
146
+ The finite-state machine used to monitor this batch.
147
+ fsm_states
148
+ The FSM states corresponding to each sequence in the batch.
149
+
150
+ Returns
151
+ -------
152
+ A nested list that contains the ids of the logits to keep.
153
+
154
+ """
155
+ return [
156
+ fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states)
157
+ ]
158
+
159
+
160
+ def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool:
161
+ """Determine if the generation is finished.
162
+
163
+ A generation is considered finished if the FSM of every sequence in the
164
+ batch is in a final state.
165
+
166
+ A better solution is to return finished sequences as soon as their FSM
167
+ is in a final state.
168
+
169
+ Parameters
170
+ ----------
171
+ fsm
172
+ The finite-state machine used to monitor this batch.
173
+ fsm_states
174
+ The FSM states corresponding to each sequence in the batch.
175
+
176
+ Returns
177
+ -------
178
+ Whether all sequences are finished sampling.
179
+
180
+ """
181
+ return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)])
182
+
183
+
184
+ def update_token_ids(
185
+ token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor"
186
+ ) -> "torch.Tensor":
187
+ """Append the sampled tokens to the running sequence of tokens.
188
+
189
+ Parameters
190
+ ----------
191
+ token_ids
192
+ The current token sequences
193
+ next_token_ids
194
+ The tokens that were just generated and that we need to append
195
+ to the existing sequences.
196
+ ancestors
197
+ The sequences to which the token ids need to be added.
198
+
199
+ Returns
200
+ -------
201
+ A new sequence of token ids that contains the tokens that were
202
+ just generated.
203
+
204
+ """
205
+ import torch
206
+
207
+ token_ids = torch.index_select(token_ids, 0, ancestors)
208
+ return torch.concatenate([token_ids, next_token_ids], dim=-1)
209
+
210
+
211
+ def update_attention_masks(
212
+ attention_masks: "torch.Tensor", ancestors: "torch.Tensor"
213
+ ) -> "torch.Tensor":
214
+ """Expand the attention masks.
215
+
216
+ Parameters
217
+ ----------
218
+ attention_masks
219
+ The attention masks for each sequence in the batch.
220
+ ancestors
221
+ The sequences to which the token ids need to be added.
222
+
223
+ Returns
224
+ -------
225
+ The attention masks padded with 1s.
226
+
227
+ """
228
+ import torch
229
+
230
+ attention_masks = torch.index_select(attention_masks, 0, ancestors)
231
+ return torch.concatenate(
232
+ [
233
+ attention_masks,
234
+ torch.ones(
235
+ attention_masks.shape[:-1] + (1,), device=attention_masks.device
236
+ ),
237
+ ],
238
+ axis=-1,
239
+ )
240
+
241
+
242
+ def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]:
243
+ reordered_fsms = []
244
+ for ancestor in ancestors:
245
+ reordered_fsms.append(fsms[ancestor].copy())
246
+
247
+ return reordered_fsms
248
+
249
+
250
+ def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]:
251
+ reordered_states = []
252
+ for ancestor in ancestors:
253
+ reordered_states.append(fsm_states[ancestor])
254
+
255
+ return reordered_states
256
+
257
+
258
+ def reorder_kv_cache(
259
+ kv_cache: Optional[Tuple], ancestors: "torch.Tensor"
260
+ ) -> Optional[Tuple]:
261
+ """Re-order the KV-cache based on the ancestors.
262
+
263
+ In transformers, the object that stores the KV-cache is a tuple who elements
264
+ are the key cache and the value cache. Each of these caches are tuples where
265
+ each element correpond to a layer. To each layer corresponds a tensor whose
266
+ first dimension is the batch size.
267
+
268
+ """
269
+ import torch
270
+
271
+ if kv_cache is None:
272
+ return None
273
+
274
+ new_kv_cache: Tuple = tuple()
275
+ for cache_item in kv_cache:
276
+ new_cache_item: Tuple = tuple()
277
+ for layer in cache_item:
278
+ layer = torch.index_select(layer, 0, ancestors.to(layer.device))
279
+ new_cache_item += (layer,)
280
+ new_kv_cache += (new_cache_item,)
281
+
282
+ return new_kv_cache
283
+
284
+
285
+ def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor":
286
+ """Mask the logits.
287
+
288
+ The function iterates over a nested list where each list corresponds to the
289
+ indices that need to be masked for each row in the array.
290
+
291
+ Parameters
292
+ ----------
293
+ logits
294
+ Two dimensional tensor that contains the next-token probability
295
+ distribution.
296
+ allowed_token_ids
297
+ A list that contains the tokens that can be generated by the model.
298
+
299
+ Returns
300
+ -------
301
+ A view of the original logits tensor where some values are masked.
302
+
303
+ """
304
+ import torch
305
+
306
+ biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
307
+ for i, ids in enumerate(allowed_token_ids):
308
+ if ids is not None:
309
+ biased_logits[i, ids] = logits[i, ids]
310
+ else:
311
+ biased_logits[i] = logits[i]
312
+ return biased_logits
.venv/lib/python3.11/site-packages/outlines/generate/json.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json as pyjson
2
+ from enum import Enum
3
+ from functools import singledispatch
4
+ from typing import Callable, Optional, Union
5
+
6
+ from outlines_core.fsm.json_schema import build_regex_from_schema
7
+ from pydantic import BaseModel
8
+
9
+ from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature
10
+ from outlines.generate.api import SequenceGeneratorAdapter
11
+ from outlines.models import OpenAI
12
+ from outlines.samplers import Sampler, multinomial
13
+
14
+ from .regex import regex
15
+
16
+
17
+ @singledispatch
18
+ def json(
19
+ model,
20
+ schema_object: Union[str, object, Callable],
21
+ sampler: Sampler = multinomial(),
22
+ whitespace_pattern: Optional[str] = None,
23
+ ) -> SequenceGeneratorAdapter:
24
+ """
25
+ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
26
+
27
+ Parameters
28
+ ----------
29
+ model:
30
+ An instance of `Transformer` that represents a model from the
31
+ `transformers` library.
32
+ schema_object:
33
+ The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable
34
+ that returns a JSON schema.
35
+ sampler:
36
+ The sampling algorithm to use to generate token ids from the logits
37
+ distribution.
38
+ whitespace_pattern
39
+ Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
40
+ Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
41
+
42
+ Returns
43
+ -------
44
+ A `SequenceGenerator` instance that generates text constrained by the schema_object and
45
+ transforms the result if BaseModel is used.
46
+
47
+ """
48
+ if isinstance(schema_object, type(BaseModel)):
49
+ schema = pyjson.dumps(schema_object.model_json_schema())
50
+ regex_str = build_regex_from_schema(schema, whitespace_pattern)
51
+ generator = regex(model, regex_str, sampler)
52
+ generator.format_sequence = lambda x: schema_object.parse_raw(x)
53
+ elif isinstance(schema_object, type(Enum)):
54
+ schema = pyjson.dumps(get_schema_from_enum(schema_object))
55
+ regex_str = build_regex_from_schema(schema, whitespace_pattern)
56
+ generator = regex(model, regex_str, sampler)
57
+ generator.format_sequence = lambda x: pyjson.loads(x)
58
+ elif callable(schema_object):
59
+ schema = pyjson.dumps(get_schema_from_signature(schema_object))
60
+ regex_str = build_regex_from_schema(schema, whitespace_pattern)
61
+ generator = regex(model, regex_str, sampler)
62
+ generator.format_sequence = lambda x: pyjson.loads(x)
63
+ elif isinstance(schema_object, str):
64
+ schema = schema_object
65
+ regex_str = build_regex_from_schema(schema, whitespace_pattern)
66
+ generator = regex(model, regex_str, sampler)
67
+ generator.format_sequence = lambda x: pyjson.loads(x)
68
+ else:
69
+ raise ValueError(
70
+ f"Cannot parse schema {schema_object}. The schema must be either "
71
+ + "a Pydantic object, a function or a string that contains the JSON "
72
+ + "Schema specification"
73
+ )
74
+
75
+ return generator
76
+
77
+
78
+ @json.register(OpenAI)
79
+ def json_openai(
80
+ model, schema_object: Union[str, object], sampler: Sampler = multinomial()
81
+ ):
82
+ if not isinstance(sampler, multinomial):
83
+ raise NotImplementedError(
84
+ r"The OpenAI API does not support any other sampling algorithm "
85
+ + "than the multinomial sampler."
86
+ )
87
+
88
+ if isinstance(schema_object, type(BaseModel)):
89
+ schema = pyjson.dumps(schema_object.model_json_schema())
90
+ format_sequence = lambda x: schema_object.parse_raw(x)
91
+ elif isinstance(schema_object, str):
92
+ schema = schema_object
93
+ format_sequence = lambda x: pyjson.loads(x)
94
+ else:
95
+ raise ValueError(
96
+ f"Cannot parse schema {schema_object}. The schema must be either "
97
+ + "a Pydantic object, a function or a string that contains the JSON "
98
+ + "Schema specification"
99
+ )
100
+
101
+ # create copied, patched model with normalized json schema set
102
+ generator = model.new_with_replacements(
103
+ response_format={
104
+ "type": "json_schema",
105
+ "json_schema": {
106
+ "name": "default",
107
+ "strict": True,
108
+ "schema": pyjson.loads(schema),
109
+ },
110
+ }
111
+ )
112
+
113
+ generator.format_sequence = format_sequence
114
+
115
+ return generator
.venv/lib/python3.11/site-packages/outlines/generate/regex.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import singledispatch
2
+
3
+ from outlines.generate.api import (
4
+ SequenceGeneratorAdapter,
5
+ VisionSequenceGeneratorAdapter,
6
+ )
7
+ from outlines.models import OpenAI, TransformersVision
8
+ from outlines.samplers import Sampler, multinomial
9
+
10
+
11
+ @singledispatch
12
+ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
13
+ """Generate structured text in the language of a regular expression.
14
+
15
+ Parameters
16
+ ----------
17
+ model:
18
+ An instance of `Transformer` that represents a model from the
19
+ `transformers` library.
20
+ regex_str:
21
+ The regular expression that the output must follow.
22
+ sampler:
23
+ The sampling algorithm to use to generate token ids from the logits
24
+ distribution.
25
+
26
+ Returns
27
+ -------
28
+ A `SequenceGeneratorAdapter` instance that generates text constrained by the
29
+ regular expression.
30
+
31
+ """
32
+ from outlines.processors import RegexLogitsProcessor
33
+
34
+ logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
35
+ return SequenceGeneratorAdapter(model, logits_processor, sampler)
36
+
37
+
38
+ @regex.register(TransformersVision)
39
+ def regex_vision(
40
+ model,
41
+ regex_str: str,
42
+ sampler: Sampler = multinomial(),
43
+ ):
44
+ from outlines.processors import RegexLogitsProcessor
45
+
46
+ logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
47
+ return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
48
+
49
+
50
+ @regex.register(OpenAI)
51
+ def regex_openai(
52
+ model: OpenAI,
53
+ regex_str: str,
54
+ sampler: Sampler = multinomial(),
55
+ ):
56
+ raise NotImplementedError(
57
+ "Cannot use regex-structured generation with an OpenAI model"
58
+ + "due to the limitations of the OpenAI API."
59
+ )
.venv/lib/python3.11/site-packages/outlines/generate/text.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import singledispatch
2
+
3
+ from outlines.generate.api import (
4
+ SequenceGeneratorAdapter,
5
+ VisionSequenceGeneratorAdapter,
6
+ )
7
+ from outlines.models import OpenAI, TransformersVision
8
+ from outlines.samplers import Sampler, multinomial
9
+
10
+
11
+ @singledispatch
12
+ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
13
+ """Generate text with a `Transformer` model.
14
+
15
+ Note
16
+ ----
17
+ Python 3.11 allows dispatching on Union types and
18
+ this should greatly simplify the code.
19
+
20
+ Arguments
21
+ ---------
22
+ model:
23
+ An instance of `Transformer` that represents a model from the
24
+ `transformers` library.
25
+ sampler:
26
+ The sampling algorithm to use to generate token ids from the logits
27
+ distribution.
28
+
29
+ Returns
30
+ -------
31
+ A `SequenceGeneratorAdapter` instance that generates text.
32
+
33
+ """
34
+ return SequenceGeneratorAdapter(model, None, sampler)
35
+
36
+
37
+ @text.register(TransformersVision)
38
+ def text_vision(model, sampler: Sampler = multinomial()):
39
+ return VisionSequenceGeneratorAdapter(model, None, sampler)
40
+
41
+
42
+ @text.register(OpenAI)
43
+ def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
44
+ if not isinstance(sampler, multinomial):
45
+ raise NotImplementedError(
46
+ r"The OpenAI API does not support any other sampling algorithm "
47
+ + "than the multinomial sampler."
48
+ )
49
+
50
+ return model
.venv/lib/python3.11/site-packages/outlines/grammars.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ GRAMMAR_PATH = Path(__file__).parent / "grammars"
4
+
5
+
6
+ def read_grammar(grammar_file_name, base_grammar_path=GRAMMAR_PATH):
7
+ """Read grammar file from default grammar path"""
8
+ full_path = base_grammar_path / grammar_file_name
9
+ with open(full_path) as file:
10
+ return file.read()
11
+
12
+
13
+ arithmetic = read_grammar("arithmetic.lark")
14
+ json = read_grammar("json.lark")
.venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ?start: sum
2
+
3
+ ?sum: product
4
+ | sum "+" product -> add
5
+ | sum "-" product -> sub
6
+
7
+ ?product: atom
8
+ | product "*" atom -> mul
9
+ | product "/" atom -> div
10
+
11
+ ?atom: NUMBER -> number
12
+ | "-" atom -> neg
13
+ | "(" sum ")"
14
+
15
+ %import common.NUMBER
16
+ %import common.WS_INLINE
17
+
18
+ %ignore WS_INLINE
.venv/lib/python3.11/site-packages/outlines/grammars/common.lark ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Adapted from https://github.com/lark-parser/lark/blob/master/lark/grammars/common.lark
2
+
3
+ // Lark License:
4
+ // Copyright © 2017 Erez Shinan
5
+ //
6
+ // Permission is hereby granted, free of charge, to any person obtaining a copy of
7
+ // this software and associated documentation files (the "Software"), to deal in
8
+ // the Software without restriction, including without limitation the rights to
9
+ // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
10
+ // the Software, and to permit persons to whom the Software is furnished to do so,
11
+ // subject to the following conditions:
12
+ //
13
+ // The above copyright notice and this permission notice shall be included in all
14
+ // copies or substantial portions of the Software.
15
+ //
16
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
18
+ // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
19
+ // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
20
+ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21
+ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22
+
23
+
24
+ // Basic terminals for common use
25
+
26
+
27
+ //
28
+ // Numbers
29
+ //
30
+
31
+ DIGIT: "0".."9"
32
+ HEXDIGIT: "a".."f"|"A".."F"|DIGIT
33
+
34
+ INT: DIGIT+
35
+ SIGNED_INT: ["+"|"-"] INT
36
+ DECIMAL: INT "." INT? | "." INT
37
+
38
+ // float = /-?\d+(\.\d+)?([eE][+-]?\d+)?/
39
+ _EXP: ("e"|"E") SIGNED_INT
40
+ FLOAT: INT _EXP | DECIMAL _EXP?
41
+ SIGNED_FLOAT: ["+"|"-"] FLOAT
42
+
43
+ NUMBER: FLOAT | INT
44
+ SIGNED_NUMBER: ["+"|"-"] NUMBER
45
+
46
+ UNESCAPED_STRING: /\"[^"]*\"/
47
+
48
+ // based on `outlines/fsm/json_schema.py`
49
+ _NON_CONTROL_CHAR: /([^"\\\x00-\x1F\x7F-\x9F])/
50
+ _ESCAPED_CHAR: /\\/ (_NON_CONTROL_CHAR | /\\/ | /"/)
51
+ ESCAPED_STRING_INNER: _NON_CONTROL_CHAR | _ESCAPED_CHAR
52
+ ESCAPED_STRING: /"/ ESCAPED_STRING_INNER* /"/
53
+
54
+
55
+
56
+ //
57
+ // Names (Variables)
58
+ //
59
+ LCASE_LETTER: "a".."z"
60
+ UCASE_LETTER: "A".."Z"
61
+
62
+ LETTER: UCASE_LETTER | LCASE_LETTER
63
+ WORD: LETTER+
64
+
65
+ CNAME: ("_"|LETTER) ("_"|LETTER|DIGIT)*
66
+
67
+
68
+ //
69
+ // Whitespace
70
+ //
71
+ WS_INLINE: (" "|/\t/)+
72
+ WS: /[ \t\f\r\n]/+
73
+
74
+ CR : /\r/
75
+ LF : /\n/
76
+ NEWLINE: (CR? LF)+
77
+
78
+
79
+ // Comments
80
+ SH_COMMENT: /#[^\n]*/
81
+ CPP_COMMENT: /\/\/[^\n]*/
82
+ C_COMMENT: "/*" /(.|\n)*?/ "*/"
83
+ SQL_COMMENT: /--[^\n]*/
.venv/lib/python3.11/site-packages/outlines/grammars/json.lark ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ?start: value
2
+
3
+ ?value: object
4
+ | array
5
+ | ESCAPED_STRING
6
+ | SIGNED_NUMBER -> number
7
+ | "true" -> true
8
+ | "false" -> false
9
+ | "null" -> null
10
+
11
+ array : "[" [value ("," value)*] "]"
12
+ object : "{" [pair ("," pair)*] "}"
13
+ pair : ESCAPED_STRING ":" value
14
+
15
+ %import common.ESCAPED_STRING
16
+ %import common.SIGNED_NUMBER
17
+ %import common.WS
18
+
19
+ %ignore WS
.venv/lib/python3.11/site-packages/outlines/processors/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .structured import (
2
+ CFGLogitsProcessor,
3
+ GuideLogitsProcessor,
4
+ JSONLogitsProcessor,
5
+ OutlinesLogitsProcessor,
6
+ RegexLogitsProcessor,
7
+ )
.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (441 Bytes). View file
 
.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc ADDED
Binary file (7.81 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import TYPE_CHECKING, List, Protocol, Type, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from numpy.typing import NDArray
7
+
8
+ if TYPE_CHECKING:
9
+ import mlx.core as mx
10
+
11
+
12
+ Array = Union[NDArray, torch.Tensor, List, "mx.array"]
13
+
14
+
15
+ def is_mlx_array_type(array_type):
16
+ try:
17
+ import mlx.core as mx
18
+ except ImportError:
19
+ return False
20
+ return issubclass(array_type, mx.array)
21
+
22
+
23
+ def is_jax_array_type(array_type):
24
+ try:
25
+ import jaxlib
26
+ except ImportError:
27
+ return False
28
+ return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance(
29
+ array_type, jaxlib.xla_extension.ArrayImpl
30
+ )
31
+
32
+
33
+ class OutlinesLogitsProcessor(Protocol):
34
+ """
35
+ Base class for logits processors which normalizes types of logits:
36
+ - ndarray (used by llama-cpp-python), converted to torch.Tensor
37
+ - mlx.core.array (used by mlx-lm), converted to torch.Tensor
38
+ - torch.Tensor (used by everything else)
39
+
40
+ Normalization of types and conversion to torch.Tensor
41
+ doesn't move memory, it just casts the type.
42
+
43
+ Normalizing the types allows all logits processors inheriting from this class
44
+ to implement a single method for all the business logit: `process_logits()`
45
+ """
46
+
47
+ @abstractmethod
48
+ def process_logits(
49
+ self, input_ids: List[List[int]], logits: torch.Tensor
50
+ ) -> torch.Tensor:
51
+ """
52
+ input_ids and logits are always 2D tensors for handling a batch of sequences.
53
+
54
+ - input_ids -> List[List[tokens]]
55
+ - logits -> 2D_Tensor[logit floats]
56
+
57
+ Important to keep in mind when designing universal logits processors
58
+ - logits processors are only used once and never re-applied for a new sequence generator
59
+ - Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids
60
+ - Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM
61
+ """
62
+ pass
63
+
64
+ @torch.no_grad()
65
+ def __call__(
66
+ self,
67
+ input_ids: Array,
68
+ logits: Array,
69
+ ) -> Array:
70
+ """
71
+ Apply logits processor
72
+
73
+ 1) Unify type
74
+ - convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]]
75
+ - convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor
76
+ 2) Unify shape, ensure logits and input_ids are 2D
77
+ 3) Call self.process_logits() to perform business logic
78
+ 4) Cast logits back to original array library type
79
+ """
80
+ # ensure logits are torch Tensors
81
+ torch_logits = self._to_torch(logits)
82
+ input_ids = self._to_torch(input_ids)
83
+
84
+ assert torch_logits.shape[:-1] == input_ids.shape[:-1]
85
+
86
+ # Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
87
+ if len(torch_logits.shape) == 2:
88
+ processed_logits = self.process_logits(input_ids, torch_logits)
89
+ elif len(torch_logits.shape) == 1:
90
+ processed_logits = self.process_logits(
91
+ input_ids.unsqueeze(0), torch_logits.unsqueeze(0)
92
+ ).squeeze(0)
93
+
94
+ # return logits as passed array type
95
+ return self._from_torch(processed_logits, type(logits))
96
+
97
+ @staticmethod
98
+ def _to_torch(tensor_like: Array) -> torch.Tensor:
99
+ """Convert various types to torch.Tensor."""
100
+ if isinstance(tensor_like, torch.Tensor):
101
+ return tensor_like
102
+
103
+ elif isinstance(tensor_like, np.ndarray):
104
+ return torch.from_numpy(tensor_like)
105
+
106
+ elif isinstance(tensor_like, (list, tuple)):
107
+ return torch.tensor(tensor_like)
108
+
109
+ elif is_mlx_array_type(type(tensor_like)):
110
+ import mlx.core as mx
111
+
112
+ # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch
113
+ return torch.from_dlpack(
114
+ np.array(tensor_like.astype(mx.float32), copy=False)
115
+ )
116
+
117
+ elif is_jax_array_type(type(tensor_like)):
118
+ import jax
119
+
120
+ torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
121
+ return torch_tensor
122
+
123
+ else:
124
+ raise TypeError(
125
+ "LogitsProcessor must be called with either np.NDArray, "
126
+ "torch.Tensor, list, or mlx.core.array typed logits. "
127
+ f"Logits type: `{type(tensor_like)}`"
128
+ )
129
+
130
+ @staticmethod
131
+ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
132
+ """Convert torch.Tensor to the specified target type."""
133
+ if target_type == torch.Tensor:
134
+ return tensor
135
+
136
+ elif target_type == np.ndarray:
137
+ return tensor.detach().numpy()
138
+
139
+ elif target_type == list:
140
+ return tensor.detach().tolist()
141
+
142
+ elif target_type == tuple:
143
+ return tuple(tensor.detach().tolist())
144
+
145
+ elif is_mlx_array_type(target_type):
146
+ import mlx.core as mx
147
+
148
+ # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
149
+ return mx.array(tensor.float().numpy())
150
+
151
+ elif is_jax_array_type(target_type):
152
+ import jax
153
+
154
+ return jax.dlpack.from_dlpack(tensor)
155
+
156
+ else:
157
+ raise TypeError(
158
+ f"Failed to convert torch tensors to target_type `{target_type}`"
159
+ )