Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .venv/lib/python3.11/site-packages/attr/__init__.py +104 -0
- .venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/attr/_compat.py +94 -0
- .venv/lib/python3.11/site-packages/attr/_config.py +31 -0
- .venv/lib/python3.11/site-packages/attr/_funcs.py +468 -0
- .venv/lib/python3.11/site-packages/attr/_version_info.pyi +9 -0
- .venv/lib/python3.11/site-packages/attr/exceptions.pyi +17 -0
- .venv/lib/python3.11/site-packages/attr/validators.pyi +86 -0
- .venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so +3 -0
- .venv/lib/python3.11/site-packages/outlines/__init__.py +20 -0
- .venv/lib/python3.11/site-packages/outlines/_version.py +16 -0
- .venv/lib/python3.11/site-packages/outlines/base.py +299 -0
- .venv/lib/python3.11/site-packages/outlines/caching.py +179 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/guide.py +276 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py +83 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/parsing.py +1127 -0
- .venv/lib/python3.11/site-packages/outlines/fsm/types.py +81 -0
- .venv/lib/python3.11/site-packages/outlines/function.py +117 -0
- .venv/lib/python3.11/site-packages/outlines/generate/__init__.py +8 -0
- .venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/generate/api.py +623 -0
- .venv/lib/python3.11/site-packages/outlines/generate/cfg.py +54 -0
- .venv/lib/python3.11/site-packages/outlines/generate/choice.py +59 -0
- .venv/lib/python3.11/site-packages/outlines/generate/format.py +47 -0
- .venv/lib/python3.11/site-packages/outlines/generate/fsm.py +31 -0
- .venv/lib/python3.11/site-packages/outlines/generate/generator.py +312 -0
- .venv/lib/python3.11/site-packages/outlines/generate/json.py +115 -0
- .venv/lib/python3.11/site-packages/outlines/generate/regex.py +59 -0
- .venv/lib/python3.11/site-packages/outlines/generate/text.py +50 -0
- .venv/lib/python3.11/site-packages/outlines/grammars.py +14 -0
- .venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark +18 -0
- .venv/lib/python3.11/site-packages/outlines/grammars/common.lark +83 -0
- .venv/lib/python3.11/site-packages/outlines/grammars/json.lark +19 -0
- .venv/lib/python3.11/site-packages/outlines/processors/__init__.py +7 -0
- .venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py +159 -0
.gitattributes
CHANGED
|
@@ -249,3 +249,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 249 |
.venv/lib/python3.11/site-packages/pycountry/locales/tr/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
|
| 250 |
.venv/lib/python3.11/site-packages/pycountry/locales/kn/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
|
| 251 |
.venv/lib/python3.11/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 249 |
.venv/lib/python3.11/site-packages/pycountry/locales/tr/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
|
| 250 |
.venv/lib/python3.11/site-packages/pycountry/locales/kn/LC_MESSAGES/iso639-3.mo filter=lfs diff=lfs merge=lfs -text
|
| 251 |
.venv/lib/python3.11/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 252 |
+
.venv/lib/python3.11/site-packages/torchvision.libs/libpng16.7f72a3c5.so.16 filter=lfs diff=lfs merge=lfs -text
|
| 253 |
+
.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/attr/__init__.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Classes Without Boilerplate
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Callable, Literal, Protocol
|
| 9 |
+
|
| 10 |
+
from . import converters, exceptions, filters, setters, validators
|
| 11 |
+
from ._cmp import cmp_using
|
| 12 |
+
from ._config import get_run_validators, set_run_validators
|
| 13 |
+
from ._funcs import asdict, assoc, astuple, has, resolve_types
|
| 14 |
+
from ._make import (
|
| 15 |
+
NOTHING,
|
| 16 |
+
Attribute,
|
| 17 |
+
Converter,
|
| 18 |
+
Factory,
|
| 19 |
+
_Nothing,
|
| 20 |
+
attrib,
|
| 21 |
+
attrs,
|
| 22 |
+
evolve,
|
| 23 |
+
fields,
|
| 24 |
+
fields_dict,
|
| 25 |
+
make_class,
|
| 26 |
+
validate,
|
| 27 |
+
)
|
| 28 |
+
from ._next_gen import define, field, frozen, mutable
|
| 29 |
+
from ._version_info import VersionInfo
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
s = attributes = attrs
|
| 33 |
+
ib = attr = attrib
|
| 34 |
+
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AttrsInstance(Protocol):
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
NothingType = Literal[_Nothing.NOTHING]
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"NOTHING",
|
| 45 |
+
"Attribute",
|
| 46 |
+
"AttrsInstance",
|
| 47 |
+
"Converter",
|
| 48 |
+
"Factory",
|
| 49 |
+
"NothingType",
|
| 50 |
+
"asdict",
|
| 51 |
+
"assoc",
|
| 52 |
+
"astuple",
|
| 53 |
+
"attr",
|
| 54 |
+
"attrib",
|
| 55 |
+
"attributes",
|
| 56 |
+
"attrs",
|
| 57 |
+
"cmp_using",
|
| 58 |
+
"converters",
|
| 59 |
+
"define",
|
| 60 |
+
"evolve",
|
| 61 |
+
"exceptions",
|
| 62 |
+
"field",
|
| 63 |
+
"fields",
|
| 64 |
+
"fields_dict",
|
| 65 |
+
"filters",
|
| 66 |
+
"frozen",
|
| 67 |
+
"get_run_validators",
|
| 68 |
+
"has",
|
| 69 |
+
"ib",
|
| 70 |
+
"make_class",
|
| 71 |
+
"mutable",
|
| 72 |
+
"resolve_types",
|
| 73 |
+
"s",
|
| 74 |
+
"set_run_validators",
|
| 75 |
+
"setters",
|
| 76 |
+
"validate",
|
| 77 |
+
"validators",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _make_getattr(mod_name: str) -> Callable:
|
| 82 |
+
"""
|
| 83 |
+
Create a metadata proxy for packaging information that uses *mod_name* in
|
| 84 |
+
its warnings and errors.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __getattr__(name: str) -> str:
|
| 88 |
+
if name not in ("__version__", "__version_info__"):
|
| 89 |
+
msg = f"module {mod_name} has no attribute {name}"
|
| 90 |
+
raise AttributeError(msg)
|
| 91 |
+
|
| 92 |
+
from importlib.metadata import metadata
|
| 93 |
+
|
| 94 |
+
meta = metadata("attrs")
|
| 95 |
+
|
| 96 |
+
if name == "__version_info__":
|
| 97 |
+
return VersionInfo._from_version_string(meta["version"])
|
| 98 |
+
|
| 99 |
+
return meta["version"]
|
| 100 |
+
|
| 101 |
+
return __getattr__
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
__getattr__ = _make_getattr(__name__)
|
.venv/lib/python3.11/site-packages/attr/__pycache__/_funcs.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/attr/__pycache__/_version_info.cpython-311.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/attr/__pycache__/converters.cpython-311.pyc
ADDED
|
Binary file (5.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/attr/__pycache__/setters.cpython-311.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/attr/_compat.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
import platform
|
| 5 |
+
import sys
|
| 6 |
+
import threading
|
| 7 |
+
|
| 8 |
+
from collections.abc import Mapping, Sequence # noqa: F401
|
| 9 |
+
from typing import _GenericAlias
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
PYPY = platform.python_implementation() == "PyPy"
|
| 13 |
+
PY_3_9_PLUS = sys.version_info[:2] >= (3, 9)
|
| 14 |
+
PY_3_10_PLUS = sys.version_info[:2] >= (3, 10)
|
| 15 |
+
PY_3_11_PLUS = sys.version_info[:2] >= (3, 11)
|
| 16 |
+
PY_3_12_PLUS = sys.version_info[:2] >= (3, 12)
|
| 17 |
+
PY_3_13_PLUS = sys.version_info[:2] >= (3, 13)
|
| 18 |
+
PY_3_14_PLUS = sys.version_info[:2] >= (3, 14)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if PY_3_14_PLUS: # pragma: no cover
|
| 22 |
+
import annotationlib
|
| 23 |
+
|
| 24 |
+
_get_annotations = annotationlib.get_annotations
|
| 25 |
+
|
| 26 |
+
else:
|
| 27 |
+
|
| 28 |
+
def _get_annotations(cls):
|
| 29 |
+
"""
|
| 30 |
+
Get annotations for *cls*.
|
| 31 |
+
"""
|
| 32 |
+
return cls.__dict__.get("__annotations__", {})
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class _AnnotationExtractor:
|
| 36 |
+
"""
|
| 37 |
+
Extract type annotations from a callable, returning None whenever there
|
| 38 |
+
is none.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
__slots__ = ["sig"]
|
| 42 |
+
|
| 43 |
+
def __init__(self, callable):
|
| 44 |
+
try:
|
| 45 |
+
self.sig = inspect.signature(callable)
|
| 46 |
+
except (ValueError, TypeError): # inspect failed
|
| 47 |
+
self.sig = None
|
| 48 |
+
|
| 49 |
+
def get_first_param_type(self):
|
| 50 |
+
"""
|
| 51 |
+
Return the type annotation of the first argument if it's not empty.
|
| 52 |
+
"""
|
| 53 |
+
if not self.sig:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
params = list(self.sig.parameters.values())
|
| 57 |
+
if params and params[0].annotation is not inspect.Parameter.empty:
|
| 58 |
+
return params[0].annotation
|
| 59 |
+
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
def get_return_type(self):
|
| 63 |
+
"""
|
| 64 |
+
Return the return type if it's not empty.
|
| 65 |
+
"""
|
| 66 |
+
if (
|
| 67 |
+
self.sig
|
| 68 |
+
and self.sig.return_annotation is not inspect.Signature.empty
|
| 69 |
+
):
|
| 70 |
+
return self.sig.return_annotation
|
| 71 |
+
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Thread-local global to track attrs instances which are already being repr'd.
|
| 76 |
+
# This is needed because there is no other (thread-safe) way to pass info
|
| 77 |
+
# about the instances that are already being repr'd through the call stack
|
| 78 |
+
# in order to ensure we don't perform infinite recursion.
|
| 79 |
+
#
|
| 80 |
+
# For instance, if an instance contains a dict which contains that instance,
|
| 81 |
+
# we need to know that we're already repr'ing the outside instance from within
|
| 82 |
+
# the dict's repr() call.
|
| 83 |
+
#
|
| 84 |
+
# This lives here rather than in _make.py so that the functions in _make.py
|
| 85 |
+
# don't have a direct reference to the thread-local in their globals dict.
|
| 86 |
+
# If they have such a reference, it breaks cloudpickle.
|
| 87 |
+
repr_context = threading.local()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_generic_base(cl):
|
| 91 |
+
"""If this is a generic class (A[str]), return the generic base for it."""
|
| 92 |
+
if cl.__class__ is _GenericAlias:
|
| 93 |
+
return cl.__origin__
|
| 94 |
+
return None
|
.venv/lib/python3.11/site-packages/attr/_config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
__all__ = ["get_run_validators", "set_run_validators"]
|
| 4 |
+
|
| 5 |
+
_run_validators = True
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_run_validators(run):
|
| 9 |
+
"""
|
| 10 |
+
Set whether or not validators are run. By default, they are run.
|
| 11 |
+
|
| 12 |
+
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
|
| 13 |
+
moved to new ``attrs`` namespace. Use `attrs.validators.set_disabled()`
|
| 14 |
+
instead.
|
| 15 |
+
"""
|
| 16 |
+
if not isinstance(run, bool):
|
| 17 |
+
msg = "'run' must be bool."
|
| 18 |
+
raise TypeError(msg)
|
| 19 |
+
global _run_validators
|
| 20 |
+
_run_validators = run
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_run_validators():
|
| 24 |
+
"""
|
| 25 |
+
Return whether or not validators are run.
|
| 26 |
+
|
| 27 |
+
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
|
| 28 |
+
moved to new ``attrs`` namespace. Use `attrs.validators.get_disabled()`
|
| 29 |
+
instead.
|
| 30 |
+
"""
|
| 31 |
+
return _run_validators
|
.venv/lib/python3.11/site-packages/attr/_funcs.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
from ._compat import PY_3_9_PLUS, get_generic_base
|
| 7 |
+
from ._make import _OBJ_SETATTR, NOTHING, fields
|
| 8 |
+
from .exceptions import AttrsAttributeNotFoundError
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def asdict(
|
| 12 |
+
inst,
|
| 13 |
+
recurse=True,
|
| 14 |
+
filter=None,
|
| 15 |
+
dict_factory=dict,
|
| 16 |
+
retain_collection_types=False,
|
| 17 |
+
value_serializer=None,
|
| 18 |
+
):
|
| 19 |
+
"""
|
| 20 |
+
Return the *attrs* attribute values of *inst* as a dict.
|
| 21 |
+
|
| 22 |
+
Optionally recurse into other *attrs*-decorated classes.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
inst: Instance of an *attrs*-decorated class.
|
| 26 |
+
|
| 27 |
+
recurse (bool): Recurse into classes that are also *attrs*-decorated.
|
| 28 |
+
|
| 29 |
+
filter (~typing.Callable):
|
| 30 |
+
A callable whose return code determines whether an attribute or
|
| 31 |
+
element is included (`True`) or dropped (`False`). Is called with
|
| 32 |
+
the `attrs.Attribute` as the first argument and the value as the
|
| 33 |
+
second argument.
|
| 34 |
+
|
| 35 |
+
dict_factory (~typing.Callable):
|
| 36 |
+
A callable to produce dictionaries from. For example, to produce
|
| 37 |
+
ordered dictionaries instead of normal Python dictionaries, pass in
|
| 38 |
+
``collections.OrderedDict``.
|
| 39 |
+
|
| 40 |
+
retain_collection_types (bool):
|
| 41 |
+
Do not convert to `list` when encountering an attribute whose type
|
| 42 |
+
is `tuple` or `set`. Only meaningful if *recurse* is `True`.
|
| 43 |
+
|
| 44 |
+
value_serializer (typing.Callable | None):
|
| 45 |
+
A hook that is called for every attribute or dict key/value. It
|
| 46 |
+
receives the current instance, field and value and must return the
|
| 47 |
+
(updated) value. The hook is run *after* the optional *filter* has
|
| 48 |
+
been applied.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Return type of *dict_factory*.
|
| 52 |
+
|
| 53 |
+
Raises:
|
| 54 |
+
attrs.exceptions.NotAnAttrsClassError:
|
| 55 |
+
If *cls* is not an *attrs* class.
|
| 56 |
+
|
| 57 |
+
.. versionadded:: 16.0.0 *dict_factory*
|
| 58 |
+
.. versionadded:: 16.1.0 *retain_collection_types*
|
| 59 |
+
.. versionadded:: 20.3.0 *value_serializer*
|
| 60 |
+
.. versionadded:: 21.3.0
|
| 61 |
+
If a dict has a collection for a key, it is serialized as a tuple.
|
| 62 |
+
"""
|
| 63 |
+
attrs = fields(inst.__class__)
|
| 64 |
+
rv = dict_factory()
|
| 65 |
+
for a in attrs:
|
| 66 |
+
v = getattr(inst, a.name)
|
| 67 |
+
if filter is not None and not filter(a, v):
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
if value_serializer is not None:
|
| 71 |
+
v = value_serializer(inst, a, v)
|
| 72 |
+
|
| 73 |
+
if recurse is True:
|
| 74 |
+
if has(v.__class__):
|
| 75 |
+
rv[a.name] = asdict(
|
| 76 |
+
v,
|
| 77 |
+
recurse=True,
|
| 78 |
+
filter=filter,
|
| 79 |
+
dict_factory=dict_factory,
|
| 80 |
+
retain_collection_types=retain_collection_types,
|
| 81 |
+
value_serializer=value_serializer,
|
| 82 |
+
)
|
| 83 |
+
elif isinstance(v, (tuple, list, set, frozenset)):
|
| 84 |
+
cf = v.__class__ if retain_collection_types is True else list
|
| 85 |
+
items = [
|
| 86 |
+
_asdict_anything(
|
| 87 |
+
i,
|
| 88 |
+
is_key=False,
|
| 89 |
+
filter=filter,
|
| 90 |
+
dict_factory=dict_factory,
|
| 91 |
+
retain_collection_types=retain_collection_types,
|
| 92 |
+
value_serializer=value_serializer,
|
| 93 |
+
)
|
| 94 |
+
for i in v
|
| 95 |
+
]
|
| 96 |
+
try:
|
| 97 |
+
rv[a.name] = cf(items)
|
| 98 |
+
except TypeError:
|
| 99 |
+
if not issubclass(cf, tuple):
|
| 100 |
+
raise
|
| 101 |
+
# Workaround for TypeError: cf.__new__() missing 1 required
|
| 102 |
+
# positional argument (which appears, for a namedturle)
|
| 103 |
+
rv[a.name] = cf(*items)
|
| 104 |
+
elif isinstance(v, dict):
|
| 105 |
+
df = dict_factory
|
| 106 |
+
rv[a.name] = df(
|
| 107 |
+
(
|
| 108 |
+
_asdict_anything(
|
| 109 |
+
kk,
|
| 110 |
+
is_key=True,
|
| 111 |
+
filter=filter,
|
| 112 |
+
dict_factory=df,
|
| 113 |
+
retain_collection_types=retain_collection_types,
|
| 114 |
+
value_serializer=value_serializer,
|
| 115 |
+
),
|
| 116 |
+
_asdict_anything(
|
| 117 |
+
vv,
|
| 118 |
+
is_key=False,
|
| 119 |
+
filter=filter,
|
| 120 |
+
dict_factory=df,
|
| 121 |
+
retain_collection_types=retain_collection_types,
|
| 122 |
+
value_serializer=value_serializer,
|
| 123 |
+
),
|
| 124 |
+
)
|
| 125 |
+
for kk, vv in v.items()
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
rv[a.name] = v
|
| 129 |
+
else:
|
| 130 |
+
rv[a.name] = v
|
| 131 |
+
return rv
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _asdict_anything(
|
| 135 |
+
val,
|
| 136 |
+
is_key,
|
| 137 |
+
filter,
|
| 138 |
+
dict_factory,
|
| 139 |
+
retain_collection_types,
|
| 140 |
+
value_serializer,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
``asdict`` only works on attrs instances, this works on anything.
|
| 144 |
+
"""
|
| 145 |
+
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
|
| 146 |
+
# Attrs class.
|
| 147 |
+
rv = asdict(
|
| 148 |
+
val,
|
| 149 |
+
recurse=True,
|
| 150 |
+
filter=filter,
|
| 151 |
+
dict_factory=dict_factory,
|
| 152 |
+
retain_collection_types=retain_collection_types,
|
| 153 |
+
value_serializer=value_serializer,
|
| 154 |
+
)
|
| 155 |
+
elif isinstance(val, (tuple, list, set, frozenset)):
|
| 156 |
+
if retain_collection_types is True:
|
| 157 |
+
cf = val.__class__
|
| 158 |
+
elif is_key:
|
| 159 |
+
cf = tuple
|
| 160 |
+
else:
|
| 161 |
+
cf = list
|
| 162 |
+
|
| 163 |
+
rv = cf(
|
| 164 |
+
[
|
| 165 |
+
_asdict_anything(
|
| 166 |
+
i,
|
| 167 |
+
is_key=False,
|
| 168 |
+
filter=filter,
|
| 169 |
+
dict_factory=dict_factory,
|
| 170 |
+
retain_collection_types=retain_collection_types,
|
| 171 |
+
value_serializer=value_serializer,
|
| 172 |
+
)
|
| 173 |
+
for i in val
|
| 174 |
+
]
|
| 175 |
+
)
|
| 176 |
+
elif isinstance(val, dict):
|
| 177 |
+
df = dict_factory
|
| 178 |
+
rv = df(
|
| 179 |
+
(
|
| 180 |
+
_asdict_anything(
|
| 181 |
+
kk,
|
| 182 |
+
is_key=True,
|
| 183 |
+
filter=filter,
|
| 184 |
+
dict_factory=df,
|
| 185 |
+
retain_collection_types=retain_collection_types,
|
| 186 |
+
value_serializer=value_serializer,
|
| 187 |
+
),
|
| 188 |
+
_asdict_anything(
|
| 189 |
+
vv,
|
| 190 |
+
is_key=False,
|
| 191 |
+
filter=filter,
|
| 192 |
+
dict_factory=df,
|
| 193 |
+
retain_collection_types=retain_collection_types,
|
| 194 |
+
value_serializer=value_serializer,
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
for kk, vv in val.items()
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
rv = val
|
| 201 |
+
if value_serializer is not None:
|
| 202 |
+
rv = value_serializer(None, None, rv)
|
| 203 |
+
|
| 204 |
+
return rv
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def astuple(
|
| 208 |
+
inst,
|
| 209 |
+
recurse=True,
|
| 210 |
+
filter=None,
|
| 211 |
+
tuple_factory=tuple,
|
| 212 |
+
retain_collection_types=False,
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
Return the *attrs* attribute values of *inst* as a tuple.
|
| 216 |
+
|
| 217 |
+
Optionally recurse into other *attrs*-decorated classes.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
inst: Instance of an *attrs*-decorated class.
|
| 221 |
+
|
| 222 |
+
recurse (bool):
|
| 223 |
+
Recurse into classes that are also *attrs*-decorated.
|
| 224 |
+
|
| 225 |
+
filter (~typing.Callable):
|
| 226 |
+
A callable whose return code determines whether an attribute or
|
| 227 |
+
element is included (`True`) or dropped (`False`). Is called with
|
| 228 |
+
the `attrs.Attribute` as the first argument and the value as the
|
| 229 |
+
second argument.
|
| 230 |
+
|
| 231 |
+
tuple_factory (~typing.Callable):
|
| 232 |
+
A callable to produce tuples from. For example, to produce lists
|
| 233 |
+
instead of tuples.
|
| 234 |
+
|
| 235 |
+
retain_collection_types (bool):
|
| 236 |
+
Do not convert to `list` or `dict` when encountering an attribute
|
| 237 |
+
which type is `tuple`, `dict` or `set`. Only meaningful if
|
| 238 |
+
*recurse* is `True`.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Return type of *tuple_factory*
|
| 242 |
+
|
| 243 |
+
Raises:
|
| 244 |
+
attrs.exceptions.NotAnAttrsClassError:
|
| 245 |
+
If *cls* is not an *attrs* class.
|
| 246 |
+
|
| 247 |
+
.. versionadded:: 16.2.0
|
| 248 |
+
"""
|
| 249 |
+
attrs = fields(inst.__class__)
|
| 250 |
+
rv = []
|
| 251 |
+
retain = retain_collection_types # Very long. :/
|
| 252 |
+
for a in attrs:
|
| 253 |
+
v = getattr(inst, a.name)
|
| 254 |
+
if filter is not None and not filter(a, v):
|
| 255 |
+
continue
|
| 256 |
+
if recurse is True:
|
| 257 |
+
if has(v.__class__):
|
| 258 |
+
rv.append(
|
| 259 |
+
astuple(
|
| 260 |
+
v,
|
| 261 |
+
recurse=True,
|
| 262 |
+
filter=filter,
|
| 263 |
+
tuple_factory=tuple_factory,
|
| 264 |
+
retain_collection_types=retain,
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
elif isinstance(v, (tuple, list, set, frozenset)):
|
| 268 |
+
cf = v.__class__ if retain is True else list
|
| 269 |
+
items = [
|
| 270 |
+
(
|
| 271 |
+
astuple(
|
| 272 |
+
j,
|
| 273 |
+
recurse=True,
|
| 274 |
+
filter=filter,
|
| 275 |
+
tuple_factory=tuple_factory,
|
| 276 |
+
retain_collection_types=retain,
|
| 277 |
+
)
|
| 278 |
+
if has(j.__class__)
|
| 279 |
+
else j
|
| 280 |
+
)
|
| 281 |
+
for j in v
|
| 282 |
+
]
|
| 283 |
+
try:
|
| 284 |
+
rv.append(cf(items))
|
| 285 |
+
except TypeError:
|
| 286 |
+
if not issubclass(cf, tuple):
|
| 287 |
+
raise
|
| 288 |
+
# Workaround for TypeError: cf.__new__() missing 1 required
|
| 289 |
+
# positional argument (which appears, for a namedturle)
|
| 290 |
+
rv.append(cf(*items))
|
| 291 |
+
elif isinstance(v, dict):
|
| 292 |
+
df = v.__class__ if retain is True else dict
|
| 293 |
+
rv.append(
|
| 294 |
+
df(
|
| 295 |
+
(
|
| 296 |
+
(
|
| 297 |
+
astuple(
|
| 298 |
+
kk,
|
| 299 |
+
tuple_factory=tuple_factory,
|
| 300 |
+
retain_collection_types=retain,
|
| 301 |
+
)
|
| 302 |
+
if has(kk.__class__)
|
| 303 |
+
else kk
|
| 304 |
+
),
|
| 305 |
+
(
|
| 306 |
+
astuple(
|
| 307 |
+
vv,
|
| 308 |
+
tuple_factory=tuple_factory,
|
| 309 |
+
retain_collection_types=retain,
|
| 310 |
+
)
|
| 311 |
+
if has(vv.__class__)
|
| 312 |
+
else vv
|
| 313 |
+
),
|
| 314 |
+
)
|
| 315 |
+
for kk, vv in v.items()
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
rv.append(v)
|
| 320 |
+
else:
|
| 321 |
+
rv.append(v)
|
| 322 |
+
|
| 323 |
+
return rv if tuple_factory is list else tuple_factory(rv)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def has(cls):
|
| 327 |
+
"""
|
| 328 |
+
Check whether *cls* is a class with *attrs* attributes.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
cls (type): Class to introspect.
|
| 332 |
+
|
| 333 |
+
Raises:
|
| 334 |
+
TypeError: If *cls* is not a class.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
bool:
|
| 338 |
+
"""
|
| 339 |
+
attrs = getattr(cls, "__attrs_attrs__", None)
|
| 340 |
+
if attrs is not None:
|
| 341 |
+
return True
|
| 342 |
+
|
| 343 |
+
# No attrs, maybe it's a specialized generic (A[str])?
|
| 344 |
+
generic_base = get_generic_base(cls)
|
| 345 |
+
if generic_base is not None:
|
| 346 |
+
generic_attrs = getattr(generic_base, "__attrs_attrs__", None)
|
| 347 |
+
if generic_attrs is not None:
|
| 348 |
+
# Stick it on here for speed next time.
|
| 349 |
+
cls.__attrs_attrs__ = generic_attrs
|
| 350 |
+
return generic_attrs is not None
|
| 351 |
+
return False
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def assoc(inst, **changes):
|
| 355 |
+
"""
|
| 356 |
+
Copy *inst* and apply *changes*.
|
| 357 |
+
|
| 358 |
+
This is different from `evolve` that applies the changes to the arguments
|
| 359 |
+
that create the new instance.
|
| 360 |
+
|
| 361 |
+
`evolve`'s behavior is preferable, but there are `edge cases`_ where it
|
| 362 |
+
doesn't work. Therefore `assoc` is deprecated, but will not be removed.
|
| 363 |
+
|
| 364 |
+
.. _`edge cases`: https://github.com/python-attrs/attrs/issues/251
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
inst: Instance of a class with *attrs* attributes.
|
| 368 |
+
|
| 369 |
+
changes: Keyword changes in the new copy.
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
A copy of inst with *changes* incorporated.
|
| 373 |
+
|
| 374 |
+
Raises:
|
| 375 |
+
attrs.exceptions.AttrsAttributeNotFoundError:
|
| 376 |
+
If *attr_name* couldn't be found on *cls*.
|
| 377 |
+
|
| 378 |
+
attrs.exceptions.NotAnAttrsClassError:
|
| 379 |
+
If *cls* is not an *attrs* class.
|
| 380 |
+
|
| 381 |
+
.. deprecated:: 17.1.0
|
| 382 |
+
Use `attrs.evolve` instead if you can. This function will not be
|
| 383 |
+
removed du to the slightly different approach compared to
|
| 384 |
+
`attrs.evolve`, though.
|
| 385 |
+
"""
|
| 386 |
+
new = copy.copy(inst)
|
| 387 |
+
attrs = fields(inst.__class__)
|
| 388 |
+
for k, v in changes.items():
|
| 389 |
+
a = getattr(attrs, k, NOTHING)
|
| 390 |
+
if a is NOTHING:
|
| 391 |
+
msg = f"{k} is not an attrs attribute on {new.__class__}."
|
| 392 |
+
raise AttrsAttributeNotFoundError(msg)
|
| 393 |
+
_OBJ_SETATTR(new, k, v)
|
| 394 |
+
return new
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def resolve_types(
|
| 398 |
+
cls, globalns=None, localns=None, attribs=None, include_extras=True
|
| 399 |
+
):
|
| 400 |
+
"""
|
| 401 |
+
Resolve any strings and forward annotations in type annotations.
|
| 402 |
+
|
| 403 |
+
This is only required if you need concrete types in :class:`Attribute`'s
|
| 404 |
+
*type* field. In other words, you don't need to resolve your types if you
|
| 405 |
+
only use them for static type checking.
|
| 406 |
+
|
| 407 |
+
With no arguments, names will be looked up in the module in which the class
|
| 408 |
+
was created. If this is not what you want, for example, if the name only
|
| 409 |
+
exists inside a method, you may pass *globalns* or *localns* to specify
|
| 410 |
+
other dictionaries in which to look up these names. See the docs of
|
| 411 |
+
`typing.get_type_hints` for more details.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
cls (type): Class to resolve.
|
| 415 |
+
|
| 416 |
+
globalns (dict | None): Dictionary containing global variables.
|
| 417 |
+
|
| 418 |
+
localns (dict | None): Dictionary containing local variables.
|
| 419 |
+
|
| 420 |
+
attribs (list | None):
|
| 421 |
+
List of attribs for the given class. This is necessary when calling
|
| 422 |
+
from inside a ``field_transformer`` since *cls* is not an *attrs*
|
| 423 |
+
class yet.
|
| 424 |
+
|
| 425 |
+
include_extras (bool):
|
| 426 |
+
Resolve more accurately, if possible. Pass ``include_extras`` to
|
| 427 |
+
``typing.get_hints``, if supported by the typing module. On
|
| 428 |
+
supported Python versions (3.9+), this resolves the types more
|
| 429 |
+
accurately.
|
| 430 |
+
|
| 431 |
+
Raises:
|
| 432 |
+
TypeError: If *cls* is not a class.
|
| 433 |
+
|
| 434 |
+
attrs.exceptions.NotAnAttrsClassError:
|
| 435 |
+
If *cls* is not an *attrs* class and you didn't pass any attribs.
|
| 436 |
+
|
| 437 |
+
NameError: If types cannot be resolved because of missing variables.
|
| 438 |
+
|
| 439 |
+
Returns:
|
| 440 |
+
*cls* so you can use this function also as a class decorator. Please
|
| 441 |
+
note that you have to apply it **after** `attrs.define`. That means the
|
| 442 |
+
decorator has to come in the line **before** `attrs.define`.
|
| 443 |
+
|
| 444 |
+
.. versionadded:: 20.1.0
|
| 445 |
+
.. versionadded:: 21.1.0 *attribs*
|
| 446 |
+
.. versionadded:: 23.1.0 *include_extras*
|
| 447 |
+
"""
|
| 448 |
+
# Since calling get_type_hints is expensive we cache whether we've
|
| 449 |
+
# done it already.
|
| 450 |
+
if getattr(cls, "__attrs_types_resolved__", None) != cls:
|
| 451 |
+
import typing
|
| 452 |
+
|
| 453 |
+
kwargs = {"globalns": globalns, "localns": localns}
|
| 454 |
+
|
| 455 |
+
if PY_3_9_PLUS:
|
| 456 |
+
kwargs["include_extras"] = include_extras
|
| 457 |
+
|
| 458 |
+
hints = typing.get_type_hints(cls, **kwargs)
|
| 459 |
+
for field in fields(cls) if attribs is None else attribs:
|
| 460 |
+
if field.name in hints:
|
| 461 |
+
# Since fields have been frozen we must work around it.
|
| 462 |
+
_OBJ_SETATTR(field, "type", hints[field.name])
|
| 463 |
+
# We store the class we resolved so that subclasses know they haven't
|
| 464 |
+
# been resolved.
|
| 465 |
+
cls.__attrs_types_resolved__ = cls
|
| 466 |
+
|
| 467 |
+
# Return the class so you can use it as a decorator too.
|
| 468 |
+
return cls
|
.venv/lib/python3.11/site-packages/attr/_version_info.pyi
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class VersionInfo:
|
| 2 |
+
@property
|
| 3 |
+
def year(self) -> int: ...
|
| 4 |
+
@property
|
| 5 |
+
def minor(self) -> int: ...
|
| 6 |
+
@property
|
| 7 |
+
def micro(self) -> int: ...
|
| 8 |
+
@property
|
| 9 |
+
def releaselevel(self) -> str: ...
|
.venv/lib/python3.11/site-packages/attr/exceptions.pyi
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
class FrozenError(AttributeError):
|
| 4 |
+
msg: str = ...
|
| 5 |
+
|
| 6 |
+
class FrozenInstanceError(FrozenError): ...
|
| 7 |
+
class FrozenAttributeError(FrozenError): ...
|
| 8 |
+
class AttrsAttributeNotFoundError(ValueError): ...
|
| 9 |
+
class NotAnAttrsClassError(ValueError): ...
|
| 10 |
+
class DefaultAlreadySetError(RuntimeError): ...
|
| 11 |
+
class UnannotatedAttributeError(RuntimeError): ...
|
| 12 |
+
class PythonTooOldError(RuntimeError): ...
|
| 13 |
+
|
| 14 |
+
class NotCallableError(TypeError):
|
| 15 |
+
msg: str = ...
|
| 16 |
+
value: Any = ...
|
| 17 |
+
def __init__(self, msg: str, value: Any) -> None: ...
|
.venv/lib/python3.11/site-packages/attr/validators.pyi
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import UnionType
|
| 2 |
+
from typing import (
|
| 3 |
+
Any,
|
| 4 |
+
AnyStr,
|
| 5 |
+
Callable,
|
| 6 |
+
Container,
|
| 7 |
+
ContextManager,
|
| 8 |
+
Iterable,
|
| 9 |
+
Mapping,
|
| 10 |
+
Match,
|
| 11 |
+
Pattern,
|
| 12 |
+
TypeVar,
|
| 13 |
+
overload,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from attrs import _ValidatorType
|
| 17 |
+
from attrs import _ValidatorArgType
|
| 18 |
+
|
| 19 |
+
_T = TypeVar("_T")
|
| 20 |
+
_T1 = TypeVar("_T1")
|
| 21 |
+
_T2 = TypeVar("_T2")
|
| 22 |
+
_T3 = TypeVar("_T3")
|
| 23 |
+
_I = TypeVar("_I", bound=Iterable)
|
| 24 |
+
_K = TypeVar("_K")
|
| 25 |
+
_V = TypeVar("_V")
|
| 26 |
+
_M = TypeVar("_M", bound=Mapping)
|
| 27 |
+
|
| 28 |
+
def set_disabled(run: bool) -> None: ...
|
| 29 |
+
def get_disabled() -> bool: ...
|
| 30 |
+
def disabled() -> ContextManager[None]: ...
|
| 31 |
+
|
| 32 |
+
# To be more precise on instance_of use some overloads.
|
| 33 |
+
# If there are more than 3 items in the tuple then we fall back to Any
|
| 34 |
+
@overload
|
| 35 |
+
def instance_of(type: type[_T]) -> _ValidatorType[_T]: ...
|
| 36 |
+
@overload
|
| 37 |
+
def instance_of(type: tuple[type[_T]]) -> _ValidatorType[_T]: ...
|
| 38 |
+
@overload
|
| 39 |
+
def instance_of(
|
| 40 |
+
type: tuple[type[_T1], type[_T2]],
|
| 41 |
+
) -> _ValidatorType[_T1 | _T2]: ...
|
| 42 |
+
@overload
|
| 43 |
+
def instance_of(
|
| 44 |
+
type: tuple[type[_T1], type[_T2], type[_T3]],
|
| 45 |
+
) -> _ValidatorType[_T1 | _T2 | _T3]: ...
|
| 46 |
+
@overload
|
| 47 |
+
def instance_of(type: tuple[type, ...]) -> _ValidatorType[Any]: ...
|
| 48 |
+
@overload
|
| 49 |
+
def instance_of(type: UnionType) -> _ValidatorType[Any]: ...
|
| 50 |
+
def optional(
|
| 51 |
+
validator: (
|
| 52 |
+
_ValidatorType[_T]
|
| 53 |
+
| list[_ValidatorType[_T]]
|
| 54 |
+
| tuple[_ValidatorType[_T]]
|
| 55 |
+
),
|
| 56 |
+
) -> _ValidatorType[_T | None]: ...
|
| 57 |
+
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
|
| 58 |
+
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
|
| 59 |
+
def matches_re(
|
| 60 |
+
regex: Pattern[AnyStr] | AnyStr,
|
| 61 |
+
flags: int = ...,
|
| 62 |
+
func: Callable[[AnyStr, AnyStr, int], Match[AnyStr] | None] | None = ...,
|
| 63 |
+
) -> _ValidatorType[AnyStr]: ...
|
| 64 |
+
def deep_iterable(
|
| 65 |
+
member_validator: _ValidatorArgType[_T],
|
| 66 |
+
iterable_validator: _ValidatorType[_I] | None = ...,
|
| 67 |
+
) -> _ValidatorType[_I]: ...
|
| 68 |
+
def deep_mapping(
|
| 69 |
+
key_validator: _ValidatorType[_K],
|
| 70 |
+
value_validator: _ValidatorType[_V],
|
| 71 |
+
mapping_validator: _ValidatorType[_M] | None = ...,
|
| 72 |
+
) -> _ValidatorType[_M]: ...
|
| 73 |
+
def is_callable() -> _ValidatorType[_T]: ...
|
| 74 |
+
def lt(val: _T) -> _ValidatorType[_T]: ...
|
| 75 |
+
def le(val: _T) -> _ValidatorType[_T]: ...
|
| 76 |
+
def ge(val: _T) -> _ValidatorType[_T]: ...
|
| 77 |
+
def gt(val: _T) -> _ValidatorType[_T]: ...
|
| 78 |
+
def max_len(length: int) -> _ValidatorType[_T]: ...
|
| 79 |
+
def min_len(length: int) -> _ValidatorType[_T]: ...
|
| 80 |
+
def not_(
|
| 81 |
+
validator: _ValidatorType[_T],
|
| 82 |
+
*,
|
| 83 |
+
msg: str | None = None,
|
| 84 |
+
exc_types: type[Exception] | Iterable[type[Exception]] = ...,
|
| 85 |
+
) -> _ValidatorType[_T]: ...
|
| 86 |
+
def or_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
|
.venv/lib/python3.11/site-packages/msgspec/_core.cpython-311-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a6211b1e1e47f505c8b79cb8b191ba1169b99f917cd874de883cffd11aa9883
|
| 3 |
+
size 406024
|
.venv/lib/python3.11/site-packages/outlines/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Outlines is a Generative Model Programming Framework."""
|
| 2 |
+
import outlines.generate
|
| 3 |
+
import outlines.grammars
|
| 4 |
+
import outlines.models
|
| 5 |
+
import outlines.processors
|
| 6 |
+
import outlines.types
|
| 7 |
+
from outlines.base import vectorize
|
| 8 |
+
from outlines.caching import clear_cache, disable_cache, get_cache
|
| 9 |
+
from outlines.function import Function
|
| 10 |
+
from outlines.prompts import prompt
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"clear_cache",
|
| 14 |
+
"disable_cache",
|
| 15 |
+
"get_cache",
|
| 16 |
+
"Function",
|
| 17 |
+
"prompt",
|
| 18 |
+
"vectorize",
|
| 19 |
+
"grammars",
|
| 20 |
+
]
|
.venv/lib/python3.11/site-packages/outlines/_version.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file generated by setuptools_scm
|
| 2 |
+
# don't change, don't track in version control
|
| 3 |
+
TYPE_CHECKING = False
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from typing import Tuple, Union
|
| 6 |
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
| 7 |
+
else:
|
| 8 |
+
VERSION_TUPLE = object
|
| 9 |
+
|
| 10 |
+
version: str
|
| 11 |
+
__version__: str
|
| 12 |
+
__version_tuple__: VERSION_TUPLE
|
| 13 |
+
version_tuple: VERSION_TUPLE
|
| 14 |
+
|
| 15 |
+
__version__ = version = '0.1.11'
|
| 16 |
+
__version_tuple__ = version_tuple = (0, 1, 11)
|
.venv/lib/python3.11/site-packages/outlines/base.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import builtins
|
| 3 |
+
import functools
|
| 4 |
+
import inspect
|
| 5 |
+
from typing import Callable, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# Import required functions based on NumPy version
|
| 10 |
+
np_major_version = int(np.__version__.split(".")[0])
|
| 11 |
+
if np_major_version >= 2:
|
| 12 |
+
from numpy.lib._function_base_impl import (
|
| 13 |
+
_calculate_shapes,
|
| 14 |
+
_parse_gufunc_signature,
|
| 15 |
+
_parse_input_dimensions,
|
| 16 |
+
_update_dim_sizes,
|
| 17 |
+
)
|
| 18 |
+
else:
|
| 19 |
+
from numpy.lib.function_base import (
|
| 20 |
+
_calculate_shapes,
|
| 21 |
+
_parse_gufunc_signature,
|
| 22 |
+
_parse_input_dimensions,
|
| 23 |
+
_update_dim_sizes,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Allow nested loops for running in notebook. We don't enable it globally as it
|
| 27 |
+
# may interfere with other libraries that use asyncio.
|
| 28 |
+
if hasattr(builtins, "__IPYTHON__"):
|
| 29 |
+
try:
|
| 30 |
+
import nest_asyncio
|
| 31 |
+
|
| 32 |
+
nest_asyncio.apply()
|
| 33 |
+
except ImportError:
|
| 34 |
+
print(
|
| 35 |
+
"Couldn't patch nest_asyncio because it's not installed. Running in the notebook might be have issues"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class vectorize:
|
| 40 |
+
"""Returns an object that acts like a function but takes arrays as an input.
|
| 41 |
+
|
| 42 |
+
The vectorized function evaluates `func` over successive tuples of the input
|
| 43 |
+
chararrays and returns a single NumPy chararrays or a tuple of NumPy chararrays.
|
| 44 |
+
|
| 45 |
+
Its behavior is similar to NumPy's `vectorize` for Python functions: the function
|
| 46 |
+
being vectorized is executed in a `for` loop. Coroutines, however, are executed
|
| 47 |
+
concurrently.
|
| 48 |
+
|
| 49 |
+
Part of the code was adapted from `numpy.lib.function_base`.
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, func: Callable, signature: Optional[str] = None):
|
| 54 |
+
self.func = func
|
| 55 |
+
self.signature = signature
|
| 56 |
+
self.is_coroutine_fn = inspect.iscoroutinefunction(func)
|
| 57 |
+
|
| 58 |
+
functools.update_wrapper(self, func)
|
| 59 |
+
|
| 60 |
+
if signature is not None:
|
| 61 |
+
# Parse the signature string into a Python data structure.
|
| 62 |
+
# For instance "(m),(s)->(s,m)" becomes `([(m,),(s,)],[(s,m)])`.
|
| 63 |
+
self._in_and_out_core_dimensions = _parse_gufunc_signature(signature)
|
| 64 |
+
else:
|
| 65 |
+
self._in_and_out_core_dimensions = None
|
| 66 |
+
|
| 67 |
+
def __call__(self, *args, **kwargs):
|
| 68 |
+
"""Call the vectorized function."""
|
| 69 |
+
if not args and not kwargs:
|
| 70 |
+
return self.call_thunk()
|
| 71 |
+
elif self.signature is not None:
|
| 72 |
+
return self.call_with_signature(*args, **kwargs)
|
| 73 |
+
else:
|
| 74 |
+
return self.call_no_signature(*args, **kwargs)
|
| 75 |
+
|
| 76 |
+
def call_thunk(self):
|
| 77 |
+
"""Call a vectorized thunk.
|
| 78 |
+
|
| 79 |
+
Thunks have no arguments and can thus be called directly.
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
if self.is_coroutine_fn:
|
| 83 |
+
loop = asyncio.new_event_loop()
|
| 84 |
+
try:
|
| 85 |
+
outputs = loop.run_until_complete(self.func())
|
| 86 |
+
finally:
|
| 87 |
+
loop.close()
|
| 88 |
+
else:
|
| 89 |
+
outputs = self.func()
|
| 90 |
+
|
| 91 |
+
return outputs
|
| 92 |
+
|
| 93 |
+
def call_no_signature(self, *args, **kwargs):
|
| 94 |
+
"""Call functions and coroutines when no signature is specified.
|
| 95 |
+
|
| 96 |
+
When no signature is specified we assume that all of the function's
|
| 97 |
+
inputs and outputs are scalars (core dimension of zero). We first
|
| 98 |
+
broadcast the input arrays, then iteratively apply the function over the
|
| 99 |
+
elements of the broadcasted arrays and finally reshape the results to
|
| 100 |
+
match the input shape.
|
| 101 |
+
|
| 102 |
+
Functions are executed in a for loop, coroutines are executed
|
| 103 |
+
concurrently.
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
# Convert args and kwargs to arrays
|
| 107 |
+
args = [np.array(arg) for arg in args]
|
| 108 |
+
kwargs = {key: np.array(value) for key, value in kwargs.items()}
|
| 109 |
+
|
| 110 |
+
# Broadcast args and kwargs
|
| 111 |
+
broadcast_shape = np.broadcast(*args, *list(kwargs.values())).shape
|
| 112 |
+
args = [np.broadcast_to(arg, broadcast_shape) for arg in args]
|
| 113 |
+
kwargs = {
|
| 114 |
+
key: np.broadcast_to(value, broadcast_shape)
|
| 115 |
+
for key, value in kwargs.items()
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# Execute functions in a loop, and coroutines concurrently
|
| 119 |
+
if self.is_coroutine_fn:
|
| 120 |
+
outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs)
|
| 121 |
+
else:
|
| 122 |
+
outputs = self.vectorize_call(broadcast_shape, args, kwargs)
|
| 123 |
+
|
| 124 |
+
# `outputs` is a flat array or a tuple of flat arrays. We reshape the arrays
|
| 125 |
+
# to match the input shape.
|
| 126 |
+
outputs = [
|
| 127 |
+
results if isinstance(results, tuple) else (results,) for results in outputs
|
| 128 |
+
]
|
| 129 |
+
outputs = tuple(
|
| 130 |
+
[np.asarray(x).reshape(broadcast_shape).squeeze() for x in zip(*outputs)]
|
| 131 |
+
)
|
| 132 |
+
outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs])
|
| 133 |
+
|
| 134 |
+
n_results = len(list(outputs))
|
| 135 |
+
|
| 136 |
+
return outputs[0] if n_results == 1 else outputs
|
| 137 |
+
|
| 138 |
+
def call_with_signature(self, *args, **kwargs):
|
| 139 |
+
"""Call functions and coroutines when a signature is specified."""
|
| 140 |
+
input_core_dims, output_core_dims = self._in_and_out_core_dimensions
|
| 141 |
+
|
| 142 |
+
# Make sure that the numbers of arguments passed is compatible with
|
| 143 |
+
# the signature.
|
| 144 |
+
num_args = len(args) + len(kwargs)
|
| 145 |
+
if num_args != len(input_core_dims):
|
| 146 |
+
raise TypeError(
|
| 147 |
+
"wrong number of positional arguments: "
|
| 148 |
+
"expected %r, got %r" % (len(input_core_dims), len(args))
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Convert args and kwargs to arrays
|
| 152 |
+
args = [np.asarray(arg) for arg in args]
|
| 153 |
+
kwargs = {key: np.array(value) for key, value in kwargs.items()}
|
| 154 |
+
|
| 155 |
+
# Find the arguments' broadcast shape, and map placeholder
|
| 156 |
+
# variables in the signature to the number of dimensions
|
| 157 |
+
# they correspond to given the arguments.
|
| 158 |
+
broadcast_shape, dim_sizes = _parse_input_dimensions(
|
| 159 |
+
args + list(kwargs.values()), input_core_dims
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Calculate the shape to which each of the arguments should be broadcasted
|
| 163 |
+
# and reshape them accordingly.
|
| 164 |
+
input_shapes = _calculate_shapes(broadcast_shape, dim_sizes, input_core_dims)
|
| 165 |
+
args = [
|
| 166 |
+
np.broadcast_to(arg, shape, subok=True)
|
| 167 |
+
for arg, shape in zip(args, input_shapes)
|
| 168 |
+
]
|
| 169 |
+
kwargs = {
|
| 170 |
+
key: np.broadcast_to(value, broadcast_shape)
|
| 171 |
+
for key, value in kwargs.items()
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
n_out = len(output_core_dims)
|
| 175 |
+
|
| 176 |
+
if self.is_coroutine_fn:
|
| 177 |
+
outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs)
|
| 178 |
+
else:
|
| 179 |
+
outputs = self.vectorize_call(broadcast_shape, args, kwargs)
|
| 180 |
+
|
| 181 |
+
outputs = [
|
| 182 |
+
results if isinstance(results, tuple) else (results,) for results in outputs
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
flat_outputs = list(zip(*outputs))
|
| 186 |
+
n_results = len(flat_outputs)
|
| 187 |
+
|
| 188 |
+
if n_out != n_results:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"wrong number of outputs from the function, expected {n_out}, got {n_results}"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# The number of dimensions of the outputs are not necessarily known in
|
| 194 |
+
# advance. The following iterates over the results and updates the
|
| 195 |
+
# number of dimensions of the outputs accordingly.
|
| 196 |
+
for results, core_dims in zip(flat_outputs, output_core_dims):
|
| 197 |
+
for result in results:
|
| 198 |
+
_update_dim_sizes(dim_sizes, result, core_dims)
|
| 199 |
+
|
| 200 |
+
# Calculate the shape to which each of the outputs should be broadcasted
|
| 201 |
+
# and reshape them.
|
| 202 |
+
shapes = _calculate_shapes(broadcast_shape, dim_sizes, output_core_dims)
|
| 203 |
+
outputs = tuple(
|
| 204 |
+
[
|
| 205 |
+
np.hstack(results).reshape(shape).squeeze()
|
| 206 |
+
for shape, results in zip(shapes, zip(*outputs))
|
| 207 |
+
]
|
| 208 |
+
)
|
| 209 |
+
outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs])
|
| 210 |
+
|
| 211 |
+
return outputs[0] if n_results == 1 else outputs
|
| 212 |
+
|
| 213 |
+
def vectorize_call(self, broadcast_shape, args, kwargs):
|
| 214 |
+
"""Run the function in a for loop.
|
| 215 |
+
|
| 216 |
+
A possible extension would be to parallelize the calls.
|
| 217 |
+
|
| 218 |
+
Parameters
|
| 219 |
+
----------
|
| 220 |
+
broadcast_shape
|
| 221 |
+
The brodcast shape of the input arrays.
|
| 222 |
+
args
|
| 223 |
+
The function's broadcasted arguments.
|
| 224 |
+
kwargs
|
| 225 |
+
The function's broadcasted keyword arguments.
|
| 226 |
+
|
| 227 |
+
"""
|
| 228 |
+
outputs = []
|
| 229 |
+
for index in np.ndindex(*broadcast_shape):
|
| 230 |
+
current_args = tuple(arg[index] for arg in args)
|
| 231 |
+
current_kwargs = {key: value[index] for key, value in kwargs.items()}
|
| 232 |
+
outputs.append(self.func(*current_args, **current_kwargs))
|
| 233 |
+
|
| 234 |
+
return outputs
|
| 235 |
+
|
| 236 |
+
def vectorize_call_coroutine(self, broadcast_shape, args, kwargs):
|
| 237 |
+
"""Run coroutines concurrently.
|
| 238 |
+
|
| 239 |
+
Creates as many tasks as needed and executes them in a new event
|
| 240 |
+
loop.
|
| 241 |
+
|
| 242 |
+
Parameters
|
| 243 |
+
----------
|
| 244 |
+
broadcast_shape
|
| 245 |
+
The brodcast shape of the input arrays.
|
| 246 |
+
args
|
| 247 |
+
The function's broadcasted arguments.
|
| 248 |
+
kwargs
|
| 249 |
+
The function's broadcasted keyword arguments.
|
| 250 |
+
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
async def create_and_gather_tasks():
|
| 254 |
+
tasks = []
|
| 255 |
+
for index in np.ndindex(*broadcast_shape):
|
| 256 |
+
current_args = tuple(arg[index] for arg in args)
|
| 257 |
+
current_kwargs = {key: value[index] for key, value in kwargs.items()}
|
| 258 |
+
tasks.append(self.func(*current_args, **current_kwargs))
|
| 259 |
+
|
| 260 |
+
outputs = await asyncio.gather(*tasks)
|
| 261 |
+
|
| 262 |
+
return outputs
|
| 263 |
+
|
| 264 |
+
loop = asyncio.new_event_loop()
|
| 265 |
+
try:
|
| 266 |
+
outputs = loop.run_until_complete(create_and_gather_tasks())
|
| 267 |
+
finally:
|
| 268 |
+
loop.close()
|
| 269 |
+
|
| 270 |
+
return outputs
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _update_arrays_type(arrays, results):
|
| 274 |
+
"""Update the dtype of arrays.
|
| 275 |
+
|
| 276 |
+
String arrays contain strings of fixed length. Here they are initialized with
|
| 277 |
+
the type of the first results, so that if the next results contain longer
|
| 278 |
+
strings they will be truncated when added to the output arrays. Here we
|
| 279 |
+
update the type if the current results contain longer strings than in the
|
| 280 |
+
current output array.
|
| 281 |
+
|
| 282 |
+
Parameters
|
| 283 |
+
----------
|
| 284 |
+
arrays
|
| 285 |
+
Arrays that contain the vectorized function's results.
|
| 286 |
+
results
|
| 287 |
+
The current output of the function being vectorized.
|
| 288 |
+
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
updated_arrays = []
|
| 292 |
+
for array, result in zip(arrays, results):
|
| 293 |
+
if array.dtype.type == np.str_:
|
| 294 |
+
if array.dtype < np.array(result).dtype:
|
| 295 |
+
array = array.astype(np.array(result).dtype)
|
| 296 |
+
|
| 297 |
+
updated_arrays.append(array)
|
| 298 |
+
|
| 299 |
+
return tuple(updated_arrays)
|
.venv/lib/python3.11/site-packages/outlines/caching.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import contextlib
|
| 3 |
+
import functools
|
| 4 |
+
import os
|
| 5 |
+
from typing import Callable, Optional
|
| 6 |
+
|
| 7 |
+
import cloudpickle
|
| 8 |
+
from diskcache import Cache, Disk
|
| 9 |
+
from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name
|
| 10 |
+
|
| 11 |
+
_caching_enabled = True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CloudpickleDisk(Disk):
|
| 15 |
+
def __init__(self, directory, compress_level=1, **kwargs):
|
| 16 |
+
self.compress_level = compress_level
|
| 17 |
+
super().__init__(directory, **kwargs)
|
| 18 |
+
|
| 19 |
+
def put(self, key):
|
| 20 |
+
data = cloudpickle.dumps(key)
|
| 21 |
+
return super().put(data)
|
| 22 |
+
|
| 23 |
+
def get(self, key, raw):
|
| 24 |
+
data = super().get(key, raw)
|
| 25 |
+
return cloudpickle.loads(data)
|
| 26 |
+
|
| 27 |
+
def store(self, value, read, key=UNKNOWN):
|
| 28 |
+
if not read:
|
| 29 |
+
value = cloudpickle.dumps(value)
|
| 30 |
+
return super().store(value, read, key=key)
|
| 31 |
+
|
| 32 |
+
def fetch(self, mode, filename, value, read):
|
| 33 |
+
data = super().fetch(mode, filename, value, read)
|
| 34 |
+
if not read:
|
| 35 |
+
data = cloudpickle.loads(data)
|
| 36 |
+
return data
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@functools.lru_cache(1)
|
| 40 |
+
def get_cache():
|
| 41 |
+
"""Get the context object that contains previously-computed return values.
|
| 42 |
+
|
| 43 |
+
The cache is used to avoid unnecessary computations and API calls, which can
|
| 44 |
+
be long and expensive for large models.
|
| 45 |
+
|
| 46 |
+
The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice
|
| 47 |
+
can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR`
|
| 48 |
+
environment variable.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
from outlines._version import __version__ as outlines_version # type: ignore
|
| 52 |
+
|
| 53 |
+
home_dir = os.path.expanduser("~")
|
| 54 |
+
cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
|
| 55 |
+
memory = Cache(
|
| 56 |
+
cache_dir,
|
| 57 |
+
eviction_policy="none",
|
| 58 |
+
cull_limit=0,
|
| 59 |
+
disk=CloudpickleDisk,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# ensure if version upgrade occurs, old cache is pruned
|
| 63 |
+
if outlines_version != memory.get("__version__"):
|
| 64 |
+
memory.clear()
|
| 65 |
+
memory["__version__"] = outlines_version
|
| 66 |
+
|
| 67 |
+
return memory
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def cache(expire: Optional[float] = None, typed=False, ignore=()):
|
| 71 |
+
"""Caching decorator for memoizing function calls.
|
| 72 |
+
|
| 73 |
+
The cache key is created based on the values returned by the key_function callable
|
| 74 |
+
if provided or based on the arguments of the decorated function directly otherwise
|
| 75 |
+
|
| 76 |
+
This is based on `diskcache`'s `memoize`.
|
| 77 |
+
|
| 78 |
+
Parameters
|
| 79 |
+
----------
|
| 80 |
+
expire
|
| 81 |
+
Seconds until arguments expire.
|
| 82 |
+
typed
|
| 83 |
+
Cache different types separately.
|
| 84 |
+
ignore
|
| 85 |
+
Positional or keyword arguments to ignore.
|
| 86 |
+
|
| 87 |
+
Returns
|
| 88 |
+
-------
|
| 89 |
+
A decorator function that can be applied to other functions.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def decorator(cached_function: Callable):
|
| 93 |
+
memory = get_cache()
|
| 94 |
+
|
| 95 |
+
base = (full_name(cached_function),)
|
| 96 |
+
|
| 97 |
+
if asyncio.iscoroutinefunction(cached_function):
|
| 98 |
+
|
| 99 |
+
async def wrapper(*args, **kwargs):
|
| 100 |
+
if not _caching_enabled:
|
| 101 |
+
return await cached_function(*args, **kwargs)
|
| 102 |
+
|
| 103 |
+
cache_key = wrapper.__cache_key__(*args, **kwargs)
|
| 104 |
+
result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)
|
| 105 |
+
|
| 106 |
+
if result is ENOVAL:
|
| 107 |
+
result = await cached_function(*args, **kwargs)
|
| 108 |
+
wrapper.__memory__.set(cache_key, result, expire, retry=True)
|
| 109 |
+
|
| 110 |
+
return result
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
|
| 114 |
+
def wrapper(*args, **kwargs):
|
| 115 |
+
if not _caching_enabled:
|
| 116 |
+
return cached_function(*args, **kwargs)
|
| 117 |
+
|
| 118 |
+
cache_key = wrapper.__cache_key__(*args, **kwargs)
|
| 119 |
+
result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)
|
| 120 |
+
|
| 121 |
+
if result is ENOVAL:
|
| 122 |
+
result = cached_function(*args, **kwargs)
|
| 123 |
+
wrapper.__memory__.set(cache_key, result, expire, retry=True)
|
| 124 |
+
|
| 125 |
+
return result
|
| 126 |
+
|
| 127 |
+
def __cache_key__(*args, **kwargs):
|
| 128 |
+
"""Make key for cache given function arguments."""
|
| 129 |
+
return args_to_key(base, args, kwargs, typed, ignore)
|
| 130 |
+
|
| 131 |
+
wrapper.__cache_key__ = __cache_key__ # type: ignore
|
| 132 |
+
wrapper.__memory__ = memory # type: ignore
|
| 133 |
+
wrapper.__wrapped__ = cached_function # type: ignore
|
| 134 |
+
|
| 135 |
+
return wrapper
|
| 136 |
+
|
| 137 |
+
return decorator
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def disable_cache():
|
| 141 |
+
"""Disable the cache for this session.
|
| 142 |
+
|
| 143 |
+
Generative models output different results each time they are called when
|
| 144 |
+
sampling. This can be a desirable property for some workflows, in which case
|
| 145 |
+
one can call `outlines.call.disable` to disable the cache for the session.
|
| 146 |
+
|
| 147 |
+
This function does not delete the cache, call `outlines.cache.clear`
|
| 148 |
+
instead. It also does not overwrite the cache with the values returned
|
| 149 |
+
during the session.
|
| 150 |
+
|
| 151 |
+
Example
|
| 152 |
+
-------
|
| 153 |
+
|
| 154 |
+
`outlines.cache.disable` should be called right after importing outlines:
|
| 155 |
+
|
| 156 |
+
>>> import outlines.caching as cache
|
| 157 |
+
>>> cache.disable_cache()
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
global _caching_enabled
|
| 161 |
+
_caching_enabled = False
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def clear_cache():
|
| 165 |
+
"""Erase the cache completely."""
|
| 166 |
+
memory = get_cache()
|
| 167 |
+
memory.clear()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@contextlib.contextmanager
|
| 171 |
+
def cache_disabled():
|
| 172 |
+
# outlines.caching._caching_enabled
|
| 173 |
+
global _caching_enabled
|
| 174 |
+
original_state = _caching_enabled
|
| 175 |
+
_caching_enabled = False
|
| 176 |
+
try:
|
| 177 |
+
yield
|
| 178 |
+
finally:
|
| 179 |
+
_caching_enabled = original_state
|
.venv/lib/python3.11/site-packages/outlines/fsm/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/guide.cpython-311.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/json_schema.cpython-311.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/parsing.cpython-311.pyc
ADDED
|
Binary file (52.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/fsm/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (5.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/fsm/guide.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import copy
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Generator, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from lark.indenter import DedentError
|
| 8 |
+
from lark.lexer import UnexpectedCharacters, UnexpectedToken
|
| 9 |
+
from outlines_core.fsm.guide import Generate
|
| 10 |
+
from outlines_core.fsm.guide import Guide as CoreGuide
|
| 11 |
+
from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide
|
| 12 |
+
from outlines_core.fsm.guide import Write
|
| 13 |
+
from outlines_core.fsm.guide import (
|
| 14 |
+
create_states_mapping as uncached_create_states_mapping,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from outlines import grammars
|
| 18 |
+
from outlines.fsm.parsing import PartialLark, PartialParserState
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from outlines.models.tokenizer import Tokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
Instruction = Union[Write, Generate]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Guide(CoreGuide):
|
| 28 |
+
"""Base definition of a generation guide.
|
| 29 |
+
|
| 30 |
+
A generation guide defines the behavior of a finite-state machine that guides
|
| 31 |
+
a text generation procedure. Unlike the DFAs built from regular expressions
|
| 32 |
+
guides can also emit a `Write` instructions which tells the model that it can
|
| 33 |
+
append a sequence of tokens (or token word) instead of generating it.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
initial_state: Any
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class StopAtEOSGuide(Guide):
|
| 41 |
+
"""Guide to generate tokens until the EOS token has been generated."""
|
| 42 |
+
|
| 43 |
+
final_state = 1
|
| 44 |
+
start_state = 0 # TODO: remove start_state, use only initial_state
|
| 45 |
+
initial_state = 0
|
| 46 |
+
|
| 47 |
+
def __init__(self, tokenizer: "Tokenizer"):
|
| 48 |
+
"""Initialize the generation guide.
|
| 49 |
+
|
| 50 |
+
model
|
| 51 |
+
The logit generator used to generate the next token.
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
self.eos_token_id = tokenizer.eos_token_id
|
| 55 |
+
self.vocabulary = tokenizer.vocabulary.values()
|
| 56 |
+
|
| 57 |
+
def get_next_instruction(self, state: int) -> Instruction:
|
| 58 |
+
if self.is_final_state(state):
|
| 59 |
+
return Write([self.eos_token_id])
|
| 60 |
+
return Generate(None)
|
| 61 |
+
|
| 62 |
+
def get_next_state(self, state: int, token_id: int) -> int:
|
| 63 |
+
if token_id == self.eos_token_id or state == self.final_state:
|
| 64 |
+
return self.final_state
|
| 65 |
+
|
| 66 |
+
return self.initial_state
|
| 67 |
+
|
| 68 |
+
def is_final_state(self, state: int):
|
| 69 |
+
return state == self.final_state
|
| 70 |
+
|
| 71 |
+
def copy(self):
|
| 72 |
+
return self
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
|
| 76 |
+
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class RegexGuide(CoreRegexGuide):
|
| 80 |
+
"""
|
| 81 |
+
Guide to generate text in the language of a regular expression.
|
| 82 |
+
CoreRegexGuide with outlines cache
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_regex(
|
| 87 |
+
cls,
|
| 88 |
+
regex_string: str,
|
| 89 |
+
tokenizer,
|
| 90 |
+
**kwargs,
|
| 91 |
+
):
|
| 92 |
+
return super().from_regex(
|
| 93 |
+
regex_string,
|
| 94 |
+
tokenizer,
|
| 95 |
+
_create_states_mapping=cached_create_states_mapping,
|
| 96 |
+
**kwargs,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class CFGGuide(Guide):
|
| 104 |
+
"""Guide to generate text that is in the language of a context-free Lark grammar."""
|
| 105 |
+
|
| 106 |
+
def __init__(self, cfg_string: str, tokenizer):
|
| 107 |
+
"""
|
| 108 |
+
Construct the PartialLark parser and set the empty initial_state (PartialParserState)
|
| 109 |
+
"""
|
| 110 |
+
warnings.warn(
|
| 111 |
+
"Outlines' public *community-contributed* CFG structured generation is experimental. "
|
| 112 |
+
"Please review https://dottxt-ai.github.io/outlines/latest/reference/generation/cfg#disclaimer"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.cfg_string = cfg_string
|
| 116 |
+
self.tokenizer = tokenizer
|
| 117 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
| 118 |
+
self.parser = PartialLark(
|
| 119 |
+
cfg_string,
|
| 120 |
+
parser="lalr",
|
| 121 |
+
import_paths=[grammars.GRAMMAR_PATH],
|
| 122 |
+
)
|
| 123 |
+
self.initial_state = CFGState(
|
| 124 |
+
parser_state=self.parser.parse(""), prev_token=None
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def get_next_instruction(self, state: CFGState) -> Instruction:
|
| 128 |
+
"""Return the next instruction for guided generation.
|
| 129 |
+
|
| 130 |
+
Current lazy approach:
|
| 131 |
+
- For each token in the vocabulary
|
| 132 |
+
- create a copy of the parsers state
|
| 133 |
+
- add the tokens to the parsers input text
|
| 134 |
+
- if valid, add token to returned tokens
|
| 135 |
+
|
| 136 |
+
Further refinements are necessary for performant text processing.
|
| 137 |
+
|
| 138 |
+
Parameters
|
| 139 |
+
----------
|
| 140 |
+
state
|
| 141 |
+
The guides current PartialParserState, or None if complete
|
| 142 |
+
|
| 143 |
+
Returns
|
| 144 |
+
-------
|
| 145 |
+
A `Generate` instance that contains the model and the allowed token ids.
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
if state.parser_state is None:
|
| 150 |
+
return Write(torch.tensor([self.eos_token_id]))
|
| 151 |
+
|
| 152 |
+
valid_tokens = list(
|
| 153 |
+
self.iter_valid_token_ids(state, self.tokenizer.vocabulary.values())
|
| 154 |
+
)
|
| 155 |
+
if len(valid_tokens) == 1:
|
| 156 |
+
return Write(torch.tensor(valid_tokens))
|
| 157 |
+
return Generate(torch.tensor(valid_tokens))
|
| 158 |
+
|
| 159 |
+
def iter_valid_token_ids(
|
| 160 |
+
self, state: CFGState, candidate_token_ids: list
|
| 161 |
+
) -> Generator[int, None, None]:
|
| 162 |
+
"""
|
| 163 |
+
Iterate over the given token_ids and yield those that are valid for the current parser state.
|
| 164 |
+
|
| 165 |
+
Parameters
|
| 166 |
+
----------
|
| 167 |
+
parser_state
|
| 168 |
+
The current state of the parser, or None if complete.
|
| 169 |
+
token_ids
|
| 170 |
+
The list of token ids to check for validity.
|
| 171 |
+
|
| 172 |
+
Yields
|
| 173 |
+
------
|
| 174 |
+
int
|
| 175 |
+
Valid token ids.
|
| 176 |
+
"""
|
| 177 |
+
if state.parser_state is None:
|
| 178 |
+
yield self.eos_token_id
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
for token_id in candidate_token_ids:
|
| 182 |
+
if token_id == self.eos_token_id:
|
| 183 |
+
if self.can_terminate_state(state):
|
| 184 |
+
yield token_id
|
| 185 |
+
else:
|
| 186 |
+
try:
|
| 187 |
+
self._get_parser_state_token_applied(state, int(token_id))
|
| 188 |
+
yield token_id
|
| 189 |
+
except (
|
| 190 |
+
ValueError,
|
| 191 |
+
EOFError,
|
| 192 |
+
UnexpectedToken,
|
| 193 |
+
UnexpectedCharacters,
|
| 194 |
+
DedentError,
|
| 195 |
+
):
|
| 196 |
+
pass
|
| 197 |
+
|
| 198 |
+
def get_next_state(self, state: CFGState, token_id: int) -> CFGState:
|
| 199 |
+
"""
|
| 200 |
+
Update the state of the guide.
|
| 201 |
+
Decode the token_id, and calculate the new parser_state with the token applied.
|
| 202 |
+
|
| 203 |
+
Parameters
|
| 204 |
+
----------
|
| 205 |
+
state
|
| 206 |
+
The guides current PartialParserState, or None if complete
|
| 207 |
+
token_id
|
| 208 |
+
The id of the token that was just generated.
|
| 209 |
+
|
| 210 |
+
Returns
|
| 211 |
+
-------
|
| 212 |
+
The guides new PartialParserState
|
| 213 |
+
|
| 214 |
+
"""
|
| 215 |
+
if state.parser_state is None or token_id == self.eos_token_id:
|
| 216 |
+
parser_state = None
|
| 217 |
+
else:
|
| 218 |
+
parser_state = self._get_parser_state_token_applied(state, int(token_id))
|
| 219 |
+
return CFGState(parser_state=parser_state, prev_token=token_id)
|
| 220 |
+
|
| 221 |
+
def _get_parser_state_token_applied(
|
| 222 |
+
self, state: CFGState, token_id: int
|
| 223 |
+
) -> PartialParserState:
|
| 224 |
+
"""
|
| 225 |
+
Don't mutate `parser_state`, copy to protect
|
| 226 |
+
|
| 227 |
+
Get the token string
|
| 228 |
+
- if first token in generation: tokenizer.decode (no leading whitespace)
|
| 229 |
+
- else: normalized (with possibly leading whitespace)
|
| 230 |
+
|
| 231 |
+
Don't allow empty ("") tokens, raise ValueError
|
| 232 |
+
"""
|
| 233 |
+
parser_state = copy.copy(state.parser_state) # prevent side effects
|
| 234 |
+
|
| 235 |
+
# normalize
|
| 236 |
+
if state.prev_token is None:
|
| 237 |
+
new_token_str = self.tokenizer.decode([token_id])[0]
|
| 238 |
+
else:
|
| 239 |
+
prev_token_str = self.tokenizer.decode([[state.prev_token]])[0]
|
| 240 |
+
combined_token_str = self.tokenizer.decode([[state.prev_token, token_id]])[
|
| 241 |
+
0
|
| 242 |
+
]
|
| 243 |
+
new_token_str = combined_token_str[len(prev_token_str) :]
|
| 244 |
+
|
| 245 |
+
if new_token_str == "":
|
| 246 |
+
raise ValueError("empty next token")
|
| 247 |
+
|
| 248 |
+
# update parser with new token
|
| 249 |
+
parser_state.lexer.state.text += new_token_str
|
| 250 |
+
self.parser.parse_from_state(parser_state, is_end=False)
|
| 251 |
+
|
| 252 |
+
return parser_state
|
| 253 |
+
|
| 254 |
+
def is_final_state(self, state: CFGState) -> bool:
|
| 255 |
+
# TODO: remove this method, use can_terminate_state and must_terminate_state
|
| 256 |
+
# here and in RegexGuide per https://github.com/dottxt-ai/outlines/issues/885
|
| 257 |
+
return self.can_terminate_state(state)
|
| 258 |
+
|
| 259 |
+
def can_terminate_state(self, state: CFGState) -> bool:
|
| 260 |
+
"""Generation is allowed to terminate"""
|
| 261 |
+
if state.parser_state is not None:
|
| 262 |
+
try:
|
| 263 |
+
copy.copy(state.parser_state).feed_eof()
|
| 264 |
+
except UnexpectedToken:
|
| 265 |
+
return False
|
| 266 |
+
return True
|
| 267 |
+
|
| 268 |
+
def must_terminate_state(self, state: CFGState) -> bool:
|
| 269 |
+
"""Generation must terminate, no legal continuations"""
|
| 270 |
+
return state.parser_state is None or set(state.parser_state.accepts()).issubset(
|
| 271 |
+
{"$END"}
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def copy(self) -> "CFGGuide":
|
| 275 |
+
"""Create a copy of the Guide."""
|
| 276 |
+
return CFGGuide(self.cfg_string, self.tokenizer)
|
.venv/lib/python3.11/site-packages/outlines/fsm/json_schema.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import json
|
| 3 |
+
import warnings
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Callable, Type, Union
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel, create_model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
|
| 11 |
+
"""Convert a JSON schema to a string.
|
| 12 |
+
|
| 13 |
+
Parameters
|
| 14 |
+
----------
|
| 15 |
+
json_schema
|
| 16 |
+
The JSON schema.
|
| 17 |
+
|
| 18 |
+
Returns
|
| 19 |
+
-------
|
| 20 |
+
str
|
| 21 |
+
The JSON schema converted to a string.
|
| 22 |
+
|
| 23 |
+
Raises
|
| 24 |
+
------
|
| 25 |
+
ValueError
|
| 26 |
+
If the schema is not a dictionary, a string or a Pydantic class.
|
| 27 |
+
"""
|
| 28 |
+
if isinstance(json_schema, dict):
|
| 29 |
+
schema_str = json.dumps(json_schema)
|
| 30 |
+
elif isinstance(json_schema, str):
|
| 31 |
+
schema_str = json_schema
|
| 32 |
+
elif issubclass(json_schema, BaseModel):
|
| 33 |
+
schema_str = json.dumps(json_schema.model_json_schema())
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(
|
| 36 |
+
f"Cannot parse schema {json_schema}. The schema must be either "
|
| 37 |
+
+ "a Pydantic class, a dictionary or a string that contains the JSON "
|
| 38 |
+
+ "schema specification"
|
| 39 |
+
)
|
| 40 |
+
return schema_str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_schema_from_signature(fn: Callable) -> dict:
|
| 44 |
+
"""Turn a function signature into a JSON schema.
|
| 45 |
+
|
| 46 |
+
Every JSON object valid to the output JSON Schema can be passed
|
| 47 |
+
to `fn` using the ** unpacking syntax.
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
signature = inspect.signature(fn)
|
| 51 |
+
arguments = {}
|
| 52 |
+
for name, arg in signature.parameters.items():
|
| 53 |
+
if arg.annotation == inspect._empty:
|
| 54 |
+
raise ValueError("Each argument must have a type annotation")
|
| 55 |
+
else:
|
| 56 |
+
arguments[name] = (arg.annotation, ...)
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
fn_name = fn.__name__
|
| 60 |
+
except Exception as e:
|
| 61 |
+
fn_name = "Arguments"
|
| 62 |
+
warnings.warn(
|
| 63 |
+
f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}",
|
| 64 |
+
category=UserWarning,
|
| 65 |
+
)
|
| 66 |
+
model = create_model(fn_name, **arguments)
|
| 67 |
+
|
| 68 |
+
return model.model_json_schema()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_schema_from_enum(myenum: type[Enum]) -> dict:
|
| 72 |
+
if len(myenum) == 0:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
|
| 75 |
+
)
|
| 76 |
+
choices = [
|
| 77 |
+
get_schema_from_signature(elt.value.func)
|
| 78 |
+
if callable(elt.value)
|
| 79 |
+
else {"const": elt.value}
|
| 80 |
+
for elt in myenum
|
| 81 |
+
]
|
| 82 |
+
schema = {"title": myenum.__name__, "oneOf": choices}
|
| 83 |
+
return schema
|
.venv/lib/python3.11/site-packages/outlines/fsm/parsing.py
ADDED
|
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import copy, deepcopy
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import (
|
| 5 |
+
Any,
|
| 6 |
+
Dict,
|
| 7 |
+
FrozenSet,
|
| 8 |
+
Generator,
|
| 9 |
+
Iterator,
|
| 10 |
+
List,
|
| 11 |
+
Optional,
|
| 12 |
+
Sequence,
|
| 13 |
+
Set,
|
| 14 |
+
Tuple,
|
| 15 |
+
Union,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import interegular
|
| 19 |
+
from interegular.fsm import FSM, Alphabet, OblivionError
|
| 20 |
+
from interegular.patterns import Unsupported
|
| 21 |
+
from lark import Lark, Token
|
| 22 |
+
from lark.common import LexerConf, ParserConf
|
| 23 |
+
from lark.exceptions import LexError, UnexpectedInput
|
| 24 |
+
from lark.indenter import Indenter
|
| 25 |
+
from lark.lexer import (
|
| 26 |
+
BasicLexer,
|
| 27 |
+
ContextualLexer,
|
| 28 |
+
LexerState,
|
| 29 |
+
LexerThread,
|
| 30 |
+
Scanner,
|
| 31 |
+
UnexpectedCharacters,
|
| 32 |
+
UnexpectedToken,
|
| 33 |
+
_create_unless,
|
| 34 |
+
)
|
| 35 |
+
from lark.parser_frontends import (
|
| 36 |
+
ParsingFrontend,
|
| 37 |
+
PostLexConnector,
|
| 38 |
+
_validate_frontend_args,
|
| 39 |
+
)
|
| 40 |
+
from lark.parsers.lalr_analysis import (
|
| 41 |
+
Action,
|
| 42 |
+
IntParseTable,
|
| 43 |
+
LALR_Analyzer,
|
| 44 |
+
ParseTable,
|
| 45 |
+
Shift,
|
| 46 |
+
)
|
| 47 |
+
from lark.parsers.lalr_interactive_parser import InteractiveParser
|
| 48 |
+
from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser
|
| 49 |
+
from outlines_core.fsm.regex import (
|
| 50 |
+
BetterFSM,
|
| 51 |
+
get_token_transition_keys,
|
| 52 |
+
make_deterministic_fsm,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
PartialParseState = Tuple[str, int]
|
| 56 |
+
ParseStateType = Union[int, FrozenSet]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class PartialTerminalInfo:
|
| 61 |
+
priority: int
|
| 62 |
+
terminal_name: str
|
| 63 |
+
can_transition: bool
|
| 64 |
+
is_final: bool
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class PartialTokensInfo:
|
| 69 |
+
fsm_state_seq: Tuple[int, ...]
|
| 70 |
+
is_not_finished: bool
|
| 71 |
+
terminals_and_info: Tuple[PartialTerminalInfo, ...]
|
| 72 |
+
final_terminals_and_info: Tuple[PartialTerminalInfo, ...]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class PartialParserConf(ParserConf):
|
| 76 |
+
__serialize_fields__ = (
|
| 77 |
+
"rules",
|
| 78 |
+
"start",
|
| 79 |
+
"parser_type",
|
| 80 |
+
"deterministic",
|
| 81 |
+
"use_value_stack",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def __init__(self, rules, callbacks, start, deterministic, use_value_stack):
|
| 85 |
+
super().__init__(rules, callbacks, start)
|
| 86 |
+
self.deterministic = deterministic
|
| 87 |
+
self.use_value_stack = use_value_stack
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class PartialLark(Lark):
|
| 91 |
+
__serialize_fields__ = (
|
| 92 |
+
"parser",
|
| 93 |
+
"rules",
|
| 94 |
+
"options",
|
| 95 |
+
"deterministic",
|
| 96 |
+
"use_value_stack",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def __init__(self, grammar, **options):
|
| 100 |
+
# TODO: Could've extended `LarkOptions`, but all these extensions are
|
| 101 |
+
# already way too much (and brittle). This library really needs a
|
| 102 |
+
# complete refactoring.
|
| 103 |
+
self.deterministic = options.pop("deterministic", False)
|
| 104 |
+
self.use_value_stack = options.pop("use_value_stack", False)
|
| 105 |
+
options["regex"] = True
|
| 106 |
+
super().__init__(grammar, **options)
|
| 107 |
+
assert self.options.parser == "lalr"
|
| 108 |
+
|
| 109 |
+
def _build_lexer(self, dont_ignore: bool = False) -> "PartialBasicLexer":
|
| 110 |
+
lexer_conf = self.lexer_conf
|
| 111 |
+
if dont_ignore:
|
| 112 |
+
from copy import copy
|
| 113 |
+
|
| 114 |
+
lexer_conf = copy(lexer_conf)
|
| 115 |
+
lexer_conf.ignore = ()
|
| 116 |
+
|
| 117 |
+
return PartialBasicLexer(lexer_conf)
|
| 118 |
+
|
| 119 |
+
def _build_parser(self) -> "PartialParsingFrontend":
|
| 120 |
+
self._prepare_callbacks()
|
| 121 |
+
_validate_frontend_args(self.options.parser, self.options.lexer)
|
| 122 |
+
parser_conf = PartialParserConf(
|
| 123 |
+
self.rules,
|
| 124 |
+
self._callbacks,
|
| 125 |
+
self.options.start,
|
| 126 |
+
self.deterministic,
|
| 127 |
+
self.use_value_stack,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# This is `_construct_parsing_frontend` expanded/inlined
|
| 131 |
+
parser_type = self.options.parser
|
| 132 |
+
lexer_type = self.options.lexer
|
| 133 |
+
lexer_conf = self.lexer_conf
|
| 134 |
+
|
| 135 |
+
assert isinstance(lexer_conf, LexerConf)
|
| 136 |
+
assert isinstance(parser_conf, ParserConf)
|
| 137 |
+
parser_conf.parser_type = parser_type
|
| 138 |
+
self.lexer_conf.lexer_type = lexer_type
|
| 139 |
+
return PartialParsingFrontend(lexer_conf, parser_conf, self.options)
|
| 140 |
+
|
| 141 |
+
def __repr__(self):
|
| 142 |
+
return "{}(open({!r}), parser={!r}, lexer={!r}, ...)".format(
|
| 143 |
+
type(self).__name__,
|
| 144 |
+
self.source_path,
|
| 145 |
+
self.options.parser,
|
| 146 |
+
self.options.lexer,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def parse_from_state(self, parse_state: "PartialParseState", is_end=False):
|
| 150 |
+
return self.parser.parser.parser.parse_from_state(parse_state, is_end=is_end)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class PartialLexerThread(LexerThread):
|
| 154 |
+
def __copy__(self):
|
| 155 |
+
return type(self)(copy(self.lexer), copy(self.state))
|
| 156 |
+
|
| 157 |
+
def __repr__(self):
|
| 158 |
+
return f"{type(self).__name__}(lexer={self.lexer!r}, state={self.state!r})"
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class PartialPostLexConnector(PostLexConnector):
|
| 162 |
+
def __copy__(self):
|
| 163 |
+
return type(self)(self.lexer, copy(self.postlexer))
|
| 164 |
+
|
| 165 |
+
def __repr__(self):
|
| 166 |
+
return (
|
| 167 |
+
f"{type(self).__name__}(lexer={self.lexer!r}, postlexer={self.postlexer!r})"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class PartialParsingFrontend(ParsingFrontend):
|
| 172 |
+
def __init__(self, lexer_conf, parser_conf, options, parser=None):
|
| 173 |
+
assert parser_conf.parser_type == "lalr"
|
| 174 |
+
|
| 175 |
+
options._plugins["LALR_Parser"] = PartialLALRParser
|
| 176 |
+
options._plugins["BasicLexer"] = PartialBasicLexer
|
| 177 |
+
options._plugins["ContextualLexer"] = PartialContextualLexer
|
| 178 |
+
options._plugins["LexerThread"] = PartialLexerThread
|
| 179 |
+
|
| 180 |
+
super().__init__(lexer_conf, parser_conf, options, parser=parser)
|
| 181 |
+
|
| 182 |
+
if lexer_conf.postlex:
|
| 183 |
+
self.lexer = PartialPostLexConnector(self.lexer.lexer, lexer_conf.postlex)
|
| 184 |
+
|
| 185 |
+
self._termset_fsm_info = None
|
| 186 |
+
self._symbols_to_states: Optional[
|
| 187 |
+
Dict[str, Set[Tuple[ParseStateType, Action]]]
|
| 188 |
+
] = None
|
| 189 |
+
self._reverse_shifts: Optional[
|
| 190 |
+
Dict[ParseStateType, Dict[str, Set[ParseStateType]]]
|
| 191 |
+
] = None
|
| 192 |
+
# self._state_transition_map: Optional[
|
| 193 |
+
# Dict[Tuple[ParseStateType, str], Set[ParseStateType]]
|
| 194 |
+
# ] = None
|
| 195 |
+
|
| 196 |
+
def _compute_maps(
|
| 197 |
+
self,
|
| 198 |
+
):
|
| 199 |
+
"""Compute state transition and symbols-to-states maps."""
|
| 200 |
+
self._reverse_shifts = {}
|
| 201 |
+
self._symbols_to_states = {}
|
| 202 |
+
|
| 203 |
+
parse_table = self.parser.parser.parse_table
|
| 204 |
+
|
| 205 |
+
for from_state, symbols_to_ops in parse_table.states.items():
|
| 206 |
+
for symbol, op in symbols_to_ops.items():
|
| 207 |
+
if op[0] == Shift:
|
| 208 |
+
symbols_to_from_states = self._reverse_shifts.setdefault(op[1], {})
|
| 209 |
+
symbols_to_from_states.setdefault(symbol, set()).add(from_state)
|
| 210 |
+
self._symbols_to_states.setdefault(symbol, set()).add((from_state, op))
|
| 211 |
+
|
| 212 |
+
# # TODO: This approach is very wasteful.
|
| 213 |
+
# context_lexer = get_contextual_lexer(self)
|
| 214 |
+
# self._state_transition_map = {}
|
| 215 |
+
#
|
| 216 |
+
# for from_state, transitions in parse_table.states.items():
|
| 217 |
+
# for symbol, action in transitions.items():
|
| 218 |
+
# # TODO: Filter non-terminals
|
| 219 |
+
# if symbol not in context_lexer.root_lexer.terminals_by_name:
|
| 220 |
+
# continue
|
| 221 |
+
#
|
| 222 |
+
# if action[0] is Shift:
|
| 223 |
+
# self._state_transition_map.setdefault(
|
| 224 |
+
# (from_state, symbol), set()
|
| 225 |
+
# ).add(action[1])
|
| 226 |
+
# continue
|
| 227 |
+
#
|
| 228 |
+
# antecedent_state_seqs = parse_to_terminal(self, [(from_state,)], symbol)
|
| 229 |
+
#
|
| 230 |
+
# for antecedent_state_seq in antecedent_state_seqs:
|
| 231 |
+
# antecedent_state = antecedent_state_seq[-1]
|
| 232 |
+
# self._state_transition_map.setdefault(
|
| 233 |
+
# (from_state, symbol), set()
|
| 234 |
+
# ).add(antecedent_state)
|
| 235 |
+
|
| 236 |
+
def _compute_termset_fsm_info(self):
|
| 237 |
+
"""Collect and return information about terminal symbol sets and their FSMs.
|
| 238 |
+
|
| 239 |
+
Terminal symbol sets (or "termsets") are ordered sequences of terminal
|
| 240 |
+
symbols that are used by each parser state. Associated with each is a
|
| 241 |
+
collection of FSMs for each terminal and a single parse state FSM that is
|
| 242 |
+
the union of each terminal's FSM.
|
| 243 |
+
|
| 244 |
+
This constructs a list of tuples containing the termset, the set of
|
| 245 |
+
parse states that use the termsets, parse state FSMs, and information
|
| 246 |
+
mapping the components of the parse state FSMs to their terminal symbol
|
| 247 |
+
FSMs.
|
| 248 |
+
|
| 249 |
+
"""
|
| 250 |
+
context_lexer = get_contextual_lexer(self)
|
| 251 |
+
termsets_to_fsms = {}
|
| 252 |
+
termsets_to_parse_states: Dict[Tuple[str, ...], Set[ParseStateType]] = {}
|
| 253 |
+
for parse_state, lexer in context_lexer.lexers.items():
|
| 254 |
+
scanner = lexer.scanner
|
| 255 |
+
key = tuple(term.name for term in scanner.terminals)
|
| 256 |
+
termsets_to_fsms[key] = (scanner.fsm, scanner.fsms_to_trans_finals)
|
| 257 |
+
termsets_to_parse_states.setdefault(key, set()).add(parse_state)
|
| 258 |
+
|
| 259 |
+
self._termset_fsm_info = [
|
| 260 |
+
(
|
| 261 |
+
termset,
|
| 262 |
+
frozenset(termsets_to_parse_states[termset]),
|
| 263 |
+
fsm,
|
| 264 |
+
fsms_to_trans_finals,
|
| 265 |
+
)
|
| 266 |
+
for termset, (fsm, fsms_to_trans_finals) in termsets_to_fsms.items()
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
@property
|
| 270 |
+
def termset_fsm_info(self):
|
| 271 |
+
if self._termset_fsm_info is None:
|
| 272 |
+
self._compute_termset_fsm_info()
|
| 273 |
+
return self._termset_fsm_info
|
| 274 |
+
|
| 275 |
+
@property
|
| 276 |
+
def symbols_to_states(self):
|
| 277 |
+
if self._symbols_to_states is None:
|
| 278 |
+
self._compute_maps()
|
| 279 |
+
return self._symbols_to_states
|
| 280 |
+
|
| 281 |
+
@property
|
| 282 |
+
def reverse_shifts(self):
|
| 283 |
+
if self._reverse_shifts is None:
|
| 284 |
+
self._compute_maps()
|
| 285 |
+
return self._reverse_shifts
|
| 286 |
+
|
| 287 |
+
# @property
|
| 288 |
+
# def state_transition_map(self):
|
| 289 |
+
# if self._state_transition_map is None:
|
| 290 |
+
# self._compute_maps()
|
| 291 |
+
# return self._state_transition_map
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class PartialLALRParser(LALR_Parser):
|
| 295 |
+
def __init__(self, parser_conf, debug=False, strict=False):
|
| 296 |
+
analysis = LALR_Analyzer(
|
| 297 |
+
parser_conf, debug=debug if not parser_conf.deterministic else True
|
| 298 |
+
)
|
| 299 |
+
analysis.compute_lalr()
|
| 300 |
+
callbacks = parser_conf.callbacks
|
| 301 |
+
|
| 302 |
+
self.parser_conf = parser_conf
|
| 303 |
+
self._parse_table = analysis.parse_table
|
| 304 |
+
|
| 305 |
+
if parser_conf.deterministic:
|
| 306 |
+
old_to_new = {}
|
| 307 |
+
|
| 308 |
+
def to_tuple(v):
|
| 309 |
+
new = old_to_new.get(v)
|
| 310 |
+
if new is None:
|
| 311 |
+
new = tuple(sorted(v, key=lambda y: str(y)))
|
| 312 |
+
old_to_new[v] = new
|
| 313 |
+
return new
|
| 314 |
+
|
| 315 |
+
enum = sorted(
|
| 316 |
+
self._parse_table.states.keys(),
|
| 317 |
+
key=lambda x: str(sorted(x, key=lambda y: str(y))),
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
new_states = {}
|
| 321 |
+
for s in enum:
|
| 322 |
+
transitions = {
|
| 323 |
+
term: op if op[0] is not Shift else (op[0], to_tuple(op[1]))
|
| 324 |
+
for term, op in self._parse_table.states[s].items()
|
| 325 |
+
}
|
| 326 |
+
new_states[to_tuple(s)] = transitions
|
| 327 |
+
|
| 328 |
+
self._parse_table = type(self._parse_table)(
|
| 329 |
+
new_states,
|
| 330 |
+
{k: to_tuple(v) for k, v in self._parse_table.start_states.items()},
|
| 331 |
+
{k: to_tuple(v) for k, v in self._parse_table.end_states.items()},
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if not debug:
|
| 335 |
+
self._parse_table = IntParseTable.from_ParseTable(self._parse_table)
|
| 336 |
+
self.states_to_rulesets = dict(
|
| 337 |
+
zip(self._parse_table.states.keys(), new_states.keys())
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.parser = PartialParser(
|
| 341 |
+
self._parse_table,
|
| 342 |
+
callbacks,
|
| 343 |
+
debug,
|
| 344 |
+
use_value_stack=parser_conf.use_value_stack,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
@classmethod
|
| 348 |
+
def deserialize(cls, data, memo, callbacks, debug=False):
|
| 349 |
+
inst = cls.__new__(cls)
|
| 350 |
+
inst._parse_table = ParseTable.deserialize(data, memo)
|
| 351 |
+
inst.parser = PartialParser(inst._parse_table, callbacks, debug)
|
| 352 |
+
return inst
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class PartialParserState(ParserState):
|
| 356 |
+
__slots__ = "use_value_stack"
|
| 357 |
+
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
parse_conf,
|
| 361 |
+
lexer,
|
| 362 |
+
state_stack=None,
|
| 363 |
+
value_stack=None,
|
| 364 |
+
use_value_stack=False,
|
| 365 |
+
):
|
| 366 |
+
super().__init__(
|
| 367 |
+
parse_conf, lexer, state_stack=state_stack, value_stack=value_stack
|
| 368 |
+
)
|
| 369 |
+
self.use_value_stack = use_value_stack
|
| 370 |
+
|
| 371 |
+
def feed_token(self, token, is_end=False):
|
| 372 |
+
if token.type == "partial":
|
| 373 |
+
# If none of the potential terminals can transition, we need to know now
|
| 374 |
+
current_state = self.state_stack[-1]
|
| 375 |
+
current_lexer = get_contextual_lexer(self.lexer).lexers[current_state]
|
| 376 |
+
|
| 377 |
+
# We have to feed the token and determine whether or not at least
|
| 378 |
+
# one terminal is consistent with the stack; otherwise, we'll miss
|
| 379 |
+
# invalid REDUCE cases.
|
| 380 |
+
# TODO: We should track separate parses conditional on possible
|
| 381 |
+
# token/symbol types, then we can coherently reuse the following
|
| 382 |
+
# results instead of recomputing it later.
|
| 383 |
+
can_transition = False
|
| 384 |
+
for terminal_info in token.value.terminals_and_info:
|
| 385 |
+
if terminal_info.terminal_name not in current_lexer.ignore_types:
|
| 386 |
+
test_token = Token.new_borrow_pos(
|
| 387 |
+
terminal_info.terminal_name, "", token
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
stack = copy(self.state_stack)
|
| 391 |
+
try:
|
| 392 |
+
self.feed_token_no_stack(test_token, is_end=is_end)
|
| 393 |
+
can_transition = True
|
| 394 |
+
break
|
| 395 |
+
except UnexpectedToken:
|
| 396 |
+
continue
|
| 397 |
+
finally:
|
| 398 |
+
self.state_stack = stack
|
| 399 |
+
else:
|
| 400 |
+
can_transition = True
|
| 401 |
+
|
| 402 |
+
if not can_transition:
|
| 403 |
+
expected = {
|
| 404 |
+
s
|
| 405 |
+
for s in self.parse_conf.states[current_state].keys()
|
| 406 |
+
if s.isupper()
|
| 407 |
+
}
|
| 408 |
+
raise UnexpectedToken(
|
| 409 |
+
token, expected, state=self, interactive_parser=None
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
elif self.use_value_stack:
|
| 413 |
+
super().feed_token(token, is_end=is_end)
|
| 414 |
+
else:
|
| 415 |
+
self.feed_token_no_stack(token, is_end=is_end)
|
| 416 |
+
|
| 417 |
+
def feed_token_no_stack(self, token, is_end=False):
|
| 418 |
+
"""
|
| 419 |
+
This is a copy of `ParserState.feed_token` with all the value stack
|
| 420 |
+
steps removed. Since we're not exactly parsing in order to obtain a
|
| 421 |
+
CST or anything similar, we can avoid the growing expense of tracking
|
| 422 |
+
the parse tree.
|
| 423 |
+
"""
|
| 424 |
+
state_stack = self.state_stack
|
| 425 |
+
states = self.parse_conf.states
|
| 426 |
+
end_state = self.parse_conf.end_state
|
| 427 |
+
|
| 428 |
+
while True:
|
| 429 |
+
state = state_stack[-1]
|
| 430 |
+
try:
|
| 431 |
+
action, arg = states[state][token.type]
|
| 432 |
+
except KeyError:
|
| 433 |
+
expected = {s for s in states[state].keys() if s.isupper()}
|
| 434 |
+
raise UnexpectedToken(
|
| 435 |
+
token, expected, state=self, interactive_parser=None
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
assert arg != end_state
|
| 439 |
+
|
| 440 |
+
if action is Shift:
|
| 441 |
+
# shift once and return
|
| 442 |
+
assert not is_end
|
| 443 |
+
state_stack.append(arg)
|
| 444 |
+
return
|
| 445 |
+
else:
|
| 446 |
+
# reduce+shift as many times as necessary
|
| 447 |
+
rule = arg
|
| 448 |
+
size = len(rule.expansion)
|
| 449 |
+
if size:
|
| 450 |
+
del state_stack[-size:]
|
| 451 |
+
|
| 452 |
+
_action, new_state = states[state_stack[-1]][rule.origin.name]
|
| 453 |
+
assert _action is Shift
|
| 454 |
+
state_stack.append(new_state)
|
| 455 |
+
|
| 456 |
+
if is_end and state_stack[-1] == end_state:
|
| 457 |
+
return
|
| 458 |
+
|
| 459 |
+
def feed_eof(self):
|
| 460 |
+
last_token = self.lexer.state.last_token
|
| 461 |
+
|
| 462 |
+
if last_token is None:
|
| 463 |
+
eof_token = self.lexer._Token("$END", "", 0, 1, 1)
|
| 464 |
+
else:
|
| 465 |
+
eof_token = Token.new_borrow_pos("$END", "", last_token)
|
| 466 |
+
|
| 467 |
+
new_token_is_legal = (
|
| 468 |
+
last_token is None
|
| 469 |
+
or last_token.type != "partial"
|
| 470 |
+
or any(ti.is_final for ti in last_token.value.terminals_and_info)
|
| 471 |
+
)
|
| 472 |
+
if new_token_is_legal:
|
| 473 |
+
self.feed_token(eof_token, is_end=True)
|
| 474 |
+
else:
|
| 475 |
+
raise UnexpectedToken(eof_token, [], state=self, interactive_parser=None)
|
| 476 |
+
|
| 477 |
+
def choices(self):
|
| 478 |
+
return self.parse_conf.parse_table.states[self.position]
|
| 479 |
+
|
| 480 |
+
def accepts(self):
|
| 481 |
+
"""
|
| 482 |
+
Adapted from https://github.com/lark-parser/lark/blob/be542c2ff6d968817df019b8bf03f37b3111c08c/lark/parsers/lalr_interactive_parser.py#L95
|
| 483 |
+
Returns the set of possible tokens that will advance the parser into a new valid state.
|
| 484 |
+
"""
|
| 485 |
+
accepts = set()
|
| 486 |
+
conf_no_callbacks = copy(self.parse_conf)
|
| 487 |
+
# We don't want to call callbacks here since those might have arbitrary side effects
|
| 488 |
+
# and are unnecessarily slow.
|
| 489 |
+
conf_no_callbacks.callbacks = {}
|
| 490 |
+
for t in self.choices():
|
| 491 |
+
if t.isupper(): # is terminal?
|
| 492 |
+
new_state = copy(self)
|
| 493 |
+
new_state.parse_conf = conf_no_callbacks
|
| 494 |
+
try:
|
| 495 |
+
new_state.feed_token(new_state.lexer._Token(t, ""))
|
| 496 |
+
except UnexpectedToken:
|
| 497 |
+
pass
|
| 498 |
+
else:
|
| 499 |
+
accepts.add(t)
|
| 500 |
+
return accepts
|
| 501 |
+
|
| 502 |
+
def __copy__(self):
|
| 503 |
+
return type(self)(
|
| 504 |
+
self.parse_conf,
|
| 505 |
+
copy(self.lexer),
|
| 506 |
+
copy(self.state_stack),
|
| 507 |
+
deepcopy(self.value_stack),
|
| 508 |
+
use_value_stack=self.use_value_stack,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
def __repr__(self):
|
| 512 |
+
return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})"
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class PartialParser(_Parser):
|
| 516 |
+
def __init__(self, parse_table, callbacks, debug=False, use_value_stack=False):
|
| 517 |
+
super().__init__(parse_table, callbacks, debug=debug)
|
| 518 |
+
self.use_value_stack = use_value_stack
|
| 519 |
+
|
| 520 |
+
def parse(
|
| 521 |
+
self, lexer, start, value_stack=None, state_stack=None, start_interactive=False
|
| 522 |
+
):
|
| 523 |
+
parse_conf = ParseConf(self.parse_table, self.callbacks, start)
|
| 524 |
+
parser_state = PartialParserState(
|
| 525 |
+
parse_conf, copy(lexer), state_stack, value_stack, self.use_value_stack
|
| 526 |
+
)
|
| 527 |
+
if start_interactive:
|
| 528 |
+
return InteractiveParser(self, parser_state, parser_state.lexer)
|
| 529 |
+
return self.parse_from_state(parser_state)
|
| 530 |
+
|
| 531 |
+
def parse_from_state(self, state, last_token=None, is_end=False):
|
| 532 |
+
try:
|
| 533 |
+
token = last_token
|
| 534 |
+
for token in state.lexer.lex(state):
|
| 535 |
+
state.feed_token(token)
|
| 536 |
+
|
| 537 |
+
if is_end and (not token or token.type != "partial"):
|
| 538 |
+
state.feed_eof()
|
| 539 |
+
|
| 540 |
+
return state
|
| 541 |
+
except UnexpectedInput as e:
|
| 542 |
+
try:
|
| 543 |
+
e.interactive_parser = InteractiveParser(self, state, state.lexer)
|
| 544 |
+
except NameError:
|
| 545 |
+
pass
|
| 546 |
+
raise e
|
| 547 |
+
except Exception:
|
| 548 |
+
if self.debug:
|
| 549 |
+
print("")
|
| 550 |
+
print("STATE STACK DUMP")
|
| 551 |
+
print("----------------")
|
| 552 |
+
for i, s in enumerate(state.state_stack):
|
| 553 |
+
print("%d)" % i, s)
|
| 554 |
+
print("")
|
| 555 |
+
|
| 556 |
+
raise
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class PartialScanner(Scanner):
|
| 560 |
+
@classmethod
|
| 561 |
+
@lru_cache
|
| 562 |
+
def construct_terminal_fsm(cls, terminal):
|
| 563 |
+
# TODO: This should really be done at the lexer/parser level so that
|
| 564 |
+
# the lifetime of these objects is tied to the parser itself.
|
| 565 |
+
regex_str = terminal.pattern.to_regexp()
|
| 566 |
+
pattern = interegular.parse_pattern(regex_str)
|
| 567 |
+
fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
|
| 568 |
+
return fsm, pattern.prefix_postfix
|
| 569 |
+
|
| 570 |
+
def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False):
|
| 571 |
+
self.terminals = terminals
|
| 572 |
+
self.g_regex_flags = g_regex_flags
|
| 573 |
+
self.use_bytes = use_bytes
|
| 574 |
+
self.match_whole = match_whole
|
| 575 |
+
self.allowed_types = {t.name for t in self.terminals}
|
| 576 |
+
self._mres = None
|
| 577 |
+
|
| 578 |
+
fsms = []
|
| 579 |
+
for t in self.terminals:
|
| 580 |
+
fsm, prefix_postfix = self.construct_terminal_fsm(t)
|
| 581 |
+
|
| 582 |
+
# TODO FIXME: We don't support this right now.
|
| 583 |
+
assert prefix_postfix == (0, 0)
|
| 584 |
+
|
| 585 |
+
fsms.append(fsm)
|
| 586 |
+
|
| 587 |
+
self.fsm, self.fsms_to_trans_finals = fsm_union(fsms)
|
| 588 |
+
|
| 589 |
+
def get_terminals_info(
|
| 590 |
+
self, fsm_state_seq
|
| 591 |
+
) -> Tuple[Tuple[PartialTerminalInfo, ...], Tuple[PartialTerminalInfo, ...]]:
|
| 592 |
+
"""Get the possible terminal symbols for an FSM state sequence."""
|
| 593 |
+
terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
|
| 594 |
+
final_terminals_and_info: Tuple[PartialTerminalInfo, ...] = ()
|
| 595 |
+
for i, (fsm_id, fsm_reads_more, in_final) in enumerate(
|
| 596 |
+
get_sub_fsms_from_seq(fsm_state_seq, self.fsms_to_trans_finals)
|
| 597 |
+
):
|
| 598 |
+
terminal_name = self.terminals[fsm_id].name
|
| 599 |
+
info = PartialTerminalInfo(i, terminal_name, fsm_reads_more, in_final)
|
| 600 |
+
terminals_and_info += (info,)
|
| 601 |
+
if in_final:
|
| 602 |
+
final_terminals_and_info += (info,)
|
| 603 |
+
|
| 604 |
+
return terminals_and_info, final_terminals_and_info
|
| 605 |
+
|
| 606 |
+
def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None):
|
| 607 |
+
"""Determine an FSM match over `text` starting at `pos` and continuing `last_fsm_state_seq`."""
|
| 608 |
+
|
| 609 |
+
start_pos = pos
|
| 610 |
+
|
| 611 |
+
if last_fsm_state_seq:
|
| 612 |
+
assert len(last_fsm_state_seq) > 1
|
| 613 |
+
start_pos += len(last_fsm_state_seq) - 1
|
| 614 |
+
start_state = last_fsm_state_seq[-1]
|
| 615 |
+
else:
|
| 616 |
+
start_state = self.fsm.initial
|
| 617 |
+
|
| 618 |
+
text_part = text[start_pos:]
|
| 619 |
+
|
| 620 |
+
text_transitions = get_token_transition_keys(
|
| 621 |
+
self.fsm.fsm_info.alphabet_symbol_mapping,
|
| 622 |
+
self.fsm.fsm_info.alphabet_anything_value,
|
| 623 |
+
text_part,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
state_seq = walk_fsm(
|
| 627 |
+
self.fsm,
|
| 628 |
+
text_transitions,
|
| 629 |
+
start_state,
|
| 630 |
+
full_match=self.match_whole,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if not state_seq:
|
| 634 |
+
return None
|
| 635 |
+
|
| 636 |
+
if last_fsm_state_seq:
|
| 637 |
+
res = last_fsm_state_seq + tuple(state_seq)
|
| 638 |
+
else:
|
| 639 |
+
res = (start_state,) + tuple(state_seq)
|
| 640 |
+
|
| 641 |
+
return res
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
class PartialContextualLexer(ContextualLexer):
|
| 645 |
+
def __init__(self, conf: "LexerConf", states, always_accept=()):
|
| 646 |
+
terminals = list(conf.terminals)
|
| 647 |
+
terminals_by_name = conf.terminals_by_name
|
| 648 |
+
|
| 649 |
+
trad_conf = copy(conf)
|
| 650 |
+
trad_conf.terminals = terminals
|
| 651 |
+
|
| 652 |
+
lexer_by_symbols: Dict = {}
|
| 653 |
+
self.lexers = {}
|
| 654 |
+
for state, accepts in states.items():
|
| 655 |
+
key = frozenset(accepts)
|
| 656 |
+
try:
|
| 657 |
+
lexer = lexer_by_symbols[key]
|
| 658 |
+
except KeyError:
|
| 659 |
+
accepts = set(accepts) | set(conf.ignore) | set(always_accept)
|
| 660 |
+
lexer_conf = copy(trad_conf)
|
| 661 |
+
lexer_conf.terminals = [
|
| 662 |
+
terminals_by_name[n] for n in accepts if n in terminals_by_name
|
| 663 |
+
]
|
| 664 |
+
if not lexer_conf.terminals:
|
| 665 |
+
continue
|
| 666 |
+
lexer = PartialBasicLexer(lexer_conf)
|
| 667 |
+
lexer_by_symbols[key] = lexer
|
| 668 |
+
|
| 669 |
+
self.lexers[state] = lexer
|
| 670 |
+
|
| 671 |
+
assert trad_conf.terminals is terminals
|
| 672 |
+
self.root_lexer = PartialBasicLexer(trad_conf)
|
| 673 |
+
|
| 674 |
+
def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]:
|
| 675 |
+
try:
|
| 676 |
+
while True:
|
| 677 |
+
lexer = self.lexers[parser_state.position]
|
| 678 |
+
next_tok = lexer.next_token(lexer_state, parser_state)
|
| 679 |
+
yield next_tok
|
| 680 |
+
except EOFError:
|
| 681 |
+
pass
|
| 682 |
+
except KeyError:
|
| 683 |
+
if len(lexer_state.text) > lexer_state.line_ctr.char_pos:
|
| 684 |
+
raise UnexpectedCharacters(
|
| 685 |
+
lexer_state.text,
|
| 686 |
+
lexer_state.line_ctr.char_pos,
|
| 687 |
+
lexer_state.line_ctr.line,
|
| 688 |
+
lexer_state.line_ctr.column,
|
| 689 |
+
allowed=False,
|
| 690 |
+
token_history=lexer_state.last_token and [lexer_state.last_token],
|
| 691 |
+
state=parser_state,
|
| 692 |
+
terminals_by_name=self.root_lexer.terminals,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class PartialBasicLexer(BasicLexer):
|
| 697 |
+
def __init__(self, conf: "LexerConf"):
|
| 698 |
+
super().__init__(conf)
|
| 699 |
+
# Eagerly construct the scanner
|
| 700 |
+
self._build_scanner()
|
| 701 |
+
|
| 702 |
+
def _build_scanner(self):
|
| 703 |
+
# This seems incredibly convoluted: `lark` creates callback-triggered
|
| 704 |
+
# nested scanners for regex-defined terminals that overlap with
|
| 705 |
+
# string-defined terminals when both types of terminals have the same
|
| 706 |
+
# priority. Unless I'm missing something important, why not simply
|
| 707 |
+
# reorder the terminals so that the string-defined ones come before the
|
| 708 |
+
# regex-defined ones?
|
| 709 |
+
terminals, self.callback = _create_unless(
|
| 710 |
+
self.terminals, self.g_regex_flags, self.re, self.use_bytes
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# We can't let people arbitrarily mess with the scanning process.
|
| 714 |
+
assert not self.user_callbacks
|
| 715 |
+
# for type_, f in self.user_callbacks.items():
|
| 716 |
+
# if type_ in self.callback:
|
| 717 |
+
# # Already a callback there, probably UnlessCallback
|
| 718 |
+
# self.callback[type_] = CallChain(
|
| 719 |
+
# self.callback[type_], f, lambda t: t.type == type_
|
| 720 |
+
# )
|
| 721 |
+
# else:
|
| 722 |
+
# self.callback[type_] = f
|
| 723 |
+
|
| 724 |
+
# We used the "callback" results to reorder the terminals (see the
|
| 725 |
+
# comments above).
|
| 726 |
+
for terminal_name, callback in self.callback.items():
|
| 727 |
+
terminal = self.terminals_by_name[terminal_name]
|
| 728 |
+
for sub_terminal in callback.scanner.terminals:
|
| 729 |
+
self.terminals.remove(sub_terminal)
|
| 730 |
+
idx = self.terminals.index(terminal)
|
| 731 |
+
self.terminals.insert(idx, sub_terminal)
|
| 732 |
+
|
| 733 |
+
self._scanner = PartialScanner(
|
| 734 |
+
self.terminals, self.g_regex_flags, self.re, self.use_bytes
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
def match(self, text, pos, last_fsm_state_seq=None):
|
| 738 |
+
return self.scanner.match(text, pos, last_fsm_state_seq)
|
| 739 |
+
|
| 740 |
+
def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token:
|
| 741 |
+
last_token = lex_state.last_token
|
| 742 |
+
|
| 743 |
+
last_fsm_state_seq = None
|
| 744 |
+
if last_token and last_token.type == "partial":
|
| 745 |
+
# Continue from last partial lexer state
|
| 746 |
+
last_fsm_state_seq = last_token.value.fsm_state_seq
|
| 747 |
+
|
| 748 |
+
line_ctr = lex_state.line_ctr
|
| 749 |
+
end_pos = line_ctr.char_pos + (
|
| 750 |
+
len(last_fsm_state_seq) - 1 if last_fsm_state_seq else 0
|
| 751 |
+
)
|
| 752 |
+
while end_pos < len(lex_state.text):
|
| 753 |
+
res = self.match(lex_state.text, line_ctr.char_pos, last_fsm_state_seq)
|
| 754 |
+
|
| 755 |
+
if not res:
|
| 756 |
+
if (
|
| 757 |
+
not last_fsm_state_seq
|
| 758 |
+
or last_fsm_state_seq[-1] not in self.scanner.fsm.finals
|
| 759 |
+
):
|
| 760 |
+
allowed = self.scanner.allowed_types - self.ignore_types
|
| 761 |
+
if not allowed:
|
| 762 |
+
allowed = {"<END-OF-FILE>"}
|
| 763 |
+
raise UnexpectedCharacters(
|
| 764 |
+
lex_state.text,
|
| 765 |
+
line_ctr.char_pos,
|
| 766 |
+
line_ctr.line,
|
| 767 |
+
line_ctr.column,
|
| 768 |
+
allowed=allowed,
|
| 769 |
+
token_history=lex_state.last_token and [lex_state.last_token],
|
| 770 |
+
state=parser_state,
|
| 771 |
+
terminals_by_name=self.terminals_by_name,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
# The partial match might be complete now
|
| 775 |
+
fsm_state_seq = last_token.value.fsm_state_seq
|
| 776 |
+
terminals_and_info = last_token.value.terminals_and_info
|
| 777 |
+
final_terminals_and_info = last_token.value.final_terminals_and_info
|
| 778 |
+
else:
|
| 779 |
+
fsm_state_seq = res
|
| 780 |
+
(
|
| 781 |
+
terminals_and_info,
|
| 782 |
+
final_terminals_and_info,
|
| 783 |
+
) = self.scanner.get_terminals_info(fsm_state_seq)
|
| 784 |
+
|
| 785 |
+
priority_terminal_info = (
|
| 786 |
+
final_terminals_and_info[0]
|
| 787 |
+
if final_terminals_and_info
|
| 788 |
+
else terminals_and_info[0]
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
is_not_finished = (
|
| 792 |
+
not priority_terminal_info.is_final
|
| 793 |
+
or priority_terminal_info.can_transition
|
| 794 |
+
or len(terminals_and_info) > 1
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
start_pos = line_ctr.char_pos
|
| 798 |
+
end_pos = start_pos + len(fsm_state_seq) - 1
|
| 799 |
+
|
| 800 |
+
if end_pos >= len(lex_state.text) and is_not_finished:
|
| 801 |
+
type_name = "partial"
|
| 802 |
+
token_value = PartialTokensInfo(
|
| 803 |
+
fsm_state_seq,
|
| 804 |
+
is_not_finished,
|
| 805 |
+
terminals_and_info,
|
| 806 |
+
final_terminals_and_info,
|
| 807 |
+
)
|
| 808 |
+
# Don't update the line counter states until we've finished
|
| 809 |
+
value = ""
|
| 810 |
+
else:
|
| 811 |
+
type_name = priority_terminal_info.terminal_name
|
| 812 |
+
# The token value should contain all partial scan parts in this
|
| 813 |
+
# case
|
| 814 |
+
value = token_value = lex_state.text[start_pos:end_pos]
|
| 815 |
+
|
| 816 |
+
assert isinstance(self.callback, Dict)
|
| 817 |
+
|
| 818 |
+
if type_name not in self.ignore_types:
|
| 819 |
+
t = Token(
|
| 820 |
+
type_name,
|
| 821 |
+
token_value,
|
| 822 |
+
line_ctr.char_pos,
|
| 823 |
+
line_ctr.line,
|
| 824 |
+
line_ctr.column,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
line_ctr.feed(value, type_name in self.newline_types)
|
| 828 |
+
|
| 829 |
+
t.end_line = line_ctr.line
|
| 830 |
+
t.end_column = line_ctr.column
|
| 831 |
+
t.end_pos = line_ctr.char_pos
|
| 832 |
+
if t.type in self.callback:
|
| 833 |
+
t = self.callback[t.type](t)
|
| 834 |
+
if not isinstance(t, Token):
|
| 835 |
+
raise LexError(
|
| 836 |
+
"Callbacks must return a token (returned %r)" % t
|
| 837 |
+
)
|
| 838 |
+
lex_state.last_token = t
|
| 839 |
+
return t
|
| 840 |
+
|
| 841 |
+
if type_name in self.callback:
|
| 842 |
+
t2 = Token(
|
| 843 |
+
type_name, value, line_ctr.char_pos, line_ctr.line, line_ctr.column
|
| 844 |
+
)
|
| 845 |
+
self.callback[type_name](t2)
|
| 846 |
+
|
| 847 |
+
line_ctr.feed(value, type_name in self.newline_types)
|
| 848 |
+
|
| 849 |
+
last_fsm_state_seq = None
|
| 850 |
+
|
| 851 |
+
raise EOFError(self)
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
class PartialIndenter(Indenter):
|
| 855 |
+
"""An `Indenter` that doesn't reset its state every time `process` is called."""
|
| 856 |
+
|
| 857 |
+
def process(self, stream):
|
| 858 |
+
return self._process(stream)
|
| 859 |
+
|
| 860 |
+
def _process(self, stream):
|
| 861 |
+
for token in stream:
|
| 862 |
+
# These were previously *after* the `yield`, but that makes the
|
| 863 |
+
# state tracking unnecessarily convoluted.
|
| 864 |
+
if token.type in self.OPEN_PAREN_types:
|
| 865 |
+
self.paren_level += 1
|
| 866 |
+
elif token.type in self.CLOSE_PAREN_types:
|
| 867 |
+
self.paren_level -= 1
|
| 868 |
+
if self.paren_level < 0:
|
| 869 |
+
raise UnexpectedToken(token, [])
|
| 870 |
+
|
| 871 |
+
if token.type == self.NL_type:
|
| 872 |
+
yield from self.handle_NL(token)
|
| 873 |
+
else:
|
| 874 |
+
yield token
|
| 875 |
+
|
| 876 |
+
# TODO: What do we want to do here?
|
| 877 |
+
# while len(self.indent_level) > 1:
|
| 878 |
+
# self.indent_level.pop()
|
| 879 |
+
# yield Token(self.DEDENT_type, "")
|
| 880 |
+
|
| 881 |
+
def accepts_token_type(self, token_type):
|
| 882 |
+
if token_type in self.CLOSE_PAREN_types and self.paren_level - 1 < 0:
|
| 883 |
+
return False
|
| 884 |
+
|
| 885 |
+
# TODO:
|
| 886 |
+
# if token_type == self.NL_type and self.paren_level == 0:
|
| 887 |
+
# ...
|
| 888 |
+
# return False
|
| 889 |
+
|
| 890 |
+
return True
|
| 891 |
+
|
| 892 |
+
def __copy__(self):
|
| 893 |
+
res = type(self)()
|
| 894 |
+
res.paren_level = self.paren_level
|
| 895 |
+
res.indent_level = copy(self.indent_level)
|
| 896 |
+
return res
|
| 897 |
+
|
| 898 |
+
def __repr__(self):
|
| 899 |
+
return f"{type(self).__name__}(paren_level={self.paren_level!r}, indent_level={self.indent_level!r})"
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
class PartialPythonIndenter(PartialIndenter):
|
| 903 |
+
NL_type = "_NEWLINE"
|
| 904 |
+
OPEN_PAREN_types = ["LPAR", "LSQB", "LBRACE"]
|
| 905 |
+
CLOSE_PAREN_types = ["RPAR", "RSQB", "RBRACE"]
|
| 906 |
+
INDENT_type = "_INDENT"
|
| 907 |
+
DEDENT_type = "_DEDENT"
|
| 908 |
+
tab_len = 8
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def get_contextual_lexer(x: Union[PartialLexerThread, PartialParsingFrontend]):
|
| 912 |
+
if isinstance(x.lexer, ContextualLexer):
|
| 913 |
+
return x.lexer
|
| 914 |
+
else:
|
| 915 |
+
return x.lexer.lexer
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]:
|
| 919 |
+
"""Construct a ``dict`` mapping terminal symbol names to their finite state machines."""
|
| 920 |
+
|
| 921 |
+
symbol_names_and_fsms = {}
|
| 922 |
+
for terminal in lp.terminals:
|
| 923 |
+
pattern = interegular.parse_pattern(terminal.pattern.to_regexp())
|
| 924 |
+
# TODO: Use `pyparser.terminals[0].pattern.flags`?
|
| 925 |
+
try:
|
| 926 |
+
fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce())
|
| 927 |
+
except Unsupported:
|
| 928 |
+
fsm = None
|
| 929 |
+
|
| 930 |
+
symbol_names_and_fsms[terminal.name] = fsm
|
| 931 |
+
|
| 932 |
+
return symbol_names_and_fsms
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def fsm_union(
|
| 936 |
+
fsms: Sequence[FSM],
|
| 937 |
+
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
|
| 938 |
+
"""Construct an FSM representing the union of the FSMs in `fsms`.
|
| 939 |
+
|
| 940 |
+
This is an updated version of `interegular.fsm.FSM.union` made to return an
|
| 941 |
+
extra map of component FSMs to the sets of state transitions that
|
| 942 |
+
correspond to them in the new FSM.
|
| 943 |
+
|
| 944 |
+
"""
|
| 945 |
+
|
| 946 |
+
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
|
| 947 |
+
|
| 948 |
+
indexed_fsms = tuple(enumerate(fsms))
|
| 949 |
+
|
| 950 |
+
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
|
| 951 |
+
|
| 952 |
+
# Dedicated function accepting a "superset" and returning the next
|
| 953 |
+
# "superset" obtained by following this transition in the new FSM
|
| 954 |
+
def follow(current_state, new_transition: int):
|
| 955 |
+
next = {}
|
| 956 |
+
for i, f in indexed_fsms:
|
| 957 |
+
old_transition = new_to_old[i][new_transition]
|
| 958 |
+
if (
|
| 959 |
+
i in current_state
|
| 960 |
+
and current_state[i] in f.map
|
| 961 |
+
and old_transition in f.map[current_state[i]]
|
| 962 |
+
):
|
| 963 |
+
next[i] = f.map[current_state[i]][old_transition]
|
| 964 |
+
if not next:
|
| 965 |
+
raise OblivionError
|
| 966 |
+
return next
|
| 967 |
+
|
| 968 |
+
states = [initial]
|
| 969 |
+
finals: Set[int] = set()
|
| 970 |
+
map: Dict[int, Dict[int, int]] = {}
|
| 971 |
+
|
| 972 |
+
# Map component FSMs to their new state-to-state transitions, finals, and a
|
| 973 |
+
# map translating component FSM states to aggregate FSM states
|
| 974 |
+
fsms_to_trans_finals: Dict[
|
| 975 |
+
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
| 976 |
+
] = {}
|
| 977 |
+
|
| 978 |
+
i = 0
|
| 979 |
+
while i < len(states):
|
| 980 |
+
state = states[i]
|
| 981 |
+
|
| 982 |
+
# Add to the finals of the aggregate FSM whenever we hit a final in a
|
| 983 |
+
# component FSM
|
| 984 |
+
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
|
| 985 |
+
finals.add(i)
|
| 986 |
+
|
| 987 |
+
# Compute the map for this state
|
| 988 |
+
map[i] = {}
|
| 989 |
+
for transition in alphabet.by_transition:
|
| 990 |
+
try:
|
| 991 |
+
next = follow(state, transition)
|
| 992 |
+
except OblivionError:
|
| 993 |
+
# Reached an oblivion state; don't list it
|
| 994 |
+
continue
|
| 995 |
+
else:
|
| 996 |
+
try:
|
| 997 |
+
# TODO: Seems like this could--and should--be avoided
|
| 998 |
+
j = states.index(next)
|
| 999 |
+
except ValueError:
|
| 1000 |
+
j = len(states)
|
| 1001 |
+
states.append(next)
|
| 1002 |
+
|
| 1003 |
+
map[i][transition] = j
|
| 1004 |
+
|
| 1005 |
+
for fsm_id, fsm_state in next.items():
|
| 1006 |
+
(
|
| 1007 |
+
fsm_transitions,
|
| 1008 |
+
fsm_finals,
|
| 1009 |
+
fsm_old_to_new,
|
| 1010 |
+
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
|
| 1011 |
+
old_from = state[fsm_id]
|
| 1012 |
+
old_to = fsm_state
|
| 1013 |
+
fsm_old_to_new.setdefault(old_from, set()).add(i)
|
| 1014 |
+
fsm_old_to_new.setdefault(old_to, set()).add(j)
|
| 1015 |
+
fsm_transitions.add((i, j))
|
| 1016 |
+
if fsm_state in fsms[fsm_id].finals:
|
| 1017 |
+
fsm_finals.add(j)
|
| 1018 |
+
|
| 1019 |
+
i += 1
|
| 1020 |
+
|
| 1021 |
+
fsm = FSM(
|
| 1022 |
+
alphabet=alphabet,
|
| 1023 |
+
states=range(len(states)),
|
| 1024 |
+
initial=0,
|
| 1025 |
+
finals=finals,
|
| 1026 |
+
map=map,
|
| 1027 |
+
__no_validation__=True,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
fsm, old_to_new_states = make_deterministic_fsm(fsm)
|
| 1031 |
+
_fsms_to_trans_finals = {
|
| 1032 |
+
fsm_id: (
|
| 1033 |
+
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
|
| 1034 |
+
{old_to_new_states[s] for s in finals},
|
| 1035 |
+
{
|
| 1036 |
+
old_state: {old_to_new_states[new_state] for new_state in new_states}
|
| 1037 |
+
for old_state, new_states in old_to_new.items()
|
| 1038 |
+
},
|
| 1039 |
+
)
|
| 1040 |
+
for fsm_id, (transitions, finals, old_to_new) in sorted(
|
| 1041 |
+
fsms_to_trans_finals.items(), key=lambda x: x[0]
|
| 1042 |
+
)
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
return (
|
| 1046 |
+
fsm,
|
| 1047 |
+
_fsms_to_trans_finals,
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def get_sub_fsms_from_seq(
|
| 1052 |
+
state_seq: Sequence[int],
|
| 1053 |
+
fsms_to_trans_finals: Dict[
|
| 1054 |
+
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
| 1055 |
+
],
|
| 1056 |
+
) -> Generator[Tuple[int, bool, bool], None, None]:
|
| 1057 |
+
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
|
| 1058 |
+
|
| 1059 |
+
Parameters
|
| 1060 |
+
----------
|
| 1061 |
+
state_seq
|
| 1062 |
+
A state sequence.
|
| 1063 |
+
fsms_to_trans_finals
|
| 1064 |
+
A map from FSM indices to tuples containing sets of their state transitions
|
| 1065 |
+
and sets of the final/accept states.
|
| 1066 |
+
|
| 1067 |
+
Returns
|
| 1068 |
+
-------
|
| 1069 |
+
A generator returning tuples containing each sub-FSM index (in the order
|
| 1070 |
+
they were union-ed to construct `fsm`) and booleans indicating whether or
|
| 1071 |
+
not there is another valid transition from the last state in the sequence
|
| 1072 |
+
for the associated sub-FSM (i.e. if the FSM can continue
|
| 1073 |
+
accepting/matching) and whether or not the sequence ends in a final state
|
| 1074 |
+
of the sub-FSM.
|
| 1075 |
+
"""
|
| 1076 |
+
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
|
| 1077 |
+
last_fsm_state = state_seq[-1]
|
| 1078 |
+
yield from (
|
| 1079 |
+
(
|
| 1080 |
+
# The sub-FMS index
|
| 1081 |
+
fsm_idx,
|
| 1082 |
+
# Is there another possible transition in this sub-FSM?
|
| 1083 |
+
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
|
| 1084 |
+
# Is this sub-FSM in a final state?
|
| 1085 |
+
state_seq[-1] in finals,
|
| 1086 |
+
)
|
| 1087 |
+
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
|
| 1088 |
+
if state_seq_transitions.issubset(transitions)
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
def walk_fsm(
|
| 1093 |
+
fsm: BetterFSM,
|
| 1094 |
+
token_transition_keys: Sequence[int],
|
| 1095 |
+
start_state: int,
|
| 1096 |
+
full_match: bool = True,
|
| 1097 |
+
) -> List[int]:
|
| 1098 |
+
fsm_finals = fsm.finals
|
| 1099 |
+
|
| 1100 |
+
state = start_state
|
| 1101 |
+
accepted_states: List[int] = []
|
| 1102 |
+
last_final_idx: int = 0
|
| 1103 |
+
|
| 1104 |
+
fsm_transitions = fsm.flat_transition_map
|
| 1105 |
+
|
| 1106 |
+
# Iterate over token transition key sequence. The transition key
|
| 1107 |
+
# sequence represents the FSM traversal rules of the tokens symbols.
|
| 1108 |
+
for i, trans_key in enumerate(token_transition_keys):
|
| 1109 |
+
new_state = fsm_transitions.get((state, trans_key))
|
| 1110 |
+
|
| 1111 |
+
if new_state is None:
|
| 1112 |
+
if not full_match and last_final_idx > 0:
|
| 1113 |
+
return accepted_states[:last_final_idx]
|
| 1114 |
+
|
| 1115 |
+
return []
|
| 1116 |
+
|
| 1117 |
+
state = new_state
|
| 1118 |
+
|
| 1119 |
+
if state in fsm_finals:
|
| 1120 |
+
last_final_idx = i + 1
|
| 1121 |
+
|
| 1122 |
+
accepted_states.append(state)
|
| 1123 |
+
|
| 1124 |
+
if full_match and last_final_idx - 1 != i:
|
| 1125 |
+
return []
|
| 1126 |
+
|
| 1127 |
+
return accepted_states
|
.venv/lib/python3.11/site-packages/outlines/fsm/types.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
from enum import EnumMeta
|
| 3 |
+
from typing import Any, Protocol, Tuple, Type
|
| 4 |
+
|
| 5 |
+
from typing_extensions import _AnnotatedAlias, get_args
|
| 6 |
+
|
| 7 |
+
INTEGER = r"[+-]?(0|[1-9][0-9]*)"
|
| 8 |
+
BOOLEAN = "(True|False)"
|
| 9 |
+
FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?"
|
| 10 |
+
DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])"
|
| 11 |
+
TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])"
|
| 12 |
+
DATETIME = rf"({DATE})(\s)({TIME})"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FormatFunction(Protocol):
|
| 16 |
+
def __call__(self, sequence: str) -> Any:
|
| 17 |
+
...
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
|
| 21 |
+
# If it is a custom type
|
| 22 |
+
if isinstance(python_type, _AnnotatedAlias):
|
| 23 |
+
json_schema = get_args(python_type)[1].json_schema
|
| 24 |
+
type_class = get_args(python_type)[0]
|
| 25 |
+
|
| 26 |
+
custom_regex_str = json_schema["pattern"]
|
| 27 |
+
|
| 28 |
+
def custom_format_fn(sequence: str) -> Any:
|
| 29 |
+
return type_class(sequence)
|
| 30 |
+
|
| 31 |
+
return custom_regex_str, custom_format_fn
|
| 32 |
+
|
| 33 |
+
if isinstance(python_type, EnumMeta):
|
| 34 |
+
values = python_type.__members__.keys()
|
| 35 |
+
enum_regex_str: str = "(" + "|".join(values) + ")"
|
| 36 |
+
|
| 37 |
+
def enum_format_fn(sequence: str) -> str:
|
| 38 |
+
return str(sequence)
|
| 39 |
+
|
| 40 |
+
return enum_regex_str, enum_format_fn
|
| 41 |
+
|
| 42 |
+
if python_type == float:
|
| 43 |
+
|
| 44 |
+
def float_format_fn(sequence: str) -> float:
|
| 45 |
+
return float(sequence)
|
| 46 |
+
|
| 47 |
+
return FLOAT, float_format_fn
|
| 48 |
+
elif python_type == int:
|
| 49 |
+
|
| 50 |
+
def int_format_fn(sequence: str) -> int:
|
| 51 |
+
return int(sequence)
|
| 52 |
+
|
| 53 |
+
return INTEGER, int_format_fn
|
| 54 |
+
elif python_type == bool:
|
| 55 |
+
|
| 56 |
+
def bool_format_fn(sequence: str) -> bool:
|
| 57 |
+
return bool(sequence)
|
| 58 |
+
|
| 59 |
+
return BOOLEAN, bool_format_fn
|
| 60 |
+
elif python_type == datetime.date:
|
| 61 |
+
|
| 62 |
+
def date_format_fn(sequence: str) -> datetime.date:
|
| 63 |
+
return datetime.datetime.strptime(sequence, "%Y-%m-%d").date()
|
| 64 |
+
|
| 65 |
+
return DATE, date_format_fn
|
| 66 |
+
elif python_type == datetime.time:
|
| 67 |
+
|
| 68 |
+
def time_format_fn(sequence: str) -> datetime.time:
|
| 69 |
+
return datetime.datetime.strptime(sequence, "%H:%M:%S").time()
|
| 70 |
+
|
| 71 |
+
return TIME, time_format_fn
|
| 72 |
+
elif python_type == datetime.datetime:
|
| 73 |
+
|
| 74 |
+
def datetime_format_fn(sequence: str) -> datetime.datetime:
|
| 75 |
+
return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")
|
| 76 |
+
|
| 77 |
+
return DATETIME, datetime_format_fn
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError(
|
| 80 |
+
f"The Python type {python_type} is not supported. Please open an issue."
|
| 81 |
+
)
|
.venv/lib/python3.11/site-packages/outlines/function.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
from outlines import generate, models
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from outlines.generate.api import SequenceGenerator
|
| 11 |
+
from outlines.prompts import Prompt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Function:
|
| 16 |
+
"""Represents an Outlines function.
|
| 17 |
+
|
| 18 |
+
Functions are a convenient way to encapsulate a prompt template, a language
|
| 19 |
+
model and a Pydantic model that define the output structure. Once defined,
|
| 20 |
+
the function can be called with arguments that will be used to render the
|
| 21 |
+
prompt template.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
prompt_template: "Prompt"
|
| 26 |
+
schema: Union[str, Callable, object]
|
| 27 |
+
model_name: str
|
| 28 |
+
generator: Optional["SequenceGenerator"] = None
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def from_github(cls, program_path: str, function_name: str = "fn"):
|
| 32 |
+
"""Load a function stored on GitHub"""
|
| 33 |
+
program_content = download_from_github(program_path)
|
| 34 |
+
function = extract_function_from_file(program_content, function_name)
|
| 35 |
+
|
| 36 |
+
return function
|
| 37 |
+
|
| 38 |
+
def init_generator(self):
|
| 39 |
+
"""Load the model and initialize the generator."""
|
| 40 |
+
model = models.transformers(self.model_name)
|
| 41 |
+
self.generator = generate.json(model, self.schema)
|
| 42 |
+
|
| 43 |
+
def __call__(self, *args, **kwargs):
|
| 44 |
+
"""Call the function.
|
| 45 |
+
|
| 46 |
+
.. warning::
|
| 47 |
+
|
| 48 |
+
This currently does not support batching.
|
| 49 |
+
|
| 50 |
+
Parameters
|
| 51 |
+
----------
|
| 52 |
+
args
|
| 53 |
+
Values to pass to the prompt template as positional arguments.
|
| 54 |
+
kwargs
|
| 55 |
+
Values to pass to the prompt template as keyword arguments.
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
if self.generator is None:
|
| 59 |
+
self.init_generator()
|
| 60 |
+
|
| 61 |
+
prompt = self.prompt_template(*args, **kwargs)
|
| 62 |
+
return self.generator(prompt)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def download_from_github(short_path: str):
|
| 66 |
+
"""Download the file in which the function is stored on GitHub."""
|
| 67 |
+
GITHUB_BASE_URL = "https://raw.githubusercontent.com"
|
| 68 |
+
BRANCH = "main"
|
| 69 |
+
|
| 70 |
+
path = short_path.split("/")
|
| 71 |
+
if len(path) < 3:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}."
|
| 74 |
+
)
|
| 75 |
+
elif short_path[-3:] == ".py":
|
| 76 |
+
raise ValueError("Do not append the `.py` extension to the program name.")
|
| 77 |
+
|
| 78 |
+
username = path[0]
|
| 79 |
+
repo = path[1]
|
| 80 |
+
path_to_file = path[2:]
|
| 81 |
+
|
| 82 |
+
url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py"
|
| 83 |
+
result = requests.get(url)
|
| 84 |
+
|
| 85 |
+
if result.status_code == 200:
|
| 86 |
+
return result.text
|
| 87 |
+
elif result.status_code == 404:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly."
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
result.raise_for_status()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]:
|
| 96 |
+
"""Extract a function object from a downloaded file."""
|
| 97 |
+
|
| 98 |
+
spec = importlib.util.spec_from_loader(
|
| 99 |
+
"outlines_function", loader=None, origin="github"
|
| 100 |
+
)
|
| 101 |
+
if spec is not None:
|
| 102 |
+
module = importlib.util.module_from_spec(spec)
|
| 103 |
+
exec(content, module.__dict__)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
fn = getattr(module, function_name)
|
| 107 |
+
except AttributeError:
|
| 108 |
+
raise AttributeError(
|
| 109 |
+
"Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not isinstance(fn, module.outlines.Function):
|
| 113 |
+
raise TypeError(
|
| 114 |
+
f"The `{function_name}` variable in the program must be an instance of `outlines.Function`"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return fn
|
.venv/lib/python3.11/site-packages/outlines/generate/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .api import SequenceGenerator
|
| 2 |
+
from .cfg import cfg
|
| 3 |
+
from .choice import choice
|
| 4 |
+
from .format import format
|
| 5 |
+
from .fsm import fsm
|
| 6 |
+
from .json import json
|
| 7 |
+
from .regex import regex
|
| 8 |
+
from .text import text
|
.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/choice.cpython-311.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/generator.cpython-311.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/generate/__pycache__/json.cpython-311.pyc
ADDED
|
Binary file (6.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/generate/api.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
from copy import copy
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from outlines.generate.generator import sequence_generator
|
| 7 |
+
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
FormattedOutput = Union[
|
| 13 |
+
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SequenceGenerator:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
fsm,
|
| 21 |
+
model,
|
| 22 |
+
sampler,
|
| 23 |
+
device,
|
| 24 |
+
):
|
| 25 |
+
self.fsm = fsm
|
| 26 |
+
self.model = model
|
| 27 |
+
self.sampler = sampler
|
| 28 |
+
self.tokenizer = model.tokenizer
|
| 29 |
+
self.device = device
|
| 30 |
+
self.num_samples = sampler.samples
|
| 31 |
+
|
| 32 |
+
def get_generated_token_ids(
|
| 33 |
+
self,
|
| 34 |
+
prompt_token_ids: "torch.Tensor",
|
| 35 |
+
token_ids: "torch.Tensor",
|
| 36 |
+
) -> List["torch.Tensor"]:
|
| 37 |
+
"""Get the tokens generated so far.
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
prompt_token_ids
|
| 42 |
+
Tensor that contains the token ids of the sequences' prompts.
|
| 43 |
+
token_ids
|
| 44 |
+
The generated token ids.
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
A tensor that contains the token ids that have been generated so far.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
prompt_lengths = [len(prompt) for prompt in prompt_token_ids]
|
| 52 |
+
token_ids = [
|
| 53 |
+
cur_token_ids[length:]
|
| 54 |
+
for cur_token_ids, length in zip(token_ids, prompt_lengths)
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
return token_ids
|
| 58 |
+
|
| 59 |
+
def is_stop_sequence_found(
|
| 60 |
+
self, generated_sequences: List[str], stop_sequences: List[str]
|
| 61 |
+
) -> bool:
|
| 62 |
+
"""Determine whether one of the stop sequences has been generated.
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
generated_sequences
|
| 67 |
+
The list of sequences generated so far.
|
| 68 |
+
stop_sequences
|
| 69 |
+
The list that contains the sequence which stop the generation when
|
| 70 |
+
found.
|
| 71 |
+
|
| 72 |
+
Returns
|
| 73 |
+
-------
|
| 74 |
+
True if at least one of the stop sequences has been found in each generated
|
| 75 |
+
sequence.
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
return all(
|
| 79 |
+
[
|
| 80 |
+
any([seq in generated for seq in stop_sequences])
|
| 81 |
+
for generated in generated_sequences
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def strip_stop_sequences(
|
| 86 |
+
self, sequence: str, stop_sequences: Optional[List[str]]
|
| 87 |
+
) -> str:
|
| 88 |
+
"""Remove the stop sequences from the generated sequences.
|
| 89 |
+
|
| 90 |
+
Parameters
|
| 91 |
+
----------
|
| 92 |
+
sequence
|
| 93 |
+
One of the generated sequences.
|
| 94 |
+
stop_sequences
|
| 95 |
+
The list that contains the sequence which stop the generation when
|
| 96 |
+
found.
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
if stop_sequences:
|
| 100 |
+
match_indexes = [sequence.find(seq) for seq in stop_sequences]
|
| 101 |
+
if any([index != -1 for index in match_indexes]):
|
| 102 |
+
# select the stop_sequence that is found first in the sequence
|
| 103 |
+
min_match_index_value = min([i for i in match_indexes if i != -1])
|
| 104 |
+
min_match_index_pos = match_indexes.index(min_match_index_value)
|
| 105 |
+
sequence = sequence[
|
| 106 |
+
: match_indexes[min_match_index_pos]
|
| 107 |
+
+ len(stop_sequences[min_match_index_pos])
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
return sequence
|
| 111 |
+
|
| 112 |
+
def format_sequence(self, sequence: str) -> FormattedOutput:
|
| 113 |
+
"""Translate the generated sequence to another type.
|
| 114 |
+
|
| 115 |
+
This method is for instance overridden when generating JSON to either
|
| 116 |
+
return a dictionnary or a Pydantic model.
|
| 117 |
+
|
| 118 |
+
Parameters
|
| 119 |
+
----------
|
| 120 |
+
sequence
|
| 121 |
+
A generated sequences.
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
The formatted sequence.
|
| 126 |
+
|
| 127 |
+
"""
|
| 128 |
+
return sequence
|
| 129 |
+
|
| 130 |
+
def __call__(
|
| 131 |
+
self,
|
| 132 |
+
prompts: Union[str, List[str]],
|
| 133 |
+
max_tokens: Optional[int] = None,
|
| 134 |
+
stop_at: Optional[Union[str, List[str]]] = None,
|
| 135 |
+
rng: Optional["torch.Generator"] = None,
|
| 136 |
+
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
|
| 137 |
+
"""Generate the full text sequence.
|
| 138 |
+
|
| 139 |
+
Since `SequenceGenerator.stream` calls the tokenizer at every step this
|
| 140 |
+
method loops over the generator returned by `sequence_generator` itself
|
| 141 |
+
so the tokenizer is called only once after all token ids have been
|
| 142 |
+
generated.
|
| 143 |
+
|
| 144 |
+
Parameters
|
| 145 |
+
----------
|
| 146 |
+
prompts
|
| 147 |
+
A string or list of strings that are passed to the model before
|
| 148 |
+
generating the first token.
|
| 149 |
+
max_tokens
|
| 150 |
+
An integer representing maximum number of tokens that will be generated
|
| 151 |
+
(per prompt)
|
| 152 |
+
stop_at
|
| 153 |
+
A string or list of strings at which the text generated will stop
|
| 154 |
+
rng
|
| 155 |
+
The random number generator. Defaults to a non-seeded `torch.Generator`
|
| 156 |
+
instance.
|
| 157 |
+
|
| 158 |
+
Returns
|
| 159 |
+
-------
|
| 160 |
+
The generation(s), potentially cast to another type.
|
| 161 |
+
"""
|
| 162 |
+
import torch
|
| 163 |
+
|
| 164 |
+
if isinstance(prompts, str):
|
| 165 |
+
prompts = [prompts]
|
| 166 |
+
|
| 167 |
+
if isinstance(stop_at, str):
|
| 168 |
+
stop_at = [stop_at]
|
| 169 |
+
|
| 170 |
+
stop_sequences = stop_at
|
| 171 |
+
num_samples = self.num_samples
|
| 172 |
+
|
| 173 |
+
if rng is None:
|
| 174 |
+
rng = torch.Generator(device=self.device)
|
| 175 |
+
rng.seed()
|
| 176 |
+
|
| 177 |
+
prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
|
| 178 |
+
prompt_token_ids = prompt_token_ids.to(self.device)
|
| 179 |
+
attention_masks = attention_masks.to(self.device)
|
| 180 |
+
|
| 181 |
+
# To draw multiple samples we repeat the prompt as many times
|
| 182 |
+
# as there are samples. We copy the FSMs and initialize the
|
| 183 |
+
# FSM states.
|
| 184 |
+
num_samples = self.num_samples
|
| 185 |
+
batch_size = len(prompts)
|
| 186 |
+
|
| 187 |
+
prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
|
| 188 |
+
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
|
| 189 |
+
fsm_states = [0 for _ in range(batch_size * num_samples)]
|
| 190 |
+
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
|
| 191 |
+
weights = torch.zeros(
|
| 192 |
+
(batch_size * num_samples), dtype=torch.float, device=self.device
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
states = sequence_generator(
|
| 196 |
+
self.model,
|
| 197 |
+
self.sampler,
|
| 198 |
+
fsms,
|
| 199 |
+
prompt_token_ids,
|
| 200 |
+
weights,
|
| 201 |
+
attention_masks,
|
| 202 |
+
fsm_states,
|
| 203 |
+
rng=rng,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
while True:
|
| 207 |
+
try:
|
| 208 |
+
last_state = next(states)
|
| 209 |
+
if max_tokens or stop_sequences:
|
| 210 |
+
token_ids = last_state.token_ids
|
| 211 |
+
generated_token_ids = self.get_generated_token_ids(
|
| 212 |
+
prompt_token_ids, token_ids
|
| 213 |
+
)
|
| 214 |
+
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
|
| 215 |
+
break
|
| 216 |
+
if stop_sequences and self.is_stop_sequence_found(
|
| 217 |
+
self.tokenizer.decode(generated_token_ids), stop_sequences
|
| 218 |
+
):
|
| 219 |
+
break
|
| 220 |
+
except StopIteration:
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
token_ids = last_state.token_ids
|
| 224 |
+
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)
|
| 225 |
+
|
| 226 |
+
generated = self.tokenizer.decode(generated_token_ids)
|
| 227 |
+
stripped = [
|
| 228 |
+
self.strip_stop_sequences(sequence, stop_sequences)
|
| 229 |
+
for sequence in generated
|
| 230 |
+
]
|
| 231 |
+
formatted = [self.format_sequence(sequence) for sequence in stripped]
|
| 232 |
+
|
| 233 |
+
# We reshape the output to (batch_size, sample_size)
|
| 234 |
+
output: List[List[FormattedOutput]] = list()
|
| 235 |
+
for i in range(0, batch_size * num_samples, num_samples):
|
| 236 |
+
output.append(formatted[i : i + num_samples])
|
| 237 |
+
|
| 238 |
+
# We remove leading dimensions for the output
|
| 239 |
+
if batch_size == 1 and num_samples == 1:
|
| 240 |
+
return output[0][0]
|
| 241 |
+
elif batch_size == 1:
|
| 242 |
+
return output[0]
|
| 243 |
+
elif num_samples == 1:
|
| 244 |
+
return [samples[0] for samples in output]
|
| 245 |
+
else:
|
| 246 |
+
return output
|
| 247 |
+
|
| 248 |
+
def stream(
|
| 249 |
+
self,
|
| 250 |
+
prompts: Union[str, List[str]],
|
| 251 |
+
max_tokens: Optional[int] = None,
|
| 252 |
+
stop_at: Optional[Union[str, List[str]]] = None,
|
| 253 |
+
rng: Optional["torch.Generator"] = None,
|
| 254 |
+
) -> Iterator[Union[List[str], str, List[List[str]]]]:
|
| 255 |
+
"""Generate the text sequence one token at a time.
|
| 256 |
+
|
| 257 |
+
Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
|
| 258 |
+
choice but to decode the generated token ids at each step and compare the
|
| 259 |
+
current decoded strings to the previously decoded strings.
|
| 260 |
+
|
| 261 |
+
Parameters
|
| 262 |
+
----------
|
| 263 |
+
prompts
|
| 264 |
+
A string or list of strings that are passed to the model before
|
| 265 |
+
generating the first token.
|
| 266 |
+
max_tokens
|
| 267 |
+
An integer representing maximum number of tokens that will be generated
|
| 268 |
+
(per prompt)
|
| 269 |
+
stop_at
|
| 270 |
+
A string or list of strings at which the text generated will stop
|
| 271 |
+
rng
|
| 272 |
+
The random number generator. Defaults to a non-seeded `torch.Generator`
|
| 273 |
+
instance.
|
| 274 |
+
|
| 275 |
+
Returns
|
| 276 |
+
-------
|
| 277 |
+
A string or list of strings that contain the generated text.
|
| 278 |
+
|
| 279 |
+
"""
|
| 280 |
+
import torch
|
| 281 |
+
|
| 282 |
+
if isinstance(prompts, str):
|
| 283 |
+
prompts = [prompts]
|
| 284 |
+
|
| 285 |
+
if isinstance(stop_at, str):
|
| 286 |
+
stop_at = [stop_at]
|
| 287 |
+
|
| 288 |
+
stop_sequences = stop_at
|
| 289 |
+
num_samples = self.num_samples
|
| 290 |
+
|
| 291 |
+
prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
|
| 292 |
+
prompt_token_ids = prompt_token_ids.to(self.device)
|
| 293 |
+
attention_masks = attention_masks.to(prompt_token_ids.device)
|
| 294 |
+
|
| 295 |
+
# To draw multiple samples we repeat the prompt as many times
|
| 296 |
+
# as there are samples. We copy the FSMs and initialize the
|
| 297 |
+
# FSM states.
|
| 298 |
+
num_samples = self.num_samples
|
| 299 |
+
batch_size = len(prompts)
|
| 300 |
+
|
| 301 |
+
prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
|
| 302 |
+
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
|
| 303 |
+
fsm_states = [0 for _ in range(batch_size * num_samples)]
|
| 304 |
+
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
|
| 305 |
+
weights = torch.zeros(
|
| 306 |
+
(batch_size * num_samples),
|
| 307 |
+
dtype=torch.float,
|
| 308 |
+
device=prompt_token_ids.device,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if rng is None:
|
| 312 |
+
rng = torch.Generator(device=prompt_token_ids.device)
|
| 313 |
+
rng.seed()
|
| 314 |
+
|
| 315 |
+
states = sequence_generator(
|
| 316 |
+
self.model,
|
| 317 |
+
self.sampler,
|
| 318 |
+
fsms,
|
| 319 |
+
prompt_token_ids,
|
| 320 |
+
weights,
|
| 321 |
+
attention_masks,
|
| 322 |
+
fsm_states,
|
| 323 |
+
rng=rng,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
|
| 327 |
+
previously_generated_sequences = [
|
| 328 |
+
"" for _ in range(batch_size)
|
| 329 |
+
] * num_samples
|
| 330 |
+
num_generated = 0
|
| 331 |
+
is_stop_at_reached = [False for _ in range(batch_size)] * num_samples
|
| 332 |
+
while True:
|
| 333 |
+
if (max_tokens and num_generated >= max_tokens) or all(
|
| 334 |
+
is_stop_at_reached
|
| 335 |
+
):
|
| 336 |
+
return
|
| 337 |
+
try:
|
| 338 |
+
sequence = next(states)
|
| 339 |
+
num_generated += 1
|
| 340 |
+
except StopIteration:
|
| 341 |
+
return
|
| 342 |
+
generated_token_ids = sequence.token_ids[:, -num_generated:]
|
| 343 |
+
generated_sequences = self.tokenizer.decode(generated_token_ids)
|
| 344 |
+
if stop_sequences:
|
| 345 |
+
is_stop_at_reached = [
|
| 346 |
+
stop
|
| 347 |
+
or self.is_stop_sequence_found(
|
| 348 |
+
[generated_sequence], stop_sequences
|
| 349 |
+
)
|
| 350 |
+
for generated_sequence, stop in zip(
|
| 351 |
+
generated_sequences, is_stop_at_reached
|
| 352 |
+
)
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
generated_sequences = [
|
| 356 |
+
self.format_sequence(
|
| 357 |
+
self.strip_stop_sequences(sequence, stop_sequences)
|
| 358 |
+
)
|
| 359 |
+
if stop
|
| 360 |
+
else sequence
|
| 361 |
+
for sequence, stop in zip(
|
| 362 |
+
generated_sequences, is_stop_at_reached
|
| 363 |
+
)
|
| 364 |
+
]
|
| 365 |
+
next_tokens = [
|
| 366 |
+
token[len(sequence) :]
|
| 367 |
+
for token, sequence, stop in zip(
|
| 368 |
+
generated_sequences,
|
| 369 |
+
previously_generated_sequences,
|
| 370 |
+
is_stop_at_reached,
|
| 371 |
+
)
|
| 372 |
+
]
|
| 373 |
+
previously_generated_sequences = generated_sequences
|
| 374 |
+
# We reshape the output to (batch_size, sample_size)
|
| 375 |
+
output: List[List[str]] = list()
|
| 376 |
+
for i in range(0, batch_size * num_samples, num_samples):
|
| 377 |
+
output.append(next_tokens[i : i + num_samples])
|
| 378 |
+
|
| 379 |
+
# We remove leading dimensions for the output
|
| 380 |
+
if batch_size == 1 and num_samples == 1:
|
| 381 |
+
yield output[0][0]
|
| 382 |
+
elif batch_size == 1:
|
| 383 |
+
yield output[0]
|
| 384 |
+
elif num_samples == 1:
|
| 385 |
+
yield [samples[0] for samples in output]
|
| 386 |
+
else:
|
| 387 |
+
yield output
|
| 388 |
+
|
| 389 |
+
return token_generator()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@dataclass(frozen=True)
|
| 393 |
+
class GenerationParameters:
|
| 394 |
+
"""Generation parameters used in Outlines' public API."""
|
| 395 |
+
|
| 396 |
+
max_tokens: Optional[int]
|
| 397 |
+
stop_at: Optional[Union[str, List[str]]]
|
| 398 |
+
seed: Optional[int]
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
@dataclass(frozen=True)
|
| 402 |
+
class SamplingParameters:
|
| 403 |
+
"""Sampling parameters available in Outlines."""
|
| 404 |
+
|
| 405 |
+
sampler: str
|
| 406 |
+
num_samples: int = 1
|
| 407 |
+
top_p: Optional[float] = None
|
| 408 |
+
top_k: Optional[int] = None
|
| 409 |
+
temperature: Optional[float] = None
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class SequenceGeneratorAdapter:
|
| 413 |
+
"""Class used to unify the interface to the model providers'
|
| 414 |
+
generation functions.
|
| 415 |
+
|
| 416 |
+
Attributes
|
| 417 |
+
----------
|
| 418 |
+
model
|
| 419 |
+
The wrapped model.
|
| 420 |
+
logits_processor
|
| 421 |
+
The logits processor to use to generate text.
|
| 422 |
+
sampler
|
| 423 |
+
The sampler to use to generate text.
|
| 424 |
+
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(self, model, logits_processor, sampler):
|
| 428 |
+
self.model = model
|
| 429 |
+
self.logits_processor = logits_processor
|
| 430 |
+
|
| 431 |
+
if isinstance(sampler, MultinomialSampler):
|
| 432 |
+
self.sampling_params = SamplingParameters(
|
| 433 |
+
"multinomial",
|
| 434 |
+
sampler.samples,
|
| 435 |
+
sampler.top_p,
|
| 436 |
+
sampler.top_k,
|
| 437 |
+
sampler.temperature,
|
| 438 |
+
)
|
| 439 |
+
elif isinstance(sampler, GreedySampler):
|
| 440 |
+
self.sampling_params = SamplingParameters(
|
| 441 |
+
"greedy", sampler.samples, None, None, 0.0
|
| 442 |
+
)
|
| 443 |
+
elif isinstance(sampler, BeamSearchSampler):
|
| 444 |
+
self.sampling_params = SamplingParameters(
|
| 445 |
+
"beam_search", sampler.samples, None, None, 1.0
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
def prepare_generation_parameters(
|
| 449 |
+
self,
|
| 450 |
+
max_tokens: Optional[int],
|
| 451 |
+
stop_at: Optional[Union[str, List[str]]],
|
| 452 |
+
seed: Optional[int],
|
| 453 |
+
):
|
| 454 |
+
if isinstance(stop_at, str):
|
| 455 |
+
stop_at = [stop_at]
|
| 456 |
+
|
| 457 |
+
generation_params = GenerationParameters(
|
| 458 |
+
max_tokens,
|
| 459 |
+
stop_at,
|
| 460 |
+
seed,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
return generation_params
|
| 464 |
+
|
| 465 |
+
def format_sequence(self, sequence: str) -> FormattedOutput:
|
| 466 |
+
"""Translate the generated sequence to another type.
|
| 467 |
+
|
| 468 |
+
This method is for instance overridden when generating JSON to either
|
| 469 |
+
return a dictionnary or a Pydantic model.
|
| 470 |
+
|
| 471 |
+
Parameters
|
| 472 |
+
----------
|
| 473 |
+
sequence
|
| 474 |
+
A generated sequences.
|
| 475 |
+
|
| 476 |
+
Returns
|
| 477 |
+
-------
|
| 478 |
+
The formatted sequence.
|
| 479 |
+
|
| 480 |
+
"""
|
| 481 |
+
return sequence
|
| 482 |
+
|
| 483 |
+
def _format(self, sequences):
|
| 484 |
+
"""Apply formatting to every string in a completion."""
|
| 485 |
+
if isinstance(sequences, list):
|
| 486 |
+
return [self._format(sequence) for sequence in sequences]
|
| 487 |
+
else:
|
| 488 |
+
return self.format_sequence(sequences)
|
| 489 |
+
|
| 490 |
+
def __call__(
|
| 491 |
+
self,
|
| 492 |
+
prompts: Union[str, List[str]],
|
| 493 |
+
max_tokens: Optional[int] = None,
|
| 494 |
+
stop_at: Optional[Union[str, List[str]]] = None,
|
| 495 |
+
seed: Optional[int] = None,
|
| 496 |
+
**model_specific_params,
|
| 497 |
+
):
|
| 498 |
+
"""Generate text from a prompt of list of prompts."""
|
| 499 |
+
|
| 500 |
+
generation_params = self.prepare_generation_parameters(
|
| 501 |
+
max_tokens, stop_at, seed
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
completions = self.model.generate(
|
| 505 |
+
prompts,
|
| 506 |
+
generation_params,
|
| 507 |
+
copy(self.logits_processor),
|
| 508 |
+
self.sampling_params,
|
| 509 |
+
**model_specific_params,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
return self._format(completions)
|
| 513 |
+
|
| 514 |
+
def stream(
|
| 515 |
+
self,
|
| 516 |
+
prompts: Union[str, List[str]],
|
| 517 |
+
max_tokens: Optional[int] = None,
|
| 518 |
+
stop_at: Optional[Union[str, List[str]]] = None,
|
| 519 |
+
seed: Optional[int] = None,
|
| 520 |
+
**model_specific_params,
|
| 521 |
+
):
|
| 522 |
+
"""Return a text generator from a prompt or a list of prompts."""
|
| 523 |
+
generation_params = self.prepare_generation_parameters(
|
| 524 |
+
max_tokens, stop_at, seed
|
| 525 |
+
)
|
| 526 |
+
return self.model.stream(
|
| 527 |
+
prompts,
|
| 528 |
+
generation_params,
|
| 529 |
+
copy(self.logits_processor),
|
| 530 |
+
self.sampling_params,
|
| 531 |
+
**model_specific_params,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class VisionSequenceGeneratorAdapter(SequenceGeneratorAdapter):
|
| 536 |
+
def __call__( # type: ignore
|
| 537 |
+
self,
|
| 538 |
+
prompts: Union[str, List[str]],
|
| 539 |
+
media: Union[str, Any],
|
| 540 |
+
max_tokens: Optional[int] = None,
|
| 541 |
+
stop_at: Optional[Union[str, List[str]]] = None,
|
| 542 |
+
seed: Optional[int] = None,
|
| 543 |
+
**model_specific_params,
|
| 544 |
+
):
|
| 545 |
+
"""
|
| 546 |
+
Generate text from a prompt of list of prompts.
|
| 547 |
+
|
| 548 |
+
Media: A URI to construct media or media object itself. Used as AutoProcessor argument.
|
| 549 |
+
"""
|
| 550 |
+
prompts, media = self._validate_prompt_media_types(prompts, media)
|
| 551 |
+
|
| 552 |
+
generation_params = self.prepare_generation_parameters(
|
| 553 |
+
max_tokens, stop_at, seed
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
completions = self.model.generate(
|
| 557 |
+
prompts,
|
| 558 |
+
media,
|
| 559 |
+
generation_params,
|
| 560 |
+
copy(self.logits_processor),
|
| 561 |
+
self.sampling_params,
|
| 562 |
+
**model_specific_params,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
return self._format(completions)
|
| 566 |
+
|
| 567 |
+
def stream( # type: ignore
|
| 568 |
+
self,
|
| 569 |
+
prompts: Union[str, List[str]],
|
| 570 |
+
media: List[Union[str, Any, List[Union[str, Any]]]],
|
| 571 |
+
max_tokens: Optional[int] = None,
|
| 572 |
+
stop_at: Optional[Union[str, List[str]]] = None,
|
| 573 |
+
seed: Optional[int] = None,
|
| 574 |
+
**model_specific_params,
|
| 575 |
+
):
|
| 576 |
+
"""Return a text generator from a prompt or a list of prompts."""
|
| 577 |
+
prompts, media = self._validate_prompt_media_types(prompts, media)
|
| 578 |
+
generation_params = self.prepare_generation_parameters(
|
| 579 |
+
max_tokens, stop_at, seed
|
| 580 |
+
)
|
| 581 |
+
return self.model.stream(
|
| 582 |
+
prompts,
|
| 583 |
+
media,
|
| 584 |
+
generation_params,
|
| 585 |
+
copy(self.logits_processor),
|
| 586 |
+
self.sampling_params,
|
| 587 |
+
**model_specific_params,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
@classmethod
|
| 591 |
+
def _validate_prompt_media_types(
|
| 592 |
+
cls,
|
| 593 |
+
prompts: Union[str, List[str]],
|
| 594 |
+
media: Union[str, Any, List[Union[str, Any]]],
|
| 595 |
+
) -> Union[Any, List[Any]]:
|
| 596 |
+
"""
|
| 597 |
+
Prepare media as PIL.Image and ensure for every prompt str there is one List[PIL.Image]
|
| 598 |
+
"""
|
| 599 |
+
|
| 600 |
+
def valid_types(prompts, media):
|
| 601 |
+
from PIL import Image # type: ignore
|
| 602 |
+
|
| 603 |
+
if isinstance(prompts, list):
|
| 604 |
+
if not isinstance(media, list) or len(prompts) != len(media):
|
| 605 |
+
return False
|
| 606 |
+
for subprompt, submedia in zip(prompts, media):
|
| 607 |
+
if not isinstance(subprompt, str) or not all(
|
| 608 |
+
isinstance(m, Image.Image) for m in submedia
|
| 609 |
+
):
|
| 610 |
+
return False
|
| 611 |
+
elif isinstance(prompts, str):
|
| 612 |
+
if not all(isinstance(m, Image.Image) for m in media):
|
| 613 |
+
return False
|
| 614 |
+
return True
|
| 615 |
+
|
| 616 |
+
if not valid_types(prompts, media):
|
| 617 |
+
raise TypeError(
|
| 618 |
+
"Expected (prompts, media) to be of type "
|
| 619 |
+
"(str, List[Image])), or (List[str], List[List[Image]]) "
|
| 620 |
+
f"instead got prompts={prompts}, media={media}"
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
return prompts, media
|
.venv/lib/python3.11/site-packages/outlines/generate/cfg.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import singledispatch
|
| 2 |
+
|
| 3 |
+
from outlines.generate.api import (
|
| 4 |
+
SequenceGeneratorAdapter,
|
| 5 |
+
VisionSequenceGeneratorAdapter,
|
| 6 |
+
)
|
| 7 |
+
from outlines.models import LlamaCpp, OpenAI, TransformersVision
|
| 8 |
+
from outlines.samplers import Sampler, multinomial
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@singledispatch
|
| 12 |
+
def cfg(
|
| 13 |
+
model, cfg_str: str, sampler: Sampler = multinomial()
|
| 14 |
+
) -> SequenceGeneratorAdapter:
|
| 15 |
+
"""Generate text in the language of a Context-Free Grammar
|
| 16 |
+
|
| 17 |
+
Arguments
|
| 18 |
+
---------
|
| 19 |
+
model:
|
| 20 |
+
An `outlines.model` instance.
|
| 21 |
+
sampler:
|
| 22 |
+
The sampling algorithm to use to generate token ids from the logits
|
| 23 |
+
distribution.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
A `SequenceGeneratorAdapter` instance that generates text.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
from outlines.processors import CFGLogitsProcessor
|
| 31 |
+
|
| 32 |
+
logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
|
| 33 |
+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@cfg.register(TransformersVision)
|
| 37 |
+
def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()):
|
| 38 |
+
from outlines.processors import CFGLogitsProcessor
|
| 39 |
+
|
| 40 |
+
logits_processor = CFGLogitsProcessor(cfg_str, tokenizer=model.tokenizer)
|
| 41 |
+
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@cfg.register(LlamaCpp)
|
| 45 |
+
def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()):
|
| 46 |
+
raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@cfg.register(OpenAI)
|
| 50 |
+
def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()):
|
| 51 |
+
raise NotImplementedError(
|
| 52 |
+
"Cannot use grammar-structured generation with an OpenAI model"
|
| 53 |
+
+ "due to the limitations of the OpenAI API."
|
| 54 |
+
)
|
.venv/lib/python3.11/site-packages/outlines/generate/choice.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json as pyjson
|
| 2 |
+
import re
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from functools import singledispatch
|
| 5 |
+
from typing import Callable, List, Union
|
| 6 |
+
|
| 7 |
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
| 8 |
+
|
| 9 |
+
from outlines.fsm.json_schema import get_schema_from_enum
|
| 10 |
+
from outlines.generate.api import SequenceGeneratorAdapter
|
| 11 |
+
from outlines.models import OpenAI
|
| 12 |
+
from outlines.samplers import Sampler, multinomial
|
| 13 |
+
|
| 14 |
+
from .json import json
|
| 15 |
+
from .regex import regex
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@singledispatch
|
| 19 |
+
def choice(
|
| 20 |
+
model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial()
|
| 21 |
+
) -> SequenceGeneratorAdapter:
|
| 22 |
+
if isinstance(choices, type(Enum)):
|
| 23 |
+
regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices)))
|
| 24 |
+
else:
|
| 25 |
+
choices = [re.escape(choice) for choice in choices] # type: ignore
|
| 26 |
+
regex_str = r"(" + r"|".join(choices) + r")"
|
| 27 |
+
|
| 28 |
+
generator = regex(model, regex_str, sampler)
|
| 29 |
+
if isinstance(choices, type(Enum)):
|
| 30 |
+
generator.format_sequence = lambda x: pyjson.loads(x)
|
| 31 |
+
else:
|
| 32 |
+
generator.format_sequence = lambda x: x
|
| 33 |
+
|
| 34 |
+
return generator
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@choice.register(OpenAI)
|
| 38 |
+
def choice_openai(
|
| 39 |
+
model: OpenAI, choices: List[str], sampler: Sampler = multinomial()
|
| 40 |
+
) -> Callable:
|
| 41 |
+
"""
|
| 42 |
+
Call OpenAI API with response_format of a dict:
|
| 43 |
+
{"result": <one of choices>}
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
choices_schema = pyjson.dumps(
|
| 47 |
+
{
|
| 48 |
+
"type": "object",
|
| 49 |
+
"properties": {"result": {"type": "string", "enum": choices}},
|
| 50 |
+
"additionalProperties": False,
|
| 51 |
+
"required": ["result"],
|
| 52 |
+
}
|
| 53 |
+
)
|
| 54 |
+
generator = json(model, choices_schema, sampler)
|
| 55 |
+
|
| 56 |
+
def generate_choice(*args, **kwargs):
|
| 57 |
+
return generator(*args, **kwargs)["result"]
|
| 58 |
+
|
| 59 |
+
return generate_choice
|
.venv/lib/python3.11/site-packages/outlines/generate/format.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import singledispatch
|
| 2 |
+
|
| 3 |
+
from outlines.fsm.types import python_types_to_regex
|
| 4 |
+
from outlines.generate.api import SequenceGeneratorAdapter
|
| 5 |
+
from outlines.models import OpenAI
|
| 6 |
+
from outlines.samplers import Sampler, multinomial
|
| 7 |
+
|
| 8 |
+
from .regex import regex
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@singledispatch
|
| 12 |
+
def format(
|
| 13 |
+
model, python_type, sampler: Sampler = multinomial()
|
| 14 |
+
) -> SequenceGeneratorAdapter:
|
| 15 |
+
"""Generate structured data that can be parsed as a Python type.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
model:
|
| 20 |
+
An instance of `Transformer` that represents a model from the
|
| 21 |
+
`transformers` library.
|
| 22 |
+
python_type:
|
| 23 |
+
A Python type. The output of the generator must be parseable into
|
| 24 |
+
this type.
|
| 25 |
+
sampler:
|
| 26 |
+
The sampling algorithm to use to generate token ids from the logits
|
| 27 |
+
distribution.
|
| 28 |
+
|
| 29 |
+
Returns
|
| 30 |
+
-------
|
| 31 |
+
A `SequenceGenerator` instance that generates text constrained by the Python type
|
| 32 |
+
and translates this text into the corresponding type.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
regex_str, format_fn = python_types_to_regex(python_type)
|
| 36 |
+
generator = regex(model, regex_str, sampler)
|
| 37 |
+
generator.format_sequence = format_fn
|
| 38 |
+
|
| 39 |
+
return generator
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@format.register(OpenAI)
|
| 43 |
+
def format_openai(model, python_type, sampler: Sampler = multinomial()):
|
| 44 |
+
raise NotImplementedError(
|
| 45 |
+
"Cannot use Python type-structured generation with an OpenAI model"
|
| 46 |
+
+ " due to the limitations of the OpenAI API."
|
| 47 |
+
)
|
.venv/lib/python3.11/site-packages/outlines/generate/fsm.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import singledispatch
|
| 2 |
+
|
| 3 |
+
import interegular
|
| 4 |
+
|
| 5 |
+
from outlines.fsm.guide import RegexGuide
|
| 6 |
+
from outlines.generate.api import (
|
| 7 |
+
SequenceGeneratorAdapter,
|
| 8 |
+
VisionSequenceGeneratorAdapter,
|
| 9 |
+
)
|
| 10 |
+
from outlines.models import TransformersVision
|
| 11 |
+
from outlines.samplers import Sampler, multinomial
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@singledispatch
|
| 15 |
+
def fsm(
|
| 16 |
+
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
|
| 17 |
+
) -> SequenceGeneratorAdapter:
|
| 18 |
+
from outlines.processors import GuideLogitsProcessor
|
| 19 |
+
|
| 20 |
+
guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
|
| 21 |
+
logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
|
| 22 |
+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@fsm.register(TransformersVision)
|
| 26 |
+
def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()):
|
| 27 |
+
from outlines.processors import GuideLogitsProcessor
|
| 28 |
+
|
| 29 |
+
guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
|
| 30 |
+
logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
|
| 31 |
+
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
|
.venv/lib/python3.11/site-packages/outlines/generate/generator.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import math
|
| 3 |
+
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from outlines.fsm.guide import Guide
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ContextLengthExceededError(Exception):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclasses.dataclass(frozen=True)
|
| 16 |
+
class GenerationState:
|
| 17 |
+
token_ids: "torch.Tensor"
|
| 18 |
+
kv_cache: "torch.Tensor"
|
| 19 |
+
logits: "torch.Tensor"
|
| 20 |
+
weights: "torch.Tensor"
|
| 21 |
+
fsm_states: List[int]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def sequence_generator(
|
| 25 |
+
model: Callable,
|
| 26 |
+
sampler: Callable,
|
| 27 |
+
fsms: List["Guide"],
|
| 28 |
+
token_ids: "torch.Tensor",
|
| 29 |
+
sequence_weights: "torch.Tensor",
|
| 30 |
+
attention_masks: "torch.Tensor",
|
| 31 |
+
fsm_states: List[int],
|
| 32 |
+
rng: "torch.Generator",
|
| 33 |
+
) -> Iterator[GenerationState]:
|
| 34 |
+
"""Generates sequences of tokens.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
model
|
| 39 |
+
A callable that generates a probability distribution over the
|
| 40 |
+
vocabulary when passed a tensor of token ids.
|
| 41 |
+
sampler
|
| 42 |
+
A callable that returns the next token ids, their ancestor sequence and
|
| 43 |
+
the updated sequence weights when passed a distribution over the
|
| 44 |
+
vocabulary.
|
| 45 |
+
token_ids
|
| 46 |
+
A tensor of token ids on which the sequence distribution is conditioned, of
|
| 47 |
+
shape ``(n_seqs, n_prompt_tokens)``
|
| 48 |
+
sequence_weights
|
| 49 |
+
A tensor that contains the initial weights of the sequences, of shape
|
| 50 |
+
``(n_seqs,)``
|
| 51 |
+
attention_masks
|
| 52 |
+
A tensor of tensors that represent the tokens considered at the attention
|
| 53 |
+
layer, of shape ``(n_seqs, n_prompt_tokens)``.
|
| 54 |
+
fsms
|
| 55 |
+
List of finite-state machines that drive the text generation,
|
| 56 |
+
one for each sequence in the batch.
|
| 57 |
+
fsm_states
|
| 58 |
+
The initial states of the finite-state machine for each sequence in the batch.
|
| 59 |
+
|
| 60 |
+
Yields
|
| 61 |
+
------
|
| 62 |
+
A new sequence.
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
import torch
|
| 66 |
+
|
| 67 |
+
if rng is None:
|
| 68 |
+
rng = torch.Generator()
|
| 69 |
+
|
| 70 |
+
kv_cache = None
|
| 71 |
+
|
| 72 |
+
while True:
|
| 73 |
+
try:
|
| 74 |
+
logits, kv_cache = model(token_ids, attention_masks, kv_cache)
|
| 75 |
+
except IndexError: # Exceeding the context length
|
| 76 |
+
raise ContextLengthExceededError(
|
| 77 |
+
"The input length exceeds the context length of the model."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
allowed_tokens = get_allowed_tokens(fsms, fsm_states)
|
| 81 |
+
biased_logits = bias_logits(logits, allowed_tokens)
|
| 82 |
+
next_token_ids, ancestors, sequence_weights = sampler(
|
| 83 |
+
biased_logits, sequence_weights, rng
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
token_ids = update_token_ids(token_ids, next_token_ids, ancestors)
|
| 87 |
+
attention_masks = update_attention_masks(attention_masks, ancestors)
|
| 88 |
+
kv_cache = reorder_kv_cache(kv_cache, ancestors)
|
| 89 |
+
if len(ancestors) > 1:
|
| 90 |
+
fsms = reorder_fsms(fsms, ancestors)
|
| 91 |
+
fsm_states = reorder_fsm_states(fsm_states, ancestors)
|
| 92 |
+
|
| 93 |
+
fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids)
|
| 94 |
+
is_finished = is_generation_finished(fsms, fsm_states)
|
| 95 |
+
|
| 96 |
+
if is_finished:
|
| 97 |
+
yield GenerationState(
|
| 98 |
+
token_ids,
|
| 99 |
+
kv_cache,
|
| 100 |
+
logits,
|
| 101 |
+
sequence_weights,
|
| 102 |
+
fsm_states,
|
| 103 |
+
)
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
yield GenerationState(
|
| 107 |
+
token_ids,
|
| 108 |
+
kv_cache,
|
| 109 |
+
logits,
|
| 110 |
+
sequence_weights,
|
| 111 |
+
fsm_states,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_next_fsm_states(
|
| 116 |
+
fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
|
| 117 |
+
) -> List[int]:
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
fsm
|
| 123 |
+
The finite-state machine used to monitor this batch.
|
| 124 |
+
next_token_ids
|
| 125 |
+
The tokens that were just generated.
|
| 126 |
+
|
| 127 |
+
Returns
|
| 128 |
+
-------
|
| 129 |
+
A `torch.Tensor` object that represents the next logit mask.
|
| 130 |
+
|
| 131 |
+
"""
|
| 132 |
+
return [
|
| 133 |
+
fsm.get_next_state(fsm_state, int(token_id[0]))
|
| 134 |
+
for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids)
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_allowed_tokens(
|
| 139 |
+
fsms: List["Guide"], fsm_states: List[int]
|
| 140 |
+
) -> List[Optional[Iterable[int]]]:
|
| 141 |
+
"""Get the new instructions for each sequence from the finite-state machine.
|
| 142 |
+
|
| 143 |
+
Parameters
|
| 144 |
+
----------
|
| 145 |
+
fsm
|
| 146 |
+
The finite-state machine used to monitor this batch.
|
| 147 |
+
fsm_states
|
| 148 |
+
The FSM states corresponding to each sequence in the batch.
|
| 149 |
+
|
| 150 |
+
Returns
|
| 151 |
+
-------
|
| 152 |
+
A nested list that contains the ids of the logits to keep.
|
| 153 |
+
|
| 154 |
+
"""
|
| 155 |
+
return [
|
| 156 |
+
fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states)
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool:
|
| 161 |
+
"""Determine if the generation is finished.
|
| 162 |
+
|
| 163 |
+
A generation is considered finished if the FSM of every sequence in the
|
| 164 |
+
batch is in a final state.
|
| 165 |
+
|
| 166 |
+
A better solution is to return finished sequences as soon as their FSM
|
| 167 |
+
is in a final state.
|
| 168 |
+
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
fsm
|
| 172 |
+
The finite-state machine used to monitor this batch.
|
| 173 |
+
fsm_states
|
| 174 |
+
The FSM states corresponding to each sequence in the batch.
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
Whether all sequences are finished sampling.
|
| 179 |
+
|
| 180 |
+
"""
|
| 181 |
+
return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)])
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def update_token_ids(
|
| 185 |
+
token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor"
|
| 186 |
+
) -> "torch.Tensor":
|
| 187 |
+
"""Append the sampled tokens to the running sequence of tokens.
|
| 188 |
+
|
| 189 |
+
Parameters
|
| 190 |
+
----------
|
| 191 |
+
token_ids
|
| 192 |
+
The current token sequences
|
| 193 |
+
next_token_ids
|
| 194 |
+
The tokens that were just generated and that we need to append
|
| 195 |
+
to the existing sequences.
|
| 196 |
+
ancestors
|
| 197 |
+
The sequences to which the token ids need to be added.
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
A new sequence of token ids that contains the tokens that were
|
| 202 |
+
just generated.
|
| 203 |
+
|
| 204 |
+
"""
|
| 205 |
+
import torch
|
| 206 |
+
|
| 207 |
+
token_ids = torch.index_select(token_ids, 0, ancestors)
|
| 208 |
+
return torch.concatenate([token_ids, next_token_ids], dim=-1)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def update_attention_masks(
|
| 212 |
+
attention_masks: "torch.Tensor", ancestors: "torch.Tensor"
|
| 213 |
+
) -> "torch.Tensor":
|
| 214 |
+
"""Expand the attention masks.
|
| 215 |
+
|
| 216 |
+
Parameters
|
| 217 |
+
----------
|
| 218 |
+
attention_masks
|
| 219 |
+
The attention masks for each sequence in the batch.
|
| 220 |
+
ancestors
|
| 221 |
+
The sequences to which the token ids need to be added.
|
| 222 |
+
|
| 223 |
+
Returns
|
| 224 |
+
-------
|
| 225 |
+
The attention masks padded with 1s.
|
| 226 |
+
|
| 227 |
+
"""
|
| 228 |
+
import torch
|
| 229 |
+
|
| 230 |
+
attention_masks = torch.index_select(attention_masks, 0, ancestors)
|
| 231 |
+
return torch.concatenate(
|
| 232 |
+
[
|
| 233 |
+
attention_masks,
|
| 234 |
+
torch.ones(
|
| 235 |
+
attention_masks.shape[:-1] + (1,), device=attention_masks.device
|
| 236 |
+
),
|
| 237 |
+
],
|
| 238 |
+
axis=-1,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]:
|
| 243 |
+
reordered_fsms = []
|
| 244 |
+
for ancestor in ancestors:
|
| 245 |
+
reordered_fsms.append(fsms[ancestor].copy())
|
| 246 |
+
|
| 247 |
+
return reordered_fsms
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]:
|
| 251 |
+
reordered_states = []
|
| 252 |
+
for ancestor in ancestors:
|
| 253 |
+
reordered_states.append(fsm_states[ancestor])
|
| 254 |
+
|
| 255 |
+
return reordered_states
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def reorder_kv_cache(
|
| 259 |
+
kv_cache: Optional[Tuple], ancestors: "torch.Tensor"
|
| 260 |
+
) -> Optional[Tuple]:
|
| 261 |
+
"""Re-order the KV-cache based on the ancestors.
|
| 262 |
+
|
| 263 |
+
In transformers, the object that stores the KV-cache is a tuple who elements
|
| 264 |
+
are the key cache and the value cache. Each of these caches are tuples where
|
| 265 |
+
each element correpond to a layer. To each layer corresponds a tensor whose
|
| 266 |
+
first dimension is the batch size.
|
| 267 |
+
|
| 268 |
+
"""
|
| 269 |
+
import torch
|
| 270 |
+
|
| 271 |
+
if kv_cache is None:
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
new_kv_cache: Tuple = tuple()
|
| 275 |
+
for cache_item in kv_cache:
|
| 276 |
+
new_cache_item: Tuple = tuple()
|
| 277 |
+
for layer in cache_item:
|
| 278 |
+
layer = torch.index_select(layer, 0, ancestors.to(layer.device))
|
| 279 |
+
new_cache_item += (layer,)
|
| 280 |
+
new_kv_cache += (new_cache_item,)
|
| 281 |
+
|
| 282 |
+
return new_kv_cache
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor":
|
| 286 |
+
"""Mask the logits.
|
| 287 |
+
|
| 288 |
+
The function iterates over a nested list where each list corresponds to the
|
| 289 |
+
indices that need to be masked for each row in the array.
|
| 290 |
+
|
| 291 |
+
Parameters
|
| 292 |
+
----------
|
| 293 |
+
logits
|
| 294 |
+
Two dimensional tensor that contains the next-token probability
|
| 295 |
+
distribution.
|
| 296 |
+
allowed_token_ids
|
| 297 |
+
A list that contains the tokens that can be generated by the model.
|
| 298 |
+
|
| 299 |
+
Returns
|
| 300 |
+
-------
|
| 301 |
+
A view of the original logits tensor where some values are masked.
|
| 302 |
+
|
| 303 |
+
"""
|
| 304 |
+
import torch
|
| 305 |
+
|
| 306 |
+
biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
|
| 307 |
+
for i, ids in enumerate(allowed_token_ids):
|
| 308 |
+
if ids is not None:
|
| 309 |
+
biased_logits[i, ids] = logits[i, ids]
|
| 310 |
+
else:
|
| 311 |
+
biased_logits[i] = logits[i]
|
| 312 |
+
return biased_logits
|
.venv/lib/python3.11/site-packages/outlines/generate/json.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json as pyjson
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from functools import singledispatch
|
| 4 |
+
from typing import Callable, Optional, Union
|
| 5 |
+
|
| 6 |
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature
|
| 10 |
+
from outlines.generate.api import SequenceGeneratorAdapter
|
| 11 |
+
from outlines.models import OpenAI
|
| 12 |
+
from outlines.samplers import Sampler, multinomial
|
| 13 |
+
|
| 14 |
+
from .regex import regex
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@singledispatch
|
| 18 |
+
def json(
|
| 19 |
+
model,
|
| 20 |
+
schema_object: Union[str, object, Callable],
|
| 21 |
+
sampler: Sampler = multinomial(),
|
| 22 |
+
whitespace_pattern: Optional[str] = None,
|
| 23 |
+
) -> SequenceGeneratorAdapter:
|
| 24 |
+
"""
|
| 25 |
+
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
|
| 26 |
+
|
| 27 |
+
Parameters
|
| 28 |
+
----------
|
| 29 |
+
model:
|
| 30 |
+
An instance of `Transformer` that represents a model from the
|
| 31 |
+
`transformers` library.
|
| 32 |
+
schema_object:
|
| 33 |
+
The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable
|
| 34 |
+
that returns a JSON schema.
|
| 35 |
+
sampler:
|
| 36 |
+
The sampling algorithm to use to generate token ids from the logits
|
| 37 |
+
distribution.
|
| 38 |
+
whitespace_pattern
|
| 39 |
+
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
|
| 40 |
+
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
A `SequenceGenerator` instance that generates text constrained by the schema_object and
|
| 45 |
+
transforms the result if BaseModel is used.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
if isinstance(schema_object, type(BaseModel)):
|
| 49 |
+
schema = pyjson.dumps(schema_object.model_json_schema())
|
| 50 |
+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
|
| 51 |
+
generator = regex(model, regex_str, sampler)
|
| 52 |
+
generator.format_sequence = lambda x: schema_object.parse_raw(x)
|
| 53 |
+
elif isinstance(schema_object, type(Enum)):
|
| 54 |
+
schema = pyjson.dumps(get_schema_from_enum(schema_object))
|
| 55 |
+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
|
| 56 |
+
generator = regex(model, regex_str, sampler)
|
| 57 |
+
generator.format_sequence = lambda x: pyjson.loads(x)
|
| 58 |
+
elif callable(schema_object):
|
| 59 |
+
schema = pyjson.dumps(get_schema_from_signature(schema_object))
|
| 60 |
+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
|
| 61 |
+
generator = regex(model, regex_str, sampler)
|
| 62 |
+
generator.format_sequence = lambda x: pyjson.loads(x)
|
| 63 |
+
elif isinstance(schema_object, str):
|
| 64 |
+
schema = schema_object
|
| 65 |
+
regex_str = build_regex_from_schema(schema, whitespace_pattern)
|
| 66 |
+
generator = regex(model, regex_str, sampler)
|
| 67 |
+
generator.format_sequence = lambda x: pyjson.loads(x)
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Cannot parse schema {schema_object}. The schema must be either "
|
| 71 |
+
+ "a Pydantic object, a function or a string that contains the JSON "
|
| 72 |
+
+ "Schema specification"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return generator
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@json.register(OpenAI)
|
| 79 |
+
def json_openai(
|
| 80 |
+
model, schema_object: Union[str, object], sampler: Sampler = multinomial()
|
| 81 |
+
):
|
| 82 |
+
if not isinstance(sampler, multinomial):
|
| 83 |
+
raise NotImplementedError(
|
| 84 |
+
r"The OpenAI API does not support any other sampling algorithm "
|
| 85 |
+
+ "than the multinomial sampler."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if isinstance(schema_object, type(BaseModel)):
|
| 89 |
+
schema = pyjson.dumps(schema_object.model_json_schema())
|
| 90 |
+
format_sequence = lambda x: schema_object.parse_raw(x)
|
| 91 |
+
elif isinstance(schema_object, str):
|
| 92 |
+
schema = schema_object
|
| 93 |
+
format_sequence = lambda x: pyjson.loads(x)
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Cannot parse schema {schema_object}. The schema must be either "
|
| 97 |
+
+ "a Pydantic object, a function or a string that contains the JSON "
|
| 98 |
+
+ "Schema specification"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# create copied, patched model with normalized json schema set
|
| 102 |
+
generator = model.new_with_replacements(
|
| 103 |
+
response_format={
|
| 104 |
+
"type": "json_schema",
|
| 105 |
+
"json_schema": {
|
| 106 |
+
"name": "default",
|
| 107 |
+
"strict": True,
|
| 108 |
+
"schema": pyjson.loads(schema),
|
| 109 |
+
},
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
generator.format_sequence = format_sequence
|
| 114 |
+
|
| 115 |
+
return generator
|
.venv/lib/python3.11/site-packages/outlines/generate/regex.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import singledispatch
|
| 2 |
+
|
| 3 |
+
from outlines.generate.api import (
|
| 4 |
+
SequenceGeneratorAdapter,
|
| 5 |
+
VisionSequenceGeneratorAdapter,
|
| 6 |
+
)
|
| 7 |
+
from outlines.models import OpenAI, TransformersVision
|
| 8 |
+
from outlines.samplers import Sampler, multinomial
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@singledispatch
|
| 12 |
+
def regex(model, regex_str: str, sampler: Sampler = multinomial()):
|
| 13 |
+
"""Generate structured text in the language of a regular expression.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
model:
|
| 18 |
+
An instance of `Transformer` that represents a model from the
|
| 19 |
+
`transformers` library.
|
| 20 |
+
regex_str:
|
| 21 |
+
The regular expression that the output must follow.
|
| 22 |
+
sampler:
|
| 23 |
+
The sampling algorithm to use to generate token ids from the logits
|
| 24 |
+
distribution.
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
A `SequenceGeneratorAdapter` instance that generates text constrained by the
|
| 29 |
+
regular expression.
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
from outlines.processors import RegexLogitsProcessor
|
| 33 |
+
|
| 34 |
+
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
|
| 35 |
+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@regex.register(TransformersVision)
|
| 39 |
+
def regex_vision(
|
| 40 |
+
model,
|
| 41 |
+
regex_str: str,
|
| 42 |
+
sampler: Sampler = multinomial(),
|
| 43 |
+
):
|
| 44 |
+
from outlines.processors import RegexLogitsProcessor
|
| 45 |
+
|
| 46 |
+
logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
|
| 47 |
+
return VisionSequenceGeneratorAdapter(model, logits_processor, sampler)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@regex.register(OpenAI)
|
| 51 |
+
def regex_openai(
|
| 52 |
+
model: OpenAI,
|
| 53 |
+
regex_str: str,
|
| 54 |
+
sampler: Sampler = multinomial(),
|
| 55 |
+
):
|
| 56 |
+
raise NotImplementedError(
|
| 57 |
+
"Cannot use regex-structured generation with an OpenAI model"
|
| 58 |
+
+ "due to the limitations of the OpenAI API."
|
| 59 |
+
)
|
.venv/lib/python3.11/site-packages/outlines/generate/text.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import singledispatch
|
| 2 |
+
|
| 3 |
+
from outlines.generate.api import (
|
| 4 |
+
SequenceGeneratorAdapter,
|
| 5 |
+
VisionSequenceGeneratorAdapter,
|
| 6 |
+
)
|
| 7 |
+
from outlines.models import OpenAI, TransformersVision
|
| 8 |
+
from outlines.samplers import Sampler, multinomial
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@singledispatch
|
| 12 |
+
def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter:
|
| 13 |
+
"""Generate text with a `Transformer` model.
|
| 14 |
+
|
| 15 |
+
Note
|
| 16 |
+
----
|
| 17 |
+
Python 3.11 allows dispatching on Union types and
|
| 18 |
+
this should greatly simplify the code.
|
| 19 |
+
|
| 20 |
+
Arguments
|
| 21 |
+
---------
|
| 22 |
+
model:
|
| 23 |
+
An instance of `Transformer` that represents a model from the
|
| 24 |
+
`transformers` library.
|
| 25 |
+
sampler:
|
| 26 |
+
The sampling algorithm to use to generate token ids from the logits
|
| 27 |
+
distribution.
|
| 28 |
+
|
| 29 |
+
Returns
|
| 30 |
+
-------
|
| 31 |
+
A `SequenceGeneratorAdapter` instance that generates text.
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
return SequenceGeneratorAdapter(model, None, sampler)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@text.register(TransformersVision)
|
| 38 |
+
def text_vision(model, sampler: Sampler = multinomial()):
|
| 39 |
+
return VisionSequenceGeneratorAdapter(model, None, sampler)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@text.register(OpenAI)
|
| 43 |
+
def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI:
|
| 44 |
+
if not isinstance(sampler, multinomial):
|
| 45 |
+
raise NotImplementedError(
|
| 46 |
+
r"The OpenAI API does not support any other sampling algorithm "
|
| 47 |
+
+ "than the multinomial sampler."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return model
|
.venv/lib/python3.11/site-packages/outlines/grammars.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
GRAMMAR_PATH = Path(__file__).parent / "grammars"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_grammar(grammar_file_name, base_grammar_path=GRAMMAR_PATH):
|
| 7 |
+
"""Read grammar file from default grammar path"""
|
| 8 |
+
full_path = base_grammar_path / grammar_file_name
|
| 9 |
+
with open(full_path) as file:
|
| 10 |
+
return file.read()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
arithmetic = read_grammar("arithmetic.lark")
|
| 14 |
+
json = read_grammar("json.lark")
|
.venv/lib/python3.11/site-packages/outlines/grammars/arithmetic.lark
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
?start: sum
|
| 2 |
+
|
| 3 |
+
?sum: product
|
| 4 |
+
| sum "+" product -> add
|
| 5 |
+
| sum "-" product -> sub
|
| 6 |
+
|
| 7 |
+
?product: atom
|
| 8 |
+
| product "*" atom -> mul
|
| 9 |
+
| product "/" atom -> div
|
| 10 |
+
|
| 11 |
+
?atom: NUMBER -> number
|
| 12 |
+
| "-" atom -> neg
|
| 13 |
+
| "(" sum ")"
|
| 14 |
+
|
| 15 |
+
%import common.NUMBER
|
| 16 |
+
%import common.WS_INLINE
|
| 17 |
+
|
| 18 |
+
%ignore WS_INLINE
|
.venv/lib/python3.11/site-packages/outlines/grammars/common.lark
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Adapted from https://github.com/lark-parser/lark/blob/master/lark/grammars/common.lark
|
| 2 |
+
|
| 3 |
+
// Lark License:
|
| 4 |
+
// Copyright © 2017 Erez Shinan
|
| 5 |
+
//
|
| 6 |
+
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
| 7 |
+
// this software and associated documentation files (the "Software"), to deal in
|
| 8 |
+
// the Software without restriction, including without limitation the rights to
|
| 9 |
+
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
| 10 |
+
// the Software, and to permit persons to whom the Software is furnished to do so,
|
| 11 |
+
// subject to the following conditions:
|
| 12 |
+
//
|
| 13 |
+
// The above copyright notice and this permission notice shall be included in all
|
| 14 |
+
// copies or substantial portions of the Software.
|
| 15 |
+
//
|
| 16 |
+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 17 |
+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
| 18 |
+
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
| 19 |
+
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
| 20 |
+
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
| 21 |
+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
// Basic terminals for common use
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
//
|
| 28 |
+
// Numbers
|
| 29 |
+
//
|
| 30 |
+
|
| 31 |
+
DIGIT: "0".."9"
|
| 32 |
+
HEXDIGIT: "a".."f"|"A".."F"|DIGIT
|
| 33 |
+
|
| 34 |
+
INT: DIGIT+
|
| 35 |
+
SIGNED_INT: ["+"|"-"] INT
|
| 36 |
+
DECIMAL: INT "." INT? | "." INT
|
| 37 |
+
|
| 38 |
+
// float = /-?\d+(\.\d+)?([eE][+-]?\d+)?/
|
| 39 |
+
_EXP: ("e"|"E") SIGNED_INT
|
| 40 |
+
FLOAT: INT _EXP | DECIMAL _EXP?
|
| 41 |
+
SIGNED_FLOAT: ["+"|"-"] FLOAT
|
| 42 |
+
|
| 43 |
+
NUMBER: FLOAT | INT
|
| 44 |
+
SIGNED_NUMBER: ["+"|"-"] NUMBER
|
| 45 |
+
|
| 46 |
+
UNESCAPED_STRING: /\"[^"]*\"/
|
| 47 |
+
|
| 48 |
+
// based on `outlines/fsm/json_schema.py`
|
| 49 |
+
_NON_CONTROL_CHAR: /([^"\\\x00-\x1F\x7F-\x9F])/
|
| 50 |
+
_ESCAPED_CHAR: /\\/ (_NON_CONTROL_CHAR | /\\/ | /"/)
|
| 51 |
+
ESCAPED_STRING_INNER: _NON_CONTROL_CHAR | _ESCAPED_CHAR
|
| 52 |
+
ESCAPED_STRING: /"/ ESCAPED_STRING_INNER* /"/
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
//
|
| 57 |
+
// Names (Variables)
|
| 58 |
+
//
|
| 59 |
+
LCASE_LETTER: "a".."z"
|
| 60 |
+
UCASE_LETTER: "A".."Z"
|
| 61 |
+
|
| 62 |
+
LETTER: UCASE_LETTER | LCASE_LETTER
|
| 63 |
+
WORD: LETTER+
|
| 64 |
+
|
| 65 |
+
CNAME: ("_"|LETTER) ("_"|LETTER|DIGIT)*
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
//
|
| 69 |
+
// Whitespace
|
| 70 |
+
//
|
| 71 |
+
WS_INLINE: (" "|/\t/)+
|
| 72 |
+
WS: /[ \t\f\r\n]/+
|
| 73 |
+
|
| 74 |
+
CR : /\r/
|
| 75 |
+
LF : /\n/
|
| 76 |
+
NEWLINE: (CR? LF)+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
// Comments
|
| 80 |
+
SH_COMMENT: /#[^\n]*/
|
| 81 |
+
CPP_COMMENT: /\/\/[^\n]*/
|
| 82 |
+
C_COMMENT: "/*" /(.|\n)*?/ "*/"
|
| 83 |
+
SQL_COMMENT: /--[^\n]*/
|
.venv/lib/python3.11/site-packages/outlines/grammars/json.lark
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
?start: value
|
| 2 |
+
|
| 3 |
+
?value: object
|
| 4 |
+
| array
|
| 5 |
+
| ESCAPED_STRING
|
| 6 |
+
| SIGNED_NUMBER -> number
|
| 7 |
+
| "true" -> true
|
| 8 |
+
| "false" -> false
|
| 9 |
+
| "null" -> null
|
| 10 |
+
|
| 11 |
+
array : "[" [value ("," value)*] "]"
|
| 12 |
+
object : "{" [pair ("," pair)*] "}"
|
| 13 |
+
pair : ESCAPED_STRING ":" value
|
| 14 |
+
|
| 15 |
+
%import common.ESCAPED_STRING
|
| 16 |
+
%import common.SIGNED_NUMBER
|
| 17 |
+
%import common.WS
|
| 18 |
+
|
| 19 |
+
%ignore WS
|
.venv/lib/python3.11/site-packages/outlines/processors/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .structured import (
|
| 2 |
+
CFGLogitsProcessor,
|
| 3 |
+
GuideLogitsProcessor,
|
| 4 |
+
JSONLogitsProcessor,
|
| 5 |
+
OutlinesLogitsProcessor,
|
| 6 |
+
RegexLogitsProcessor,
|
| 7 |
+
)
|
.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (441 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/base_logits_processor.cpython-311.pyc
ADDED
|
Binary file (7.81 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/processors/__pycache__/structured.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/outlines/processors/base_logits_processor.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from typing import TYPE_CHECKING, List, Protocol, Type, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from numpy.typing import NDArray
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Array = Union[NDArray, torch.Tensor, List, "mx.array"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def is_mlx_array_type(array_type):
|
| 16 |
+
try:
|
| 17 |
+
import mlx.core as mx
|
| 18 |
+
except ImportError:
|
| 19 |
+
return False
|
| 20 |
+
return issubclass(array_type, mx.array)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_jax_array_type(array_type):
|
| 24 |
+
try:
|
| 25 |
+
import jaxlib
|
| 26 |
+
except ImportError:
|
| 27 |
+
return False
|
| 28 |
+
return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance(
|
| 29 |
+
array_type, jaxlib.xla_extension.ArrayImpl
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class OutlinesLogitsProcessor(Protocol):
|
| 34 |
+
"""
|
| 35 |
+
Base class for logits processors which normalizes types of logits:
|
| 36 |
+
- ndarray (used by llama-cpp-python), converted to torch.Tensor
|
| 37 |
+
- mlx.core.array (used by mlx-lm), converted to torch.Tensor
|
| 38 |
+
- torch.Tensor (used by everything else)
|
| 39 |
+
|
| 40 |
+
Normalization of types and conversion to torch.Tensor
|
| 41 |
+
doesn't move memory, it just casts the type.
|
| 42 |
+
|
| 43 |
+
Normalizing the types allows all logits processors inheriting from this class
|
| 44 |
+
to implement a single method for all the business logit: `process_logits()`
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def process_logits(
|
| 49 |
+
self, input_ids: List[List[int]], logits: torch.Tensor
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
input_ids and logits are always 2D tensors for handling a batch of sequences.
|
| 53 |
+
|
| 54 |
+
- input_ids -> List[List[tokens]]
|
| 55 |
+
- logits -> 2D_Tensor[logit floats]
|
| 56 |
+
|
| 57 |
+
Important to keep in mind when designing universal logits processors
|
| 58 |
+
- logits processors are only used once and never re-applied for a new sequence generator
|
| 59 |
+
- Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids
|
| 60 |
+
- Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM
|
| 61 |
+
"""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def __call__(
|
| 66 |
+
self,
|
| 67 |
+
input_ids: Array,
|
| 68 |
+
logits: Array,
|
| 69 |
+
) -> Array:
|
| 70 |
+
"""
|
| 71 |
+
Apply logits processor
|
| 72 |
+
|
| 73 |
+
1) Unify type
|
| 74 |
+
- convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]]
|
| 75 |
+
- convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor
|
| 76 |
+
2) Unify shape, ensure logits and input_ids are 2D
|
| 77 |
+
3) Call self.process_logits() to perform business logic
|
| 78 |
+
4) Cast logits back to original array library type
|
| 79 |
+
"""
|
| 80 |
+
# ensure logits are torch Tensors
|
| 81 |
+
torch_logits = self._to_torch(logits)
|
| 82 |
+
input_ids = self._to_torch(input_ids)
|
| 83 |
+
|
| 84 |
+
assert torch_logits.shape[:-1] == input_ids.shape[:-1]
|
| 85 |
+
|
| 86 |
+
# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
|
| 87 |
+
if len(torch_logits.shape) == 2:
|
| 88 |
+
processed_logits = self.process_logits(input_ids, torch_logits)
|
| 89 |
+
elif len(torch_logits.shape) == 1:
|
| 90 |
+
processed_logits = self.process_logits(
|
| 91 |
+
input_ids.unsqueeze(0), torch_logits.unsqueeze(0)
|
| 92 |
+
).squeeze(0)
|
| 93 |
+
|
| 94 |
+
# return logits as passed array type
|
| 95 |
+
return self._from_torch(processed_logits, type(logits))
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _to_torch(tensor_like: Array) -> torch.Tensor:
|
| 99 |
+
"""Convert various types to torch.Tensor."""
|
| 100 |
+
if isinstance(tensor_like, torch.Tensor):
|
| 101 |
+
return tensor_like
|
| 102 |
+
|
| 103 |
+
elif isinstance(tensor_like, np.ndarray):
|
| 104 |
+
return torch.from_numpy(tensor_like)
|
| 105 |
+
|
| 106 |
+
elif isinstance(tensor_like, (list, tuple)):
|
| 107 |
+
return torch.tensor(tensor_like)
|
| 108 |
+
|
| 109 |
+
elif is_mlx_array_type(type(tensor_like)):
|
| 110 |
+
import mlx.core as mx
|
| 111 |
+
|
| 112 |
+
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch
|
| 113 |
+
return torch.from_dlpack(
|
| 114 |
+
np.array(tensor_like.astype(mx.float32), copy=False)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
elif is_jax_array_type(type(tensor_like)):
|
| 118 |
+
import jax
|
| 119 |
+
|
| 120 |
+
torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
|
| 121 |
+
return torch_tensor
|
| 122 |
+
|
| 123 |
+
else:
|
| 124 |
+
raise TypeError(
|
| 125 |
+
"LogitsProcessor must be called with either np.NDArray, "
|
| 126 |
+
"torch.Tensor, list, or mlx.core.array typed logits. "
|
| 127 |
+
f"Logits type: `{type(tensor_like)}`"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
|
| 132 |
+
"""Convert torch.Tensor to the specified target type."""
|
| 133 |
+
if target_type == torch.Tensor:
|
| 134 |
+
return tensor
|
| 135 |
+
|
| 136 |
+
elif target_type == np.ndarray:
|
| 137 |
+
return tensor.detach().numpy()
|
| 138 |
+
|
| 139 |
+
elif target_type == list:
|
| 140 |
+
return tensor.detach().tolist()
|
| 141 |
+
|
| 142 |
+
elif target_type == tuple:
|
| 143 |
+
return tuple(tensor.detach().tolist())
|
| 144 |
+
|
| 145 |
+
elif is_mlx_array_type(target_type):
|
| 146 |
+
import mlx.core as mx
|
| 147 |
+
|
| 148 |
+
# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
|
| 149 |
+
return mx.array(tensor.float().numpy())
|
| 150 |
+
|
| 151 |
+
elif is_jax_array_type(target_type):
|
| 152 |
+
import jax
|
| 153 |
+
|
| 154 |
+
return jax.dlpack.from_dlpack(tensor)
|
| 155 |
+
|
| 156 |
+
else:
|
| 157 |
+
raise TypeError(
|
| 158 |
+
f"Failed to convert torch tensors to target_type `{target_type}`"
|
| 159 |
+
)
|